Skip to content

Commit 226962d

Browse files
committed
doc: update example
1 parent 0465162 commit 226962d

1 file changed

Lines changed: 38 additions & 19 deletions

File tree

README.rst

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,30 +49,49 @@ Example
4949

5050
.. code:: python
5151
52-
from torchensemble import VotingClassifier # Voting is a classic ensemble strategy
52+
from torchensemble import VotingClassifier # voting is a classic ensemble strategy
5353
5454
# Load data
5555
train_loader = DataLoader(...)
5656
test_loader = DataLoader(...)
5757
58-
# Define the ensemble
59-
model = VotingClassifier(estimator=base_estimator, # your deep learning model
60-
n_estimators=10) # the number of base estimators
61-
62-
# Set the optimizer
63-
model.set_optimizer("Adam", # parameter optimizer
64-
lr=learning_rate, # learning rate of the optimizer
65-
weight_decay=weight_decay) # weight decay of the optimizer
66-
67-
# Set the scheduler
68-
model.set_scheduler("CosineAnnealingLR", T_max=epochs) # (optional) learning rate scheduler
69-
70-
# Train
71-
model.fit(train_loader,
72-
epochs=epochs) # the number of training epochs
73-
74-
# Evaluate
75-
acc = model.predict(test_loader) # testing accuracy
58+
'''
59+
[Step-1] Define the ensemble
60+
'''
61+
model = VotingClassifier(
62+
estimator=base_estimator, # here is your deep learning model
63+
n_estimators=10, # number of base estimators
64+
)
65+
66+
'''
67+
[Step-2] Set the parameter optimizer
68+
'''
69+
model.set_optimizer(
70+
"Adam", # type of parameter optimizer
71+
lr=learning_rate, # learning rate of parameter optimizer
72+
weight_decay=weight_decay, # weight decay of parameter optimizer
73+
)
74+
75+
'''
76+
[Step-3] Set the learning rate scheduler
77+
'''
78+
model.set_scheduler(
79+
"CosineAnnealingLR", # type of learning rate scheduler
80+
T_max=epochs, # additional arguments on the scheduler
81+
)
82+
83+
'''
84+
[Step-4] Train the ensemble
85+
'''
86+
model.fit(
87+
train_loader,
88+
epochs=epochs, # number of training epochs
89+
)
90+
91+
'''
92+
[Step-5] Evaluate the ensemble
93+
'''
94+
acc = model.predict(test_loader) # testing accuracy
7695
7796
Supported Ensemble
7897
------------------

0 commit comments

Comments
 (0)