Skip to content

Commit 3728174

Browse files
authored
chore: code cleaning up (#71)
* chore: code cleaning up * update * Update adversarial_training.py
1 parent ced4695 commit 3728174

8 files changed

Lines changed: 83 additions & 177 deletions

File tree

torchensemble/_base.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -95,37 +95,12 @@ def __getitem__(self, index):
9595
"""Return the `index`-th base estimator in the ensemble."""
9696
return self.estimators_[index]
9797

98-
def _decide_n_outputs(self, train_loader, is_classification=True):
99-
"""
100-
Decide the number of outputs according to the `train_loader`.
101-
102-
- If `is_classification` is True, the number of outputs equals the
103-
number of distinct classes.
104-
- If `is_classification` is False, the number of outputs equals the
105-
number of target variables (e.g., `1` in univariate regression).
106-
"""
107-
if is_classification:
108-
if hasattr(train_loader.dataset, "classes"):
109-
n_outputs = len(train_loader.dataset.classes)
110-
# Infer `n_outputs` from the dataloader
111-
else:
112-
labels = []
113-
for _, (_, target) in enumerate(train_loader):
114-
labels.append(target)
115-
labels = torch.unique(torch.cat(labels))
116-
n_outputs = labels.size(0)
117-
else:
118-
for _, (_, target) in enumerate(train_loader):
119-
if len(target.size()) == 1:
120-
n_outputs = 1
121-
else:
122-
n_outputs = target.size(1)
123-
break
124-
125-
return n_outputs
98+
@abc.abstractmethod
99+
def _decide_n_outputs(self, train_loader):
100+
"""Decide the number of outputs according to the `train_loader`."""
126101

127102
def _make_estimator(self):
128-
"""Make and configure a copy of the `self.base_estimator_`."""
103+
"""Make and configure a copy of `self.base_estimator_`."""
129104

130105
# Call `deepcopy` to make a base estimator
131106
if not isinstance(self.base_estimator_, type):
@@ -161,17 +136,16 @@ def _validate_parameters(self, epochs, log_interval):
161136
self.logger.error(msg.format(log_interval))
162137
raise ValueError(msg.format(log_interval))
163138

164-
@abc.abstractmethod
165139
def set_optimizer(self, optimizer_name, **kwargs):
166-
"""
167-
Implementation on setting the parameter optimizer.
168-
"""
140+
"""Set the parameter optimizer."""
141+
self.optimizer_name = optimizer_name
142+
self.optimizer_args = kwargs
169143

170-
@abc.abstractmethod
171144
def set_scheduler(self, scheduler_name, **kwargs):
172-
"""
173-
Implementation on setting the learning rate scheduler.
174-
"""
145+
"""Set the learning rate scheduler."""
146+
self.scheduler_name = scheduler_name
147+
self.scheduler_args = kwargs
148+
self.use_scheduler_ = True
175149

176150
@abc.abstractmethod
177151
def forward(self, x):
@@ -227,6 +201,24 @@ class BaseClassifier(BaseModule):
227201
Please use the derived classes instead.
228202
"""
229203

204+
def _decide_n_outputs(self, train_loader):
205+
"""
206+
Decide the number of outputs according to the `train_loader`.
207+
The number of outputs equals the number of distinct classes for
208+
classifiers.
209+
"""
210+
if hasattr(train_loader.dataset, "classes"):
211+
n_outputs = len(train_loader.dataset.classes)
212+
# Infer `n_outputs` from the dataloader
213+
else:
214+
labels = []
215+
for _, (_, target) in enumerate(train_loader):
216+
labels.append(target)
217+
labels = torch.unique(torch.cat(labels))
218+
n_outputs = labels.size(0)
219+
220+
return n_outputs
221+
230222
@torch.no_grad()
231223
def evaluate(self, test_loader, return_loss=False):
232224
"""Docstrings decorated by downstream models."""
@@ -260,6 +252,21 @@ class BaseRegressor(BaseModule):
260252
Please use the derived classes instead.
261253
"""
262254

255+
def _decide_n_outputs(self, train_loader):
256+
"""
257+
Decide the number of outputs according to the `train_loader`.
258+
The number of outputs equals the number of target variables for
259+
regressors (e.g., `1` in univariate regression).
260+
"""
261+
for _, (_, target) in enumerate(train_loader):
262+
if len(target.size()) == 1:
263+
n_outputs = 1
264+
else:
265+
n_outputs = target.size(1)
266+
break
267+
268+
return n_outputs
269+
263270
@torch.no_grad()
264271
def evaluate(self, test_loader):
265272
"""Docstrings decorated by downstream models."""

torchensemble/adversarial_training.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@
2323
from .utils import operator as op
2424

2525

26-
__all__ = [
27-
"_BaseAdversarialTraining",
28-
"AdversarialTrainingClassifier",
29-
"AdversarialTrainingRegressor",
30-
]
26+
__all__ = ["AdversarialTrainingClassifier", "AdversarialTrainingRegressor"]
3127

3228

3329
__fit_doc = """
@@ -218,10 +214,6 @@ def _validate_parameters(self, epochs, epsilon, log_interval):
218214
"model",
219215
)
220216
class AdversarialTrainingClassifier(_BaseAdversarialTraining, BaseClassifier):
221-
def __init__(self, **kwargs):
222-
super().__init__(**kwargs)
223-
self.is_classification = True
224-
225217
@torchensemble_model_doc(
226218
"""Implementation on the data forwarding in AdversarialTrainingClassifier.""", # noqa: E501
227219
"classifier_forward",
@@ -240,17 +232,14 @@ def forward(self, x):
240232
"set_optimizer",
241233
)
242234
def set_optimizer(self, optimizer_name, **kwargs):
243-
self.optimizer_name = optimizer_name
244-
self.optimizer_args = kwargs
235+
super().set_optimizer(optimizer_name, **kwargs)
245236

246237
@torchensemble_model_doc(
247238
"""Set the attributes on scheduler for AdversarialTrainingClassifier.""", # noqa: E501
248239
"set_scheduler",
249240
)
250241
def set_scheduler(self, scheduler_name, **kwargs):
251-
self.scheduler_name = scheduler_name
252-
self.scheduler_args = kwargs
253-
self.use_scheduler_ = True
242+
super().set_scheduler(scheduler_name, **kwargs)
254243

255244
@_adversarial_training_model_doc(
256245
"""Implementation on the training stage of AdversarialTrainingClassifier.""", # noqa: E501
@@ -268,7 +257,7 @@ def fit(
268257
):
269258

270259
self._validate_parameters(epochs, epsilon, log_interval)
271-
self.n_outputs = self._decide_n_outputs(train_loader, True)
260+
self.n_outputs = self._decide_n_outputs(train_loader)
272261

273262
# Instantiate a pool of base estimators, optimizers, and schedulers.
274263
estimators = []
@@ -404,10 +393,6 @@ def predict(self, X, return_numpy=True):
404393
"model",
405394
)
406395
class AdversarialTrainingRegressor(_BaseAdversarialTraining, BaseRegressor):
407-
def __init__(self, **kwargs):
408-
super().__init__(**kwargs)
409-
self.is_classification = False
410-
411396
@torchensemble_model_doc(
412397
"""Implementation on the data forwarding in AdversarialTrainingRegressor.""", # noqa: E501
413398
"regressor_forward",
@@ -424,17 +409,14 @@ def forward(self, x):
424409
"set_optimizer",
425410
)
426411
def set_optimizer(self, optimizer_name, **kwargs):
427-
self.optimizer_name = optimizer_name
428-
self.optimizer_args = kwargs
412+
super().set_optimizer(optimizer_name, **kwargs)
429413

430414
@torchensemble_model_doc(
431415
"""Set the attributes on scheduler for AdversarialTrainingRegressor.""", # noqa: E501
432416
"set_scheduler",
433417
)
434418
def set_scheduler(self, scheduler_name, **kwargs):
435-
self.scheduler_name = scheduler_name
436-
self.scheduler_args = kwargs
437-
self.use_scheduler_ = True
419+
super().set_scheduler(scheduler_name, **kwargs)
438420

439421
@_adversarial_training_model_doc(
440422
"""Implementation on the training stage of AdversarialTrainingRegressor.""", # noqa: E501
@@ -452,7 +434,7 @@ def fit(
452434
):
453435

454436
self._validate_parameters(epochs, epsilon, log_interval)
455-
self.n_outputs = self._decide_n_outputs(train_loader, True)
437+
self.n_outputs = self._decide_n_outputs(train_loader)
456438

457439
# Instantiate a pool of base estimators, optimizers, and schedulers.
458440
estimators = []

torchensemble/bagging.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,14 @@ def forward(self, x):
115115
"set_optimizer",
116116
)
117117
def set_optimizer(self, optimizer_name, **kwargs):
118-
self.optimizer_name = optimizer_name
119-
self.optimizer_args = kwargs
118+
super().set_optimizer(optimizer_name, **kwargs)
120119

121120
@torchensemble_model_doc(
122121
"""Set the attributes on scheduler for BaggingClassifier.""",
123122
"set_scheduler",
124123
)
125124
def set_scheduler(self, scheduler_name, **kwargs):
126-
self.scheduler_name = scheduler_name
127-
self.scheduler_args = kwargs
128-
self.use_scheduler_ = True
125+
super().set_scheduler(scheduler_name, **kwargs)
129126

130127
@torchensemble_model_doc(
131128
"""Implementation on the training stage of BaggingClassifier.""", "fit"
@@ -141,7 +138,7 @@ def fit(
141138
):
142139

143140
self._validate_parameters(epochs, log_interval)
144-
self.n_outputs = self._decide_n_outputs(train_loader, True)
141+
self.n_outputs = self._decide_n_outputs(train_loader)
145142

146143
# Instantiate a pool of base estimators, optimizers, and schedulers.
147144
estimators = []
@@ -289,17 +286,14 @@ def forward(self, x):
289286
"set_optimizer",
290287
)
291288
def set_optimizer(self, optimizer_name, **kwargs):
292-
self.optimizer_name = optimizer_name
293-
self.optimizer_args = kwargs
289+
super().set_optimizer(optimizer_name, **kwargs)
294290

295291
@torchensemble_model_doc(
296292
"""Set the attributes on scheduler for BaggingRegressor.""",
297293
"set_scheduler",
298294
)
299295
def set_scheduler(self, scheduler_name, **kwargs):
300-
self.scheduler_name = scheduler_name
301-
self.scheduler_args = kwargs
302-
self.use_scheduler_ = True
296+
super().set_scheduler(scheduler_name, **kwargs)
303297

304298
@torchensemble_model_doc(
305299
"""Implementation on the training stage of BaggingRegressor.""", "fit"
@@ -315,7 +309,7 @@ def fit(
315309
):
316310

317311
self._validate_parameters(epochs, log_interval)
318-
self.n_outputs = self._decide_n_outputs(train_loader, False)
312+
self.n_outputs = self._decide_n_outputs(train_loader)
319313

320314
# Instantiate a pool of base estimators, optimizers, and schedulers.
321315
estimators = []

torchensemble/fast_geometric.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,7 @@
2424
from .utils.logging import get_tb_logger
2525

2626

27-
__all__ = [
28-
"_BaseFastGeometric",
29-
"FastGeometricClassifier",
30-
"FastGeometricRegressor",
31-
]
27+
__all__ = ["FastGeometricClassifier", "FastGeometricRegressor"]
3228

3329

3430
__fit_doc = """
@@ -175,10 +171,6 @@ def set_scheduler(self, scheduler_name, **kwargs):
175171
"""Implementation on the FastGeometricClassifier.""", "seq_model"
176172
)
177173
class FastGeometricClassifier(_BaseFastGeometric, BaseClassifier):
178-
def __init__(self, **kwargs):
179-
super().__init__(**kwargs)
180-
self.is_classification = True
181-
182174
@torchensemble_model_doc(
183175
"""Implementation on the data forwarding in FastGeometricClassifier.""", # noqa: E501
184176
"classifier_forward",
@@ -227,9 +219,7 @@ def fit(
227219
save_dir=None,
228220
):
229221
self._validate_parameters(epochs, log_interval)
230-
self.n_outputs = self._decide_n_outputs(
231-
train_loader, self.is_classification
232-
)
222+
self.n_outputs = self._decide_n_outputs(train_loader)
233223

234224
# ====================================================================
235225
# Train the dummy estimator (estimator_)
@@ -426,10 +416,6 @@ def predict(self, X, return_numpy=True):
426416
"""Implementation on the FastGeometricRegressor.""", "seq_model"
427417
)
428418
class FastGeometricRegressor(_BaseFastGeometric, BaseRegressor):
429-
def __init__(self, **kwargs):
430-
super().__init__(**kwargs)
431-
self.is_classification = False
432-
433419
@torchensemble_model_doc(
434420
"""Implementation on the data forwarding in FastGeometricRegressor.""", # noqa: E501
435421
"regressor_forward",
@@ -477,9 +463,7 @@ def fit(
477463
save_dir=None,
478464
):
479465
self._validate_parameters(epochs, log_interval)
480-
self.n_outputs = self._decide_n_outputs(
481-
train_loader, self.is_classification
482-
)
466+
self.n_outputs = self._decide_n_outputs(train_loader)
483467

484468
# ====================================================================
485469
# Train the dummy estimator (estimator_)

torchensemble/fusion.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,14 @@ def forward(self, x):
4949
"set_optimizer",
5050
)
5151
def set_optimizer(self, optimizer_name, **kwargs):
52-
self.optimizer_name = optimizer_name
53-
self.optimizer_args = kwargs
52+
super().set_optimizer(optimizer_name, **kwargs)
5453

5554
@torchensemble_model_doc(
5655
"""Set the attributes on scheduler for FusionClassifier.""",
5756
"set_scheduler",
5857
)
5958
def set_scheduler(self, scheduler_name, **kwargs):
60-
self.scheduler_name = scheduler_name
61-
self.scheduler_args = kwargs
62-
self.use_scheduler_ = True
59+
super().set_scheduler(scheduler_name, **kwargs)
6360

6461
@torchensemble_model_doc(
6562
"""Implementation on the training stage of FusionClassifier.""", "fit"
@@ -78,7 +75,7 @@ def fit(
7875
for _ in range(self.n_estimators):
7976
self.estimators_.append(self._make_estimator())
8077
self._validate_parameters(epochs, log_interval)
81-
self.n_outputs = self._decide_n_outputs(train_loader, True)
78+
self.n_outputs = self._decide_n_outputs(train_loader)
8279
optimizer = set_module.set_optimizer(
8380
self, self.optimizer_name, **self.optimizer_args
8481
)
@@ -192,17 +189,14 @@ def forward(self, x):
192189
"set_optimizer",
193190
)
194191
def set_optimizer(self, optimizer_name, **kwargs):
195-
self.optimizer_name = optimizer_name
196-
self.optimizer_args = kwargs
192+
super().set_optimizer(optimizer_name, **kwargs)
197193

198194
@torchensemble_model_doc(
199195
"""Set the attributes on scheduler for FusionRegressor.""",
200196
"set_scheduler",
201197
)
202198
def set_scheduler(self, scheduler_name, **kwargs):
203-
self.scheduler_name = scheduler_name
204-
self.scheduler_args = kwargs
205-
self.use_scheduler_ = True
199+
super().set_scheduler(scheduler_name, **kwargs)
206200

207201
@torchensemble_model_doc(
208202
"""Implementation on the training stage of FusionRegressor.""", "fit"
@@ -220,7 +214,7 @@ def fit(
220214
for _ in range(self.n_estimators):
221215
self.estimators_.append(self._make_estimator())
222216
self._validate_parameters(epochs, log_interval)
223-
self.n_outputs = self._decide_n_outputs(train_loader, False)
217+
self.n_outputs = self._decide_n_outputs(train_loader)
224218
optimizer = set_module.set_optimizer(
225219
self, self.optimizer_name, **self.optimizer_args
226220
)

0 commit comments

Comments
 (0)