When building custom estimators or transformers that integrate with the Scikit-learn ecosystem, ensuring robustness and providing a clear user experience is very important. A significant part of this involves rigorously validating the parameters your component accepts. Just as standard Scikit-learn components check their inputs, your custom classes should do the same. This prevents unexpected runtime errors, guides users towards correct usage, and makes your components more reliable and easier to debug.
Parameter validation primarily occurs within the __init__
method of your custom class, although some checks might be deferred until methods like fit
or transform
when parameter interactions with data become relevant. The goal is to catch invalid inputs as early as possible.
ValueError: 'solver' must be one of ['svd', 'lsqr', 'eigen'], got 'foo'
are far more helpful than a complex traceback originating from matrix decomposition failures. Good validation acts as interactive documentation.fit
or transform
, rather than questioning the inputs.You should consider several types of checks for your parameters:
int
, float
, str
, list
, callable
, bool
, None
)? Use isinstance()
for this. Be mindful of acceptable variations, like allowing both integers and floats for a numerical parameter using isinstance(param, (int, float))
.n_neighbors > 0
, alpha >= 0.0
, ratio between (0, 1)
).penalty in {'l1', 'l2', 'elasticnet'}
).True
, False
, or None
are permitted.callable()
and potentially its signature using the inspect
module if specific arguments are required.method='A'
, perhaps param_for_A
must be provided, while param_for_B
must be None
. These checks often involve inspecting multiple self
attributes within __init__
.__init__
The most straightforward approach is using conditional logic directly within your __init__
method.
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted
class CustomScaler(BaseEstimator, TransformerMixin):
"""
A custom scaler that scales features by a constant factor.
Parameters
----------
scale_factor : float, default=1.0
The factor to multiply features by. Must be positive.
strategy : {'multiply', 'divide'}, default='multiply'
Whether to multiply or divide by the scale_factor.
"""
def __init__(self, scale_factor=1.0, strategy='multiply'):
# Manual Type and Value Validation
if not isinstance(scale_factor, (int, float)):
raise TypeError(f"Parameter 'scale_factor' must be numeric, got {type(scale_factor).__name__}")
if scale_factor <= 0:
raise ValueError(f"Parameter 'scale_factor' must be positive, got {scale_factor}")
allowed_strategies = {'multiply', 'divide'}
if strategy not in allowed_strategies:
raise ValueError(f"Parameter 'strategy' must be one of {allowed_strategies}, got '{strategy}'")
self.scale_factor = scale_factor
self.strategy = strategy
def fit(self, X, y=None):
# No fitting necessary for this simple transformer
# Optional: Add input validation for X here using check_array
# from sklearn.utils.validation import check_array
# X = check_array(X)
self._n_features_in = X.shape[1]
return self
def transform(self, X):
check_is_fitted(self)
# Optional: Add input validation for X here
# X = check_array(X)
if X.shape[1] != self._n_features_in:
raise ValueError(f"Input has {X.shape[1]} features, expected {self._n_features_in}")
if self.strategy == 'multiply':
return X * self.scale_factor
elif self.strategy == 'divide':
# Add check for division by zero if scale_factor could be zero (already prevented by init validation)
return X / self.scale_factor
This is clear for simple cases but can become verbose and repetitive if you have many parameters with complex constraints.
validate_parameter_constraints
Scikit-learn provides a more structured and declarative way to handle parameter validation using the validate_parameter_constraints
decorator and a class-level dictionary named _parameter_constraints
. This approach promotes consistency and reduces boilerplate code in __init__
.
You define the constraints for each parameter in the _parameter_constraints
dictionary. The keys are the parameter names, and the values are lists of allowed types or constraint objects (like Interval
, StrOptions
).
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted, check_array
# Import constraint classes and the decorator
from sklearn.utils._param_validation import validate_parameter_constraints, Interval, StrOptions
import numbers # Use numbers.Real for broader numeric types
class CustomScalerValidated(BaseEstimator, TransformerMixin):
"""
A custom scaler that scales features by a constant factor.
Uses validate_parameter_constraints for validation.
Parameters
----------
scale_factor : float, default=1.0
The factor to multiply features by. Must be positive.
strategy : {'multiply', 'divide'}, default='multiply'
Whether to multiply or divide by the scale_factor.
"""
# Define constraints using the special class attribute
_parameter_constraints: dict = {
"scale_factor": [Interval(numbers.Real, 0, None, closed="neither")], # Must be real, > 0
"strategy": [StrOptions({"multiply", "divide"})], # Must be one of these strings
}
# Apply the decorator to the class
@validate_parameter_constraints
def __init__(self, scale_factor=1.0, strategy='multiply'):
# No explicit validation code needed here!
# The decorator handles it based on _parameter_constraints.
self.scale_factor = scale_factor
self.strategy = strategy
def fit(self, X, y=None):
X = check_array(X, ensure_2d=True, dtype=np.float64) # Validate X input in fit
self._n_features_in = X.shape[1]
# Store any fitted attributes here (none needed for this simple scaler)
# self.mean_ = X.mean(axis=0) # Example if fitting was needed
return self
def transform(self, X):
check_is_fitted(self) # Check if fit has been called
X = check_array(X, ensure_2d=True, dtype=np.float64) # Validate X input in transform
if X.shape[1] != self._n_features_in:
raise ValueError(f"Input has {X.shape[1]} features, but scaler was fitted with {self._n_features_in} features.")
if self.strategy == 'multiply':
return X * self.scale_factor
elif self.strategy == 'divide':
# The constraint ensures scale_factor > 0, so no division by zero
return X / self.scale_factor
# Example of how validation works:
try:
scaler = CustomScalerValidated(scale_factor=-2.0)
except ValueError as e:
print(f"Validation Error: {e}")
# Expected output: Validation Error: The 'scale_factor' parameter of CustomScalerValidated must be a strictly positive real number. Got -2.0 instead.
try:
scaler = CustomScalerValidated(strategy='add')
except ValueError as e:
print(f"Validation Error: {e}")
# Expected output: Validation Error: The 'strategy' parameter of CustomScalerValidated must be a str among {'multiply', 'divide'}. Got 'add' instead.
The validate_parameter_constraints
decorator inspects the arguments passed to __init__
, compares them against the rules defined in _parameter_constraints
, and raises informative TypeError
or ValueError
exceptions if they don't match. This is the recommended approach for modern Scikit-learn compatible components. Available constraints include Interval
(for numerical ranges), StrOptions
(for string options), Options
(generic set membership), type hints (like int
, float
, list
, callable
, np.ndarray
), and None
. You can also combine constraints in the list (e.g., [StrOptions({"auto", "manual"}), None]
to allow specific strings or None
).
__init__
, always store the parameters passed by the user directly as attributes with the same name (e.g., self.scale_factor = scale_factor
). Scikit-learn's tools like get_params
and set_params
rely on this.get_params
and set_params
: Ensure your estimator inherits from BaseEstimator
(or includes mixins that do). This provides default get_params
and set_params
methods essential for hyperparameter tuning (like GridSearchCV
) and cloning. These methods work by inspecting the __init__
signature and accessing attributes with corresponding names.__init__
(ideally using validate_parameter_constraints
).X
(like shape compatibility) to the fit
or transform
methods. Use Scikit-learn's data validation utilities like check_array
and check_X_y
here.check_is_fitted
at the beginning of transform
, predict
, etc., to ensure fit
has been called.__init__
docstring, clearly stating its type, purpose, allowed range or options, and default value. This complements programmatic validation.By implementing thorough parameter validation and management, you create custom Scikit-learn components that are not only functionally correct but also robust, user-friendly, and well-integrated into the broader machine learning toolkit. Using validate_parameter_constraints
is the preferred method for achieving this efficiently and consistently with Scikit-learn best practices.
© 2025 ApX Machine Learning