Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 61 additions & 10 deletions modnet/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
(see https://scikit-learn.org/stable/developers/develop.html#instantiation).
"""

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin
from sklearn.base import TransformerMixin
Expand All @@ -32,7 +33,9 @@
get_features_relevance_redundancy,
get_cross_nmi,
nmi_target,
merge_ranked,
)
from modnet.utils import LOG


class MODNetFeaturizer(TransformerMixin, BaseEstimator):
Expand Down Expand Up @@ -77,7 +80,12 @@ class RR(TransformerMixin, BaseEstimator):
"""

def __init__(
self, n_feat: Union[None, int] = None, rr_parameters: Union[None, Dict] = None
self,
n_feat: Union[None, int] = None,
rr_parameters: Union[None, Dict] = None,
n_jobs: Union[None, int] = None,
cross_nmi_kwargs: Union[None, Dict] = None,
target_nmi_kwargs: Union[None, Dict] = None,
):
"""Constructor for RR transformer.

Expand All @@ -87,10 +95,19 @@ def __init__(
to constant values instead of using the dynamical evaluation. Expects to find keys `"p"` and `"c"`,
containing either a callable that takes `n` as an argument and returns the desired `p` or `c`,
or another dictionary containing the key `"value"` that stores a constant value of `p` or `c`.
n_jobs: max number of processes to use when calculating cross NMI.
cross_nmi_kwargs: Keyword arguments to be passed down to the modnet.preprocessing.get_cross_nmi
target_nmi_kwargs: Keyword arguments to be passed down to the modnet.preprocessing.nmi_target
"""
self.n_feat = n_feat
self.rr_parameters = rr_parameters
self.optimal_descriptors = []
self.optimal_features_by_target = {}
self.n_jobs = n_jobs
self.cross_nmi_kwargs = cross_nmi_kwargs if cross_nmi_kwargs is not None else {}
self.target_nmi_kwargs = (
target_nmi_kwargs if target_nmi_kwargs is not None else {}
)

def fit(self, X, y, nmi_feats_target=None, cross_nmi_feats=None):
"""Ranking of the features. This is based on relevance and redundancy provided as NMI dataframes.
Expand All @@ -108,18 +125,52 @@ def fit(self, X, y, nmi_feats_target=None, cross_nmi_feats=None):
Fitted RR transformer
"""

ranked_lists = []

if cross_nmi_feats is None:
cross_nmi_feats = get_cross_nmi(X)
cross_nmi_feats = get_cross_nmi(
X, n_jobs=self.n_jobs, **self.cross_nmi_kwargs
)

if nmi_feats_target is None:
nmi_feats_target = nmi_target(X, y)
for name in list(y):
LOG.info(f"Starting NMI computations for target {name}")
X_temp = X.copy()
y_temp = y[[name]]

nmi_feats_target = nmi_target(X_temp, y_temp, **self.target_nmi_kwargs)

missing = [
x for x in nmi_feats_target.index if x not in cross_nmi_feats.index
]
nmi_feats_target = nmi_feats_target.drop(missing, axis=0)
nmi_feats_target = nmi_feats_target.astype(np.float64)

rr_results = get_features_relevance_redundancy(
nmi_feats_target,
cross_nmi_feats,
n_feat=self.n_feat,
rr_parameters=self.rr_parameters,
)

self.optimal_features_by_target[name] = [
x["feature"] for x in rr_results
]

ranked_lists.append(self.optimal_features_by_target[name])

if ranked_lists:
self.optimal_descriptors = merge_ranked(ranked_lists)
else:
rr_results = get_features_relevance_redundancy(
nmi_feats_target,
cross_nmi_feats,
n_feat=self.n_feat,
rr_parameters=self.rr_parameters,
)
self.optimal_descriptors = [x["feature"] for x in rr_results]

rr_results = get_features_relevance_redundancy(
nmi_feats_target,
cross_nmi_feats,
n_feat=self.n_feat,
rr_parameters=self.rr_parameters,
)
self.optimal_descriptors = [x["feature"] for x in rr_results]
return self

def transform(self, X, y=None):
"""Transform the inputs X based on a fitted RR analysis. The best n_feat features are kept and returned.
Expand Down