@@ -29,6 +29,7 @@ def get_doc(item):
2929 __doc = {
3030 "model" : const .__model_doc ,
3131 "seq_model" : const .__seq_model_doc ,
32+ "tree_ensmeble_model" : const .__tree_ensemble_doc ,
3233 "fit" : const .__fit_doc ,
3334 "predict" : const .__predict_doc ,
3435 "set_optimizer" : const .__set_optimizer_doc ,
@@ -199,6 +200,50 @@ def predict(self, *x):
199200 return pred
200201
201202
203+ class BaseTreeEnsemble (BaseModule ):
204+ def __init__ (
205+ self ,
206+ n_estimators ,
207+ depth = 5 ,
208+ lamda = 1e-3 ,
209+ cuda = False ,
210+ n_jobs = None ,
211+ ):
212+ super (BaseModule , self ).__init__ ()
213+ self .base_estimator_ = BaseTree
214+ self .n_estimators = n_estimators
215+ self .depth = depth
216+ self .lamda = lamda
217+
218+ self .device = torch .device ("cuda" if cuda else "cpu" )
219+ self .n_jobs = n_jobs
220+ self .logger = logging .getLogger ()
221+ self .tb_logger = get_tb_logger ()
222+
223+ self .estimators_ = nn .ModuleList ()
224+ self .use_scheduler_ = False
225+
226+ def _decidce_n_inputs (self , train_loader ):
227+ """Decide the input dimension according to the `train_loader`."""
228+ for _ , elem in enumerate (train_loader ):
229+ data = elem [0 ]
230+ n_samples = data .size (0 )
231+ data = data .view (n_samples , - 1 )
232+ return data .size (1 )
233+
234+ def _make_estimator (self ):
235+ """Make and configure a soft decision tree."""
236+ estimator = BaseTree (
237+ input_dim = self .n_inputs ,
238+ output_dim = self .n_outputs ,
239+ depth = self .depth ,
240+ lamda = self .lamda ,
241+ cuda = self .device == torch .device ("cuda" ),
242+ )
243+
244+ return estimator .to (self .device )
245+
246+
202247class BaseClassifier (BaseModule ):
203248 """Base class for all ensemble classifiers.
204249
@@ -285,3 +330,131 @@ def evaluate(self, test_loader):
285330 loss += self ._criterion (output , target )
286331
287332 return float (loss ) / len (test_loader )
333+
334+
335+ class BaseTree (nn .Module ):
336+ """Fast implementation of soft decision tree in PyTorch, copied from:
337+ `https://github.com/xuyxu/Soft-Decision-Tree/blob/master/SDT.py`
338+ """
339+
340+ def __init__ (self , input_dim , output_dim , depth = 5 , lamda = 1e-3 , cuda = False ):
341+ super (BaseTree , self ).__init__ ()
342+
343+ self .input_dim = input_dim
344+ self .output_dim = output_dim
345+
346+ self .depth = depth
347+ self .lamda = lamda
348+ self .device = torch .device ("cuda" if cuda else "cpu" )
349+
350+ self ._validate_parameters ()
351+
352+ self .internal_node_num_ = 2 ** self .depth - 1
353+ self .leaf_node_num_ = 2 ** self .depth
354+
355+ # Different penalty coefficients for nodes in different layers
356+ self .penalty_list = [
357+ self .lamda * (2 ** (- depth )) for depth in range (0 , self .depth )
358+ ]
359+
360+ # Initialize internal nodes and leaf nodes, the input dimension on
361+ # internal nodes is added by 1, serving as the bias.
362+ self .inner_nodes = nn .Sequential (
363+ nn .Linear (self .input_dim + 1 , self .internal_node_num_ , bias = False ),
364+ nn .Sigmoid (),
365+ )
366+
367+ self .leaf_nodes = nn .Linear (
368+ self .leaf_node_num_ , self .output_dim , bias = False
369+ )
370+
371+ def forward (self , X , is_training_data = False ):
372+ _mu , _penalty = self ._forward (X )
373+ y_pred = self .leaf_nodes (_mu )
374+
375+ # When `X` is the training data, the model also returns the penalty
376+ # to compute the training loss.
377+ if is_training_data :
378+ return y_pred , _penalty
379+ else :
380+ return y_pred
381+
382+ def _forward (self , X ):
383+ """Implementation on the data forwarding process."""
384+
385+ batch_size = X .size ()[0 ]
386+ X = self ._data_augment (X )
387+
388+ path_prob = self .inner_nodes (X )
389+ path_prob = torch .unsqueeze (path_prob , dim = 2 )
390+ path_prob = torch .cat ((path_prob , 1 - path_prob ), dim = 2 )
391+
392+ _mu = X .data .new (batch_size , 1 , 1 ).fill_ (1.0 )
393+ _penalty = torch .tensor (0.0 ).to (self .device )
394+
395+ # Iterate through internal odes in each layer to compute the final path
396+ # probabilities and the regularization term.
397+ begin_idx = 0
398+ end_idx = 1
399+
400+ for layer_idx in range (0 , self .depth ):
401+ _path_prob = path_prob [:, begin_idx :end_idx , :]
402+
403+ # Extract internal nodes in the current layer to compute the
404+ # regularization term
405+ _penalty = _penalty + self ._cal_penalty (layer_idx , _mu , _path_prob )
406+ _mu = _mu .view (batch_size , - 1 , 1 ).repeat (1 , 1 , 2 )
407+
408+ _mu = _mu * _path_prob # update path probabilities
409+
410+ begin_idx = end_idx
411+ end_idx = begin_idx + 2 ** (layer_idx + 1 )
412+
413+ mu = _mu .view (batch_size , self .leaf_node_num_ )
414+
415+ return mu , _penalty
416+
417+ def _cal_penalty (self , layer_idx , _mu , _path_prob ):
418+ """Compute the regularization term for internal nodes"""
419+
420+ penalty = torch .tensor (0.0 ).to (self .device )
421+
422+ batch_size = _mu .size ()[0 ]
423+ _mu = _mu .view (batch_size , 2 ** layer_idx )
424+ _path_prob = _path_prob .view (batch_size , 2 ** (layer_idx + 1 ))
425+
426+ for node in range (0 , 2 ** (layer_idx + 1 )):
427+ alpha = torch .sum (
428+ _path_prob [:, node ] * _mu [:, node // 2 ], dim = 0
429+ ) / torch .sum (_mu [:, node // 2 ], dim = 0 )
430+
431+ coeff = self .penalty_list [layer_idx ]
432+
433+ penalty -= 0.5 * coeff * (torch .log (alpha ) + torch .log (1 - alpha ))
434+
435+ return penalty
436+
437+ def _data_augment (self , X ):
438+ """Add a constant input `1` onto the front of each sample."""
439+ batch_size = X .size ()[0 ]
440+ X = X .view (batch_size , - 1 )
441+ bias = torch .ones (batch_size , 1 ).to (self .device )
442+ X = torch .cat ((bias , X ), 1 )
443+
444+ return X
445+
446+ def _validate_parameters (self ):
447+
448+ if not self .depth > 0 :
449+ msg = (
450+ "The tree depth should be strictly positive, but got {}"
451+ "instead."
452+ )
453+ raise ValueError (msg .format (self .depth ))
454+
455+ if not self .lamda >= 0 :
456+ msg = (
457+ "The coefficient of the regularization term should not be"
458+ " negative, but got {} instead."
459+ )
460+ raise ValueError (msg .format (self .lamda ))
0 commit comments