Skip to content

Commit 9d26a9d

Browse files
authored
feat: add neural forest ensemble (#110)
* feat: add NeuralForestClassifier * feat: add NeuralForestClassifier * update code * update code * update code * Update requirements.txt * add doc * Update README.rst * update code
1 parent f532498 commit 9d26a9d

13 files changed

Lines changed: 607 additions & 4 deletions

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
Ver 0.1.*
1919
---------
2020

21+
* |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu <https://github.com/xuyxu>`__
2122
* |Fix| Relax check on input dataloader | `@xuyxu <https://github.com/xuyxu>`__
2223
* |Feature| |API| Support arbitrary training criteria for all ensembles except Gradient Boosting | `@by256 <https://github.com/by256>`__ and `@xuyxu <https://github.com/xuyxu>`__
2324
* |Fix| Fix missing functionality of ``save_model`` for :meth:`fit` of Soft Gradient Boosting | `@xuyxu <https://github.com/xuyxu>`__

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ Supported Ensemble
8787
+------------------------------+------------+---------------------------+
8888
| Voting [1]_ | Parallel | voting.py |
8989
+------------------------------+------------+---------------------------+
90+
| Neural Forest | Parallel | voting.py |
91+
+------------------------------+------------+---------------------------+
9092
| Bagging [2]_ | Parallel | bagging.py |
9193
+------------------------------+------------+---------------------------+
9294
| Gradient Boosting [3]_ | Sequential | gradient_boosting.py |

docs/parameters.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,29 @@ GradientBoostingRegressor
9292
.. autoclass:: torchensemble.gradient_boosting.GradientBoostingRegressor
9393
:members:
9494

95+
Neural Tree Ensemble
96+
--------------------
97+
98+
Neural tree ensemble are extensions of voting and gradient boosting, which
99+
uses the neural tree as the base estimator. Neural trees are differentiable
100+
trees that uses the logistic regression in internal nodes to split samples
101+
into child nodes with different probabilities. Model details are available at
102+
`Distilling a neural network into a soft decision tree
103+
<https://arxiv.org/abs/1711.09784>`_.
104+
105+
106+
NeuralForestClassifier
107+
**********************
108+
109+
.. autoclass:: torchensemble.voting.NeuralForestClassifier
110+
:members:
111+
112+
NeuralForestRegressor
113+
**********************
114+
115+
.. autoclass:: torchensemble.voting.NeuralForestRegressor
116+
:members:
117+
95118
Snapshot Ensemble
96119
-----------------
97120

docs/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ sphinxemoji==0.1.8
44
sphinx-copybutton
55
m2r2==0.2.7
66
mistune==0.8.4
7-
Jinja2<3.1
7+
Jinja2<3.1
8+
Numpy
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import time
2+
import torch
3+
from torchvision import datasets, transforms
4+
5+
from torchensemble import NeuralForestClassifier
6+
from torchensemble.utils.logging import set_logger
7+
8+
9+
if __name__ == "__main__":
10+
11+
# Hyper-parameters
12+
n_estimators = 5
13+
depth = 5
14+
lamda = 1e-3
15+
lr = 1e-3
16+
weight_decay = 5e-4
17+
epochs = 50
18+
19+
# Utils
20+
cuda = False
21+
n_jobs = 1
22+
batch_size = 128
23+
data_dir = "../../Dataset/mnist" # MODIFY THIS IF YOU WANT
24+
25+
# Load data
26+
train_loader = torch.utils.data.DataLoader(
27+
datasets.MNIST(
28+
data_dir,
29+
train=True,
30+
download=True,
31+
transform=transforms.Compose(
32+
[
33+
transforms.ToTensor(),
34+
transforms.Normalize((0.1307,), (0.3081,)),
35+
]
36+
),
37+
),
38+
batch_size=batch_size,
39+
shuffle=True,
40+
)
41+
42+
test_loader = torch.utils.data.DataLoader(
43+
datasets.MNIST(
44+
data_dir,
45+
train=False,
46+
download=True,
47+
transform=transforms.Compose(
48+
[
49+
transforms.ToTensor(),
50+
transforms.Normalize((0.1307,), (0.3081,)),
51+
]
52+
),
53+
),
54+
batch_size=batch_size,
55+
shuffle=True,
56+
)
57+
58+
logger = set_logger(
59+
"classification_mnist_tree_ensemble", use_tb_logger=False
60+
)
61+
62+
model = NeuralForestClassifier(
63+
n_estimators=n_estimators,
64+
depth=depth,
65+
lamda=lamda,
66+
cuda=cuda,
67+
n_jobs=-1,
68+
)
69+
70+
model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay)
71+
72+
tic = time.time()
73+
model.fit(train_loader, epochs=epochs, test_loader=test_loader)
74+
toc = time.time()
75+
training_time = toc - tic

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ exclude = '''
2323
| dist
2424
| docs
2525
)/
26-
'''
26+
'''

torchensemble/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from .fusion import FusionRegressor
33
from .voting import VotingClassifier
44
from .voting import VotingRegressor
5+
from .voting import NeuralForestClassifier
6+
from .voting import NeuralForestRegressor
57
from .bagging import BaggingClassifier
68
from .bagging import BaggingRegressor
79
from .gradient_boosting import GradientBoostingClassifier
@@ -21,6 +23,8 @@
2123
"FusionRegressor",
2224
"VotingClassifier",
2325
"VotingRegressor",
26+
"NeuralForestClassifier",
27+
"NeuralForestRegressor",
2428
"BaggingClassifier",
2529
"BaggingRegressor",
2630
"GradientBoostingClassifier",

torchensemble/_base.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def get_doc(item):
2929
__doc = {
3030
"model": const.__model_doc,
3131
"seq_model": const.__seq_model_doc,
32+
"tree_ensmeble_model": const.__tree_ensemble_doc,
3233
"fit": const.__fit_doc,
3334
"predict": const.__predict_doc,
3435
"set_optimizer": const.__set_optimizer_doc,
@@ -199,6 +200,50 @@ def predict(self, *x):
199200
return pred
200201

201202

203+
class BaseTreeEnsemble(BaseModule):
204+
def __init__(
205+
self,
206+
n_estimators,
207+
depth=5,
208+
lamda=1e-3,
209+
cuda=False,
210+
n_jobs=None,
211+
):
212+
super(BaseModule, self).__init__()
213+
self.base_estimator_ = BaseTree
214+
self.n_estimators = n_estimators
215+
self.depth = depth
216+
self.lamda = lamda
217+
218+
self.device = torch.device("cuda" if cuda else "cpu")
219+
self.n_jobs = n_jobs
220+
self.logger = logging.getLogger()
221+
self.tb_logger = get_tb_logger()
222+
223+
self.estimators_ = nn.ModuleList()
224+
self.use_scheduler_ = False
225+
226+
def _decidce_n_inputs(self, train_loader):
227+
"""Decide the input dimension according to the `train_loader`."""
228+
for _, elem in enumerate(train_loader):
229+
data = elem[0]
230+
n_samples = data.size(0)
231+
data = data.view(n_samples, -1)
232+
return data.size(1)
233+
234+
def _make_estimator(self):
235+
"""Make and configure a soft decision tree."""
236+
estimator = BaseTree(
237+
input_dim=self.n_inputs,
238+
output_dim=self.n_outputs,
239+
depth=self.depth,
240+
lamda=self.lamda,
241+
cuda=self.device == torch.device("cuda"),
242+
)
243+
244+
return estimator.to(self.device)
245+
246+
202247
class BaseClassifier(BaseModule):
203248
"""Base class for all ensemble classifiers.
204249
@@ -285,3 +330,131 @@ def evaluate(self, test_loader):
285330
loss += self._criterion(output, target)
286331

287332
return float(loss) / len(test_loader)
333+
334+
335+
class BaseTree(nn.Module):
336+
"""Fast implementation of soft decision tree in PyTorch, copied from:
337+
`https://github.com/xuyxu/Soft-Decision-Tree/blob/master/SDT.py`
338+
"""
339+
340+
def __init__(self, input_dim, output_dim, depth=5, lamda=1e-3, cuda=False):
341+
super(BaseTree, self).__init__()
342+
343+
self.input_dim = input_dim
344+
self.output_dim = output_dim
345+
346+
self.depth = depth
347+
self.lamda = lamda
348+
self.device = torch.device("cuda" if cuda else "cpu")
349+
350+
self._validate_parameters()
351+
352+
self.internal_node_num_ = 2 ** self.depth - 1
353+
self.leaf_node_num_ = 2 ** self.depth
354+
355+
# Different penalty coefficients for nodes in different layers
356+
self.penalty_list = [
357+
self.lamda * (2 ** (-depth)) for depth in range(0, self.depth)
358+
]
359+
360+
# Initialize internal nodes and leaf nodes, the input dimension on
361+
# internal nodes is added by 1, serving as the bias.
362+
self.inner_nodes = nn.Sequential(
363+
nn.Linear(self.input_dim + 1, self.internal_node_num_, bias=False),
364+
nn.Sigmoid(),
365+
)
366+
367+
self.leaf_nodes = nn.Linear(
368+
self.leaf_node_num_, self.output_dim, bias=False
369+
)
370+
371+
def forward(self, X, is_training_data=False):
372+
_mu, _penalty = self._forward(X)
373+
y_pred = self.leaf_nodes(_mu)
374+
375+
# When `X` is the training data, the model also returns the penalty
376+
# to compute the training loss.
377+
if is_training_data:
378+
return y_pred, _penalty
379+
else:
380+
return y_pred
381+
382+
def _forward(self, X):
383+
"""Implementation on the data forwarding process."""
384+
385+
batch_size = X.size()[0]
386+
X = self._data_augment(X)
387+
388+
path_prob = self.inner_nodes(X)
389+
path_prob = torch.unsqueeze(path_prob, dim=2)
390+
path_prob = torch.cat((path_prob, 1 - path_prob), dim=2)
391+
392+
_mu = X.data.new(batch_size, 1, 1).fill_(1.0)
393+
_penalty = torch.tensor(0.0).to(self.device)
394+
395+
# Iterate through internal odes in each layer to compute the final path
396+
# probabilities and the regularization term.
397+
begin_idx = 0
398+
end_idx = 1
399+
400+
for layer_idx in range(0, self.depth):
401+
_path_prob = path_prob[:, begin_idx:end_idx, :]
402+
403+
# Extract internal nodes in the current layer to compute the
404+
# regularization term
405+
_penalty = _penalty + self._cal_penalty(layer_idx, _mu, _path_prob)
406+
_mu = _mu.view(batch_size, -1, 1).repeat(1, 1, 2)
407+
408+
_mu = _mu * _path_prob # update path probabilities
409+
410+
begin_idx = end_idx
411+
end_idx = begin_idx + 2 ** (layer_idx + 1)
412+
413+
mu = _mu.view(batch_size, self.leaf_node_num_)
414+
415+
return mu, _penalty
416+
417+
def _cal_penalty(self, layer_idx, _mu, _path_prob):
418+
"""Compute the regularization term for internal nodes"""
419+
420+
penalty = torch.tensor(0.0).to(self.device)
421+
422+
batch_size = _mu.size()[0]
423+
_mu = _mu.view(batch_size, 2 ** layer_idx)
424+
_path_prob = _path_prob.view(batch_size, 2 ** (layer_idx + 1))
425+
426+
for node in range(0, 2 ** (layer_idx + 1)):
427+
alpha = torch.sum(
428+
_path_prob[:, node] * _mu[:, node // 2], dim=0
429+
) / torch.sum(_mu[:, node // 2], dim=0)
430+
431+
coeff = self.penalty_list[layer_idx]
432+
433+
penalty -= 0.5 * coeff * (torch.log(alpha) + torch.log(1 - alpha))
434+
435+
return penalty
436+
437+
def _data_augment(self, X):
438+
"""Add a constant input `1` onto the front of each sample."""
439+
batch_size = X.size()[0]
440+
X = X.view(batch_size, -1)
441+
bias = torch.ones(batch_size, 1).to(self.device)
442+
X = torch.cat((bias, X), 1)
443+
444+
return X
445+
446+
def _validate_parameters(self):
447+
448+
if not self.depth > 0:
449+
msg = (
450+
"The tree depth should be strictly positive, but got {}"
451+
"instead."
452+
)
453+
raise ValueError(msg.format(self.depth))
454+
455+
if not self.lamda >= 0:
456+
msg = (
457+
"The coefficient of the regularization term should not be"
458+
" negative, but got {} instead."
459+
)
460+
raise ValueError(msg.format(self.lamda))

torchensemble/_constants.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,36 @@
5858
"""
5959

6060

61+
__tree_ensemble_doc = """
62+
Parameters
63+
----------
64+
n_estimators : int
65+
The number of neural trees in the ensemble.
66+
depth : int, default=5
67+
The depth of neural tree. A tree with depth ``d`` is with :math:`2^d`
68+
leaf nodes and :math:`2^d-1` internal nodes.
69+
lamda : float, default=1e-3
70+
The coefficient of the regularization term when training neural
71+
trees, proposed in the paper: `Distilling a neural network into a
72+
soft decision tree <https://arxiv.org/abs/1711.09784>`_.
73+
cuda : bool, default=True
74+
75+
- If ``True``, use GPU to train and evaluate the ensemble.
76+
- If ``False``, use CPU to train and evaluate the ensemble.
77+
n_jobs : int, default=None
78+
The number of workers for training the ensemble. This input
79+
argument is used for parallel ensemble methods such as
80+
:mod:`voting` and :mod:`bagging`. Setting it to an integer larger
81+
than ``1`` enables ``n_jobs`` base estimators to be trained
82+
simultaneously.
83+
84+
Attributes
85+
----------
86+
estimators_ : torch.nn.ModuleList
87+
An internal container that stores all fitted base estimators.
88+
"""
89+
90+
6191
__set_optimizer_doc = """
6292
Parameters
6393
----------

0 commit comments

Comments
 (0)