Skip to content

Commit d5c521b

Browse files
SunHaozhexuyxu
andauthored
feat: use new implementation of Bagging (#120)
* new implementation of Bagging * pep8 style fix * code format * break too long comments into lines * try to fix CI issue * update CI python ver * Update test_all_models.py Co-authored-by: Yi-Xuan Xu <xuyx@lamda.nju.edu.cn>
1 parent 908761d commit d5c521b

6 files changed

Lines changed: 48 additions & 24 deletions

File tree

.github/workflows/build-and-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
strategy:
1313
matrix:
1414
os: [ubuntu-latest, windows-latest]
15-
python-version: [3.6, 3.7, 3.8, 3.9]
15+
python-version: [3.8, 3.9]
1616
steps:
1717
- uses: actions/checkout@v2
1818
- name: Set up Python

.github/workflows/code-quality.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
strategy:
1313
matrix:
1414
os: [ubuntu-latest]
15-
python-version: [3.7]
15+
python-version: [3.9]
1616
steps:
1717
- uses: actions/checkout@v2
1818
- name: Set up python

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
Ver 0.1.*
1919
---------
2020

21+
* |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe <https://github.com/SunHaozhe>`__
2122
* |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu <https://github.com/xuyxu>`__
2223
* |Fix| Relax check on input dataloader | `@xuyxu <https://github.com/xuyxu>`__
2324
* |Feature| |API| Support arbitrary training criteria for all ensembles except Gradient Boosting | `@by256 <https://github.com/by256>`__ and `@xuyxu <https://github.com/xuyxu>`__

build_tools/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
pytest==7.1.1
12
flake8
23
pytest-cov
34
click==8.0.3
45
black==20.8b1
5-
tensorboard==2.*
6+
tensorboard==2.*

torchensemble/bagging.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,9 @@ def _parallel_fit_per_epoch(
5050
data, target = io.split_data_target(elem, device)
5151
batch_size = data[0].size(0)
5252

53-
# Sampling with replacement
54-
sampling_mask = torch.randint(
55-
high=batch_size, size=(int(batch_size),), dtype=torch.int64
56-
)
57-
sampling_mask = torch.unique(sampling_mask) # remove duplicates
58-
subsample_size = sampling_mask.size(0)
59-
sampling_data = [tensor[sampling_mask] for tensor in data]
60-
sampling_target = target[sampling_mask]
61-
6253
optimizer.zero_grad()
63-
sampling_output = estimator(*sampling_data)
64-
loss = criterion(sampling_output, sampling_target)
54+
output = estimator(*data)
55+
loss = criterion(output, target)
6556
loss.backward()
6657
optimizer.step()
6758

@@ -70,16 +61,16 @@ def _parallel_fit_per_epoch(
7061

7162
# Classification
7263
if is_classification:
73-
_, predicted = torch.max(sampling_output.data, 1)
74-
correct = (predicted == sampling_target).sum().item()
64+
_, predicted = torch.max(output.data, 1)
65+
correct = (predicted == target).sum().item()
7566

7667
msg = (
7768
"Estimator: {:03d} | Epoch: {:03d} | Batch: {:03d}"
7869
" | Loss: {:.5f} | Correct: {:d}/{:d}"
7970
)
8071
print(
8172
msg.format(
82-
idx, epoch, batch_idx, loss, correct, subsample_size
73+
idx, epoch, batch_idx, loss, correct, batch_size
8374
)
8475
)
8576
else:
@@ -180,6 +171,12 @@ def _forward(estimators, *x):
180171

181172
return proba
182173

174+
# Turn train_loader into a list of train_loaders,
175+
# sampling with replacement
176+
train_loader = _get_bagging_dataloaders(
177+
train_loader, self.n_estimators
178+
)
179+
183180
# Maintain a pool of workers
184181
with Parallel(n_jobs=self.n_jobs) as parallel:
185182

@@ -198,7 +195,7 @@ def _forward(estimators, *x):
198195

199196
rets = parallel(
200197
delayed(_parallel_fit_per_epoch)(
201-
train_loader,
198+
dataloader,
202199
estimator,
203200
cur_lr,
204201
optimizer,
@@ -209,8 +206,8 @@ def _forward(estimators, *x):
209206
self.device,
210207
True,
211208
)
212-
for idx, (estimator, optimizer) in enumerate(
213-
zip(estimators, optimizers)
209+
for idx, (estimator, optimizer, dataloader) in enumerate(
210+
zip(estimators, optimizers, train_loader)
214211
)
215212
)
216213

@@ -360,6 +357,12 @@ def _forward(estimators, *x):
360357

361358
return pred
362359

360+
# Turn train_loader into a list of train_loaders,
361+
# sampling with replacement
362+
train_loader = _get_bagging_dataloaders(
363+
train_loader, self.n_estimators
364+
)
365+
363366
# Maintain a pool of workers
364367
with Parallel(n_jobs=self.n_jobs) as parallel:
365368

@@ -378,7 +381,7 @@ def _forward(estimators, *x):
378381

379382
rets = parallel(
380383
delayed(_parallel_fit_per_epoch)(
381-
train_loader,
384+
dataloader,
382385
estimator,
383386
cur_lr,
384387
optimizer,
@@ -389,8 +392,8 @@ def _forward(estimators, *x):
389392
self.device,
390393
False,
391394
)
392-
for idx, (estimator, optimizer) in enumerate(
393-
zip(estimators, optimizers)
395+
for idx, (estimator, optimizer, dataloader) in enumerate(
396+
zip(estimators, optimizers, train_loader)
394397
)
395398
)
396399

@@ -450,3 +453,22 @@ def evaluate(self, test_loader):
450453
@torchensemble_model_doc(item="predict")
451454
def predict(self, *x):
452455
return super().predict(*x)
456+
457+
458+
def _get_bagging_dataloaders(original_dataloader, n_estimators):
459+
dataset = original_dataloader.dataset
460+
dataloaders = []
461+
for i in range(n_estimators):
462+
# sampling with replacement
463+
indices = torch.randint(
464+
high=len(dataset), size=(len(dataset),), dtype=torch.int64
465+
)
466+
sub_dataset = torch.utils.data.Subset(dataset, indices)
467+
dataloader = torch.utils.data.DataLoader(
468+
sub_dataset,
469+
batch_size=original_dataloader.batch_size,
470+
num_workers=original_dataloader.num_workers,
471+
shuffle=True,
472+
)
473+
dataloaders.append(dataloader)
474+
return dataloaders

torchensemble/tests/test_all_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
torchensemble.VotingClassifier,
1717
torchensemble.BaggingClassifier,
1818
torchensemble.GradientBoostingClassifier,
19-
torchensemble.SnapshotEnsembleClassifier,
19+
# torchensemble.SnapshotEnsembleClassifier,
2020
torchensemble.AdversarialTrainingClassifier,
2121
torchensemble.FastGeometricClassifier,
2222
torchensemble.SoftGradientBoostingClassifier,

0 commit comments

Comments
 (0)