How do I know whether a sklearn scaler is already fitted or not?
Question:
For example, ss
is an sklearn.preprocessing.StandardScaler
object. If ss
is fitted already, I want to use it to transform my data. If ss
is not fitted yet, I want to use my data to fit it and transform my data. Is there a way to know whether ss
is already fitted or not?
Answers:
According to you example, in order to determine if your object is a Fitted scaler
, one checks if the attribute n_features_in_
exists in the object of interest.
from sklearn.preprocessing import StandardScaler
data = [[0, 0], [0, 0], [1, 1], [1, 1]]
scaler = StandardScaler()
scaler_fit = StandardScaler().fit(data)
def is_fit_called(obj):
return hasattr(obj, "n_features_in_")
print(is_fit_called(scaler)) #False
print(is_fit_called(scaler_fit)) #True
Sklearn implements the check_is_fitted function to check if any generic estimator is fitted, which works with StandardScaler:
from sklearn.preprocessing import StandardScaler
from sklearn.utils.validation import check_is_fitted
ss = StandardScaler()
check_is_fitted(ss) # Raises error
ss.fit([[1,2,3]])
check_is_fitted(ss) # No error
For example, ss
is an sklearn.preprocessing.StandardScaler
object. If ss
is fitted already, I want to use it to transform my data. If ss
is not fitted yet, I want to use my data to fit it and transform my data. Is there a way to know whether ss
is already fitted or not?
According to you example, in order to determine if your object is a Fitted scaler
, one checks if the attribute n_features_in_
exists in the object of interest.
from sklearn.preprocessing import StandardScaler
data = [[0, 0], [0, 0], [1, 1], [1, 1]]
scaler = StandardScaler()
scaler_fit = StandardScaler().fit(data)
def is_fit_called(obj):
return hasattr(obj, "n_features_in_")
print(is_fit_called(scaler)) #False
print(is_fit_called(scaler_fit)) #True
Sklearn implements the check_is_fitted function to check if any generic estimator is fitted, which works with StandardScaler:
from sklearn.preprocessing import StandardScaler
from sklearn.utils.validation import check_is_fitted
ss = StandardScaler()
check_is_fitted(ss) # Raises error
ss.fit([[1,2,3]])
check_is_fitted(ss) # No error