diff --git a/modnet/sklearn.py b/modnet/sklearn.py index ba59d15..50a929c 100644 --- a/modnet/sklearn.py +++ b/modnet/sklearn.py @@ -24,6 +24,7 @@ (see https://scikit-learn.org/stable/developers/develop.html#instantiation). """ +import numpy as np from sklearn.base import BaseEstimator from sklearn.base import RegressorMixin from sklearn.base import TransformerMixin @@ -32,7 +33,9 @@ get_features_relevance_redundancy, get_cross_nmi, nmi_target, + merge_ranked, ) +from modnet.utils import LOG class MODNetFeaturizer(TransformerMixin, BaseEstimator): @@ -77,7 +80,12 @@ class RR(TransformerMixin, BaseEstimator): """ def __init__( - self, n_feat: Union[None, int] = None, rr_parameters: Union[None, Dict] = None + self, + n_feat: Union[None, int] = None, + rr_parameters: Union[None, Dict] = None, + n_jobs: Union[None, int] = None, + cross_nmi_kwargs: Union[None, Dict] = None, + target_nmi_kwargs: Union[None, Dict] = None, ): """Constructor for RR transformer. @@ -87,10 +95,19 @@ def __init__( to constant values instead of using the dynamical evaluation. Expects to find keys `"p"` and `"c"`, containing either a callable that takes `n` as an argument and returns the desired `p` or `c`, or another dictionary containing the key `"value"` that stores a constant value of `p` or `c`. + n_jobs: max number of processes to use when calculating cross NMI. + cross_nmi_kwargs: Keyword arguments to be passed down to the modnet.preprocessing.get_cross_nmi + target_nmi_kwargs: Keyword arguments to be passed down to the modnet.preprocessing.nmi_target """ self.n_feat = n_feat self.rr_parameters = rr_parameters self.optimal_descriptors = [] + self.optimal_features_by_target = {} + self.n_jobs = n_jobs + self.cross_nmi_kwargs = cross_nmi_kwargs if cross_nmi_kwargs is not None else {} + self.target_nmi_kwargs = ( + target_nmi_kwargs if target_nmi_kwargs is not None else {} + ) def fit(self, X, y, nmi_feats_target=None, cross_nmi_feats=None): """Ranking of the features. This is based on relevance and redundancy provided as NMI dataframes. @@ -108,18 +125,52 @@ def fit(self, X, y, nmi_feats_target=None, cross_nmi_feats=None): Fitted RR transformer """ + ranked_lists = [] + if cross_nmi_feats is None: - cross_nmi_feats = get_cross_nmi(X) + cross_nmi_feats = get_cross_nmi( + X, n_jobs=self.n_jobs, **self.cross_nmi_kwargs + ) + if nmi_feats_target is None: - nmi_feats_target = nmi_target(X, y) + for name in list(y): + LOG.info(f"Starting NMI computations for target {name}") + X_temp = X.copy() + y_temp = y[[name]] + + nmi_feats_target = nmi_target(X_temp, y_temp, **self.target_nmi_kwargs) + + missing = [ + x for x in nmi_feats_target.index if x not in cross_nmi_feats.index + ] + nmi_feats_target = nmi_feats_target.drop(missing, axis=0) + nmi_feats_target = nmi_feats_target.astype(np.float64) + + rr_results = get_features_relevance_redundancy( + nmi_feats_target, + cross_nmi_feats, + n_feat=self.n_feat, + rr_parameters=self.rr_parameters, + ) + + self.optimal_features_by_target[name] = [ + x["feature"] for x in rr_results + ] + + ranked_lists.append(self.optimal_features_by_target[name]) + + if ranked_lists: + self.optimal_descriptors = merge_ranked(ranked_lists) + else: + rr_results = get_features_relevance_redundancy( + nmi_feats_target, + cross_nmi_feats, + n_feat=self.n_feat, + rr_parameters=self.rr_parameters, + ) + self.optimal_descriptors = [x["feature"] for x in rr_results] - rr_results = get_features_relevance_redundancy( - nmi_feats_target, - cross_nmi_feats, - n_feat=self.n_feat, - rr_parameters=self.rr_parameters, - ) - self.optimal_descriptors = [x["feature"] for x in rr_results] + return self def transform(self, X, y=None): """Transform the inputs X based on a fitted RR analysis. The best n_feat features are kept and returned.