Skip to content

Commit b914d3f

Browse files
authored
feat(fge): simplify the training workflow (#62)
* initial commit * update example
1 parent c4f9293 commit b914d3f

7 files changed

Lines changed: 52 additions & 355 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Changelog
44
Ver 0.1.*
55
---------
66

7+
* |Enhancement| |API| Simplify the training workflow of :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu <https://github.com/xuyxu>`__
78
* |Feature| |API| Support TensorBoard logging in :meth:`set_logger` | `@zzzzwj <https://github.com/zzzzwj>`__
89
* |Enhancement| |API| Add ``use_reduction_sum`` parameter for :meth:`fit` of Gradient Boosting | `@xuyxu <https://github.com/xuyxu>`__
910
* |Feature| |API| Improve the functionality of :meth:`evaluate` and :meth:`predict` | `@xuyxu <https://github.com/xuyxu>`__

docs/parameters.rst

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -147,21 +147,6 @@ Reference:
147147
T. Garipov, P. Izmailov, D. Podoprikhin et al., Loss Surfaces, Mode
148148
Connectivity, and Fast Ensembling of DNNs, NeurIPS, 2018.
149149

150-
Notice that unlike all ensembles above, using fast geometric ensemble (FGE) is
151-
**a two-staged process**. Concretely, you first need to call :meth:`fit` to
152-
build a dummy base estimator that will be used to generate ensembles. Second,
153-
you need to call :meth:`ensemble` to generate real base estimators in the
154-
ensemble. The pipeline is shown in the following code snippet:
155-
156-
.. code:: python
157-
158-
model = FastGeometricClassifier(**ensemble_related_args)
159-
estimator = model.fit(train_loader, **base_estimator_related_args) # train the base estimator
160-
model.ensemble(estimator, train_loader, **fge_related_args) # generate the ensemble using the base estimator
161-
162-
You can refer to scripts in `examples <https://github.com/xuyxu/Ensemble-Pytorch/tree/master/examples>`__ for
163-
a detailed example.
164-
165150
FastGeometricClassifier
166151
***********************
167152

examples/fast_geometric_ensemble_cifar10_resnet18.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,9 @@ def forward(self, x):
161161
train_loader,
162162
epochs=epochs,
163163
test_loader=test_loader,
164-
)
165-
166-
# Ensemble
167-
model.ensemble(
168-
estimator,
169-
train_loader,
170164
cycle=4,
171165
lr_1=5e-2,
172166
lr_2=5e-4,
173-
test_loader=test_loader,
174167
)
175168

176169
# Evaluate

0 commit comments

Comments
 (0)