Skip to content

Commit c4f9293

Browse files
zzzzwjzzzzwjxuyxu
authored
feat(logging): add tensorboard logging (#61)
* add logger by tensorboard * Fix small issue of Tensorboard Logger * reformat the code with flake8 & black * Add tensorboard to requirements.txt * Reformat code by black * Refactor the tensorboard logging module * fix conflicts * Add unit test script of tensorboard-logging * Reformat the code * Update CHANGELOG.rst * update doc * pin tensorboard version * revert default value Co-authored-by: zzzzwj <zwj@smail.nju.edu.cn> Co-authored-by: Yi-Xuan Xu <xuyx@lamda.nju.edu.cn>
1 parent f24e689 commit c4f9293

19 files changed

Lines changed: 586 additions & 61 deletions

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Changelog
44
Ver 0.1.*
55
---------
66

7+
* |Feature| |API| Support TensorBoard logging in :meth:`set_logger` | `@zzzzwj <https://github.com/zzzzwj>`__
78
* |Enhancement| |API| Add ``use_reduction_sum`` parameter for :meth:`fit` of Gradient Boosting | `@xuyxu <https://github.com/xuyxu>`__
89
* |Feature| |API| Improve the functionality of :meth:`evaluate` and :meth:`predict` | `@xuyxu <https://github.com/xuyxu>`__
910
* |Feature| |API| Add :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu <https://github.com/xuyxu>`__

docs/quick_start.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ Ensemble-PyTorch uses a global logger to track and print the intermediate inform
5757
5858
logger = set_logger("classification_mnist_mlp")
5959
60-
Using the logger, all intermediate information will be printed on the command line and saved to the specified text file: classification_mnist_mlp.
60+
With this logger, all intermediate information will be printed on the command line and saved to the specified text file: ``classification_mnist_mlp``.
61+
62+
In addition, when passing ``use_tb_logger=True`` into the method :meth:`set_logger`, you can use tensorboard to have a better visualization result on training and evaluating the ensemble.
63+
64+
.. code-block:: bash
65+
66+
tensorboard --logdir=logs/
6167
6268
Choose the Ensemble
6369
-------------------

examples/classification_cifar10_cnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def forward(self, x):
9696
shuffle=True,
9797
)
9898

99-
logger = set_logger("classification_cifar10_cnn")
99+
logger = set_logger("classification_cifar10_cnn", use_tb_logger=True)
100100

101101
# FusionClassifier
102102
model = FusionClassifier(

examples/fast_geometric_ensemble_cifar10_resnet18.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def forward(self, x):
136136
)
137137

138138
# Set the Logger
139-
logger = set_logger("FastGeometricClassifier_cifar10_resnet")
139+
logger = set_logger(
140+
"FastGeometricClassifier_cifar10_resnet", use_tb_logger=True
141+
)
140142

141143
# Choose the Ensemble Method
142144
model = FastGeometricClassifier(
@@ -155,7 +157,11 @@ def forward(self, x):
155157
model.set_scheduler("CosineAnnealingLR", T_max=epochs)
156158

157159
# Train
158-
estimator = model.fit(train_loader, epochs=epochs, test_loader=test_loader)
160+
estimator = model.fit(
161+
train_loader,
162+
epochs=epochs,
163+
test_loader=test_loader,
164+
)
159165

160166
# Ensemble
161167
model.ensemble(

examples/regression_YearPredictionMSD_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def forward(self, x):
9898
train_loader, test_loader = load_data(batch_size)
9999
print("Finish loading data...\n")
100100

101-
logger = set_logger("regression_YearPredictionMSD_mlp")
101+
logger = set_logger("regression_YearPredictionMSD_mlp", use_tb_logger=True)
102102

103103
# FusionRegressor
104104
model = FusionRegressor(

examples/snapshot_ensemble_cifar10_resnet18.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def forward(self, x):
136136
)
137137

138138
# Set the Logger
139-
logger = set_logger("snapshot_ensemble_cifar10_resnet18")
139+
logger = set_logger(
140+
"snapshot_ensemble_cifar10_resnet18", use_tb_logger=True
141+
)
140142

141143
# Choose the Ensemble Method
142144
model = SnapshotEnsembleClassifier(
@@ -152,4 +154,8 @@ def forward(self, x):
152154
)
153155

154156
# Train and Evaluate
155-
model.fit(train_loader, epochs=epochs, test_loader=test_loader)
157+
model.fit(
158+
train_loader,
159+
epochs=epochs,
160+
test_loader=test_loader,
161+
)

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
torch>=1.4.0
22
torchvision>=0.2.2
3-
scikit-learn>=0.23.0
3+
scikit-learn>=0.23.0
4+
tensorboard==2.*

setup.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
here = path.abspath(path.dirname(__file__))
99

1010
# Get the long description from README.rst
11-
with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
11+
with open(path.join(here, "README.rst"), encoding="utf-8") as f:
1212
long_description = f.read()
1313

1414
# get the dependencies and installs
15-
with open(path.join(here, 'requirements.txt'), encoding='utf-8') as f:
16-
all_reqs = f.read().split('\n')
15+
with open(path.join(here, "requirements.txt"), encoding="utf-8") as f:
16+
all_reqs = f.read().split("\n")
1717

18-
install_requires = [x.strip() for x in all_reqs if 'git+' not in x]
18+
install_requires = [x.strip() for x in all_reqs if "git+" not in x]
1919

2020
cmdclass = {}
2121

@@ -28,12 +28,12 @@ def run(self):
2828
Clean.run(self)
2929
# Remove c files if we are not within a sdist package
3030
cwd = os.path.abspath(os.path.dirname(__file__))
31-
remove_c_files = not os.path.exists(os.path.join(cwd, 'PKG-INFO'))
31+
remove_c_files = not os.path.exists(os.path.join(cwd, "PKG-INFO"))
3232
if remove_c_files:
33-
print('Will remove generated .c files')
34-
if os.path.exists('build'):
35-
shutil.rmtree('build')
36-
for dirpath, dirnames, filenames in os.walk('sklearn'):
33+
print("Will remove generated .c files")
34+
if os.path.exists("build"):
35+
shutil.rmtree("build")
36+
for dirpath, dirnames, filenames in os.walk("sklearn"):
3737
for filename in filenames:
3838
if any(
3939
filename.endswith(suffix)
@@ -42,50 +42,50 @@ def run(self):
4242
os.unlink(os.path.join(dirpath, filename))
4343
continue
4444
extension = os.path.splitext(filename)[1]
45-
if remove_c_files and extension in ['.c', '.cpp']:
46-
pyx_file = str.replace(filename, extension, '.pyx')
45+
if remove_c_files and extension in [".c", ".cpp"]:
46+
pyx_file = str.replace(filename, extension, ".pyx")
4747
if os.path.exists(os.path.join(dirpath, pyx_file)):
4848
os.unlink(os.path.join(dirpath, filename))
4949
for dirname in dirnames:
50-
if dirname == '__pycache__':
50+
if dirname == "__pycache__":
5151
shutil.rmtree(os.path.join(dirpath, dirname))
5252

5353

54-
cmdclass.update({'clean': CleanCommand})
54+
cmdclass.update({"clean": CleanCommand})
5555

5656

5757
setup(
58-
name='torchensemble',
59-
maintainer='Yi-Xuan Xu',
60-
maintainer_email='xuyx@lamda.nju.edu.cn',
58+
name="torchensemble",
59+
maintainer="Yi-Xuan Xu",
60+
maintainer_email="xuyx@lamda.nju.edu.cn",
6161
description=(
62-
'Implementations of scikit-learn like ensemble methods in Pytorch'
62+
"Implementations of scikit-learn like ensemble methods in Pytorch"
6363
),
64-
license='BSD 3-Clause',
65-
url='https://github.com/xuyxu/Ensemble-Pytorch',
64+
license="BSD 3-Clause",
65+
url="https://github.com/xuyxu/Ensemble-Pytorch",
6666
project_urls={
67-
'Bug Tracker': 'https://github.com/xuyxu/Ensemble-Pytorch/issues',
68-
'Documentation': 'https://ensemble-pytorch.readthedocs.io',
69-
'Source Code': 'https://github.com/xuyxu/Ensemble-Pytorch',
67+
"Bug Tracker": "https://github.com/xuyxu/Ensemble-Pytorch/issues",
68+
"Documentation": "https://ensemble-pytorch.readthedocs.io",
69+
"Source Code": "https://github.com/xuyxu/Ensemble-Pytorch",
7070
},
71-
version='0.1.2',
71+
version="0.1.2",
7272
long_description=long_description,
7373
classifiers=[
74-
'Intended Audience :: Science/Research',
75-
'Intended Audience :: Developers',
76-
'Topic :: Software Development',
77-
'Topic :: Scientific/Engineering',
78-
'Development Status :: 4 - Beta',
79-
'Operating System :: Microsoft :: Windows',
80-
'Operating System :: POSIX',
81-
'Operating System :: Unix',
82-
'Operating System :: MacOS',
83-
'Programming Language :: Python :: 3',
84-
'Programming Language :: Python :: 3.6',
85-
'Programming Language :: Python :: 3.7',
86-
'Programming Language :: Python :: 3.8',
74+
"Intended Audience :: Science/Research",
75+
"Intended Audience :: Developers",
76+
"Topic :: Software Development",
77+
"Topic :: Scientific/Engineering",
78+
"Development Status :: 4 - Beta",
79+
"Operating System :: Microsoft :: Windows",
80+
"Operating System :: POSIX",
81+
"Operating System :: Unix",
82+
"Operating System :: MacOS",
83+
"Programming Language :: Python :: 3",
84+
"Programming Language :: Python :: 3.6",
85+
"Programming Language :: Python :: 3.7",
86+
"Programming Language :: Python :: 3.8",
8787
],
88-
keywords=['PyTorch', 'Ensemble Learning'],
88+
keywords=["PyTorch", "Ensemble Learning"],
8989
packages=find_packages(),
9090
cmdclass=cmdclass,
9191
python_requires=">=3.6",

torchensemble/_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn as nn
88

99
from . import _constants as const
10+
from .utils.logging import get_tb_logger
1011

1112

1213
def torchensemble_model_doc(header="", item="model"):
@@ -77,6 +78,7 @@ def __init__(
7778
self.device = torch.device("cuda" if cuda else "cpu")
7879
self.n_jobs = n_jobs
7980
self.logger = logging.getLogger()
81+
self.tb_logger = get_tb_logger()
8082

8183
self.estimators_ = nn.ModuleList()
8284
self.use_scheduler_ = False

torchensemble/adversarial_training.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,12 @@ def _forward(estimators, data):
368368
" % | Historical Best: {:.3f} %"
369369
)
370370
self.logger.info(msg.format(epoch, acc, best_acc))
371+
if self.tb_logger:
372+
self.tb_logger.add_scalar(
373+
"adversarial_training/Validation_Acc",
374+
acc,
375+
epoch,
376+
)
371377

372378
# Update the scheduler
373379
with warnings.catch_warnings():
@@ -541,6 +547,12 @@ def _forward(estimators, data):
541547
" {:.5f} | Historical Best: {:.5f}"
542548
)
543549
self.logger.info(msg.format(epoch, mse, best_mse))
550+
if self.tb_logger:
551+
self.tb_logger.add_scalar(
552+
"adversirial_training/Validation_MSE",
553+
mse,
554+
epoch,
555+
)
544556

545557
# Update the scheduler
546558
with warnings.catch_warnings():

0 commit comments

Comments
 (0)