11import copy
22import numpy as np
33from sklearn .tree import DecisionTreeClassifier
4+ from sklearn .ensemble import RandomForestClassifier
45
5- from adapt .parameter_based import TransferTreeClassifier
6+ from adapt .parameter_based import TransferTreeClassifier , TransferForestClassifier
67
78methods = [
89 'relab' ,
910 'ser' ,
1011 'strut' ,
1112 'ser_nr' ,
13+ 'ser_no_ext' ,
1214 'ser_nr_lambda' ,
1315 'strut_nd' ,
1416 'strut_lambda' ,
15- 'strut_lambda_np'
17+ 'strut_np'
18+ 'strut_lambda_np' ,
19+ 'strut_lambda_np2'
1620# 'strut_hi'
1721]
18- labels = [
19- 'relab' ,
20- '$SER$' ,
21- '$STRUT$' ,
22- '$SER_{NP}$' ,
23- '$SER_{NP}(\lambda)$' ,
24- '$STRUT_{ND}$' ,
25- '$STRUT(\lambda)$' ,
26- '$STRUT_{NP}(\lambda)$'
27- # 'STRUT$^{*}$',
28- #'STRUT$^{*}$',
29- ]
22+
3023
3124def test_transfer_tree ():
3225
3326 np .random .seed (0 )
3427
35- plot_step = 0.01
3628 # Generate training source data
3729 ns = 200
3830 ns_perclass = ns // 2
@@ -65,12 +57,15 @@ def test_transfer_tree():
6557 yt_test [nt_test_perclass :] = 1
6658
6759 # Source classifier
68- clf_source = DecisionTreeClassifier (max_depth = None )
69- clf_source .fit (Xs , ys )
70- score_src_src = clf_source .score (Xs , ys )
71- score_src_trgt = clf_source .score (Xt_test , yt_test )
72- print ('Training score Source model: {:.3f}' .format (score_src_src ))
73- print ('Testing score Source model: {:.3f}' .format (score_src_trgt ))
60+ RF_SIZE = 10
61+ clf_source_dt = DecisionTreeClassifier (max_depth = None )
62+ clf_source_rf = RandomForestClassifier (n_estimators = RF_SIZE )
63+ clf_source_dt .fit (Xs , ys )
64+ clf_source_rf .fit (Xs , ys )
65+ #score_src_src = clf_source.score(Xs, ys)
66+ #score_src_trgt = clf_source.score(Xt_test, yt_test)
67+ #print('Training score Source model: {:.3f}'.format(score_src_src))
68+ #print('Testing score Source model: {:.3f}'.format(score_src_trgt))
7469 clfs = []
7570 scores = []
7671 # Transfer with SER
@@ -79,7 +74,7 @@ def test_transfer_tree():
7974
8075 for method in methods :
8176 Nkmin = sum (yt == 0 )
82- root_source_values = clf_source .tree_ .value [0 ].reshape (- 1 )
77+ root_source_values = clf_source_dt .tree_ .value [0 ].reshape (- 1 )
8378 props_s = root_source_values
8479 props_s = props_s / sum (props_s )
8580 props_t = np .zeros (props_s .size )
@@ -88,43 +83,105 @@ def test_transfer_tree():
8883
8984 coeffs = np .divide (props_t , props_s )
9085
91- clf_transfer = copy .deepcopy (clf_source )
86+ clf_transfer_dt = copy .deepcopy (clf_source_dt )
87+ clf_transfer_rf = copy .deepcopy (clf_source_rf )
88+
9289 if method == 'relab' :
93- transferred_dt = TransferTreeClassifier (estimator = clf_transfer ,algo = "" )
90+ #decision tree
91+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "" )
9492 transferred_dt .fit (Xt ,yt )
93+ #random forest
94+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "" ,bootstrap = True )
95+ transferred_rf .fit (Xt ,yt )
9596 if method == 'ser' :
96- transferred_dt = TransferTreeClassifier (estimator = clf_transfer ,algo = "ser" )
97+ #decision tree
98+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "ser" ,max_depth = 10 )
9799 transferred_dt .fit (Xt ,yt )
98- #transferred_dt._ser(Xt, yt, node=0, original_ser=True)
99- #ser.SER(0, clf_transfer, Xt, yt, original_ser=True)
100+ #random forest
101+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "ser" )
102+ transferred_rf .fit (Xt ,yt )
100103 if method == 'ser_nr' :
101- transferred_dt = TransferTreeClassifier (estimator = clf_transfer ,algo = "ser" )
104+ #decision tree
105+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "ser" )
102106 transferred_dt ._ser (Xt , yt ,node = 0 ,original_ser = False ,no_red_on_cl = True ,cl_no_red = [0 ])
107+ #random forest
108+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "ser" )
109+ transferred_rf ._ser_rf (Xt , yt ,original_ser = False ,no_red_on_cl = True ,cl_no_red = [0 ])
110+ if method == 'ser_no_ext' :
111+ #decision tree
112+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "ser" )
113+ transferred_dt ._ser (Xt , yt ,node = 0 ,original_ser = False ,no_ext_on_cl = True ,cl_no_red = [0 ],ext_cond = True )
114+ #random forest
115+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "ser" )
116+ transferred_rf ._ser_rf (Xt , yt ,original_ser = False ,no_ext_on_cl = True ,cl_no_ext = [0 ],ext_cond = True )
103117 if method == 'ser_nr_lambda' :
104- transferred_dt = TransferTreeClassifier (estimator = clf_transfer ,algo = "ser" )
118+ #decision tree
119+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "ser" )
105120 transferred_dt ._ser (Xt , yt ,node = 0 ,original_ser = False ,no_red_on_cl = True ,cl_no_red = [0 ],
106121 leaf_loss_quantify = True ,leaf_loss_threshold = 0.5 ,
107122 root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
108- #ser.SER(0, clf_transfer, Xt, yt,original_ser=False,no_red_on_cl=True,cl_no_red=[0],ext_cond=True)
123+ #random forest
124+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "ser" )
125+ transferred_rf ._ser_rf (Xt , yt ,original_ser = False ,no_red_on_cl = True ,cl_no_red = [0 ],
126+ leaf_loss_quantify = True ,leaf_loss_threshold = 0.5 ,
127+ root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
109128 if method == 'strut' :
110- transferred_dt = TransferTreeClassifier (estimator = clf_transfer ,algo = "strut" )
129+ #decision tree
130+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "strut" )
111131 transferred_dt .fit (Xt ,yt )
112- #transferred_dt._strut(Xt, yt,node=0)
132+ #random forest
133+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
134+ transferred_rf .fit (Xt ,yt )
113135 if method == 'strut_nd' :
114- transferred_dt = TransferTreeClassifier (estimator = clf_transfer ,algo = "strut" )
136+ #decision tree
137+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_rf ,algo = "strut" )
115138 transferred_dt ._strut (Xt , yt ,node = 0 ,use_divergence = False )
139+ #random forest
140+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
141+ transferred_rf ._strut_rf (Xt , yt ,use_divergence = False )
116142 if method == 'strut_lambda' :
117- transferred_dt = TransferTreeClassifier (estimator = clf_transfer ,algo = "strut" )
143+ #decision tree
144+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "strut" )
118145 transferred_dt ._strut (Xt , yt ,node = 0 ,adapt_prop = True ,root_source_values = root_source_values ,
119146 Nkmin = Nkmin ,coeffs = coeffs )
147+ #random forest
148+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
149+ transferred_rf ._strut_rf (Xt , yt ,adapt_prop = True ,root_source_values = root_source_values ,
150+ Nkmin = Nkmin ,coeffs = coeffs )
151+ if method == 'strut_np' :
152+ #decision tree
153+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "strut" )
154+ transferred_dt ._strut (Xt , yt ,node = 0 ,adapt_prop = False ,no_prune_on_cl = True ,cl_no_prune = [0 ],
155+ leaf_loss_quantify = False ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = False ,
156+ root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
157+ #random forest
158+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
159+ transferred_rf ._strut (Xt , yt ,adapt_prop = False ,no_prune_on_cl = True ,cl_no_prune = [0 ],
160+ leaf_loss_quantify = False ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = False ,
161+ root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
120162 if method == 'strut_lambda_np' :
121- transferred_dt = TransferTreeClassifier (estimator = clf_transfer ,algo = "strut" )
163+ #decision tree
164+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "strut" )
165+ transferred_dt ._strut (Xt , yt ,node = 0 ,adapt_prop = False ,no_prune_on_cl = True ,cl_no_prune = [0 ],
166+ leaf_loss_quantify = False ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = False ,
167+ root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
168+ #random forest
169+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
170+ transferred_rf ._strut (Xt , yt ,adapt_prop = True ,no_prune_on_cl = True ,cl_no_prune = [0 ],
171+ leaf_loss_quantify = True ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = False ,
172+ root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
173+ if method == 'strut_lambda_np2' :
174+ #decision tree
175+ transferred_dt = TransferTreeClassifier (estimator = clf_transfer_dt ,algo = "strut" )
122176 transferred_dt ._strut (Xt , yt ,node = 0 ,adapt_prop = False ,no_prune_on_cl = True ,cl_no_prune = [0 ],
123177 leaf_loss_quantify = False ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = False ,
124178 root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
125- #if method == 'strut_hi':
126- #transferred_dt._strut(Xt, yt,node=0,no_prune_on_cl=False,adapt_prop=True,coeffs=[0.2, 1])
127- #strut.STRUT(clf_transfer, 0, Xt, yt, Xt, yt,pruning_updated_node=True,no_prune_on_cl=False,adapt_prop=True,simple_weights=False,coeffs=[0.2, 1])
179+ #random forest
180+ transferred_rf = TransferForestClassifier (estimator = clf_transfer_rf ,algo = "strut" )
181+ transferred_rf ._strut (Xt , yt ,adapt_prop = True ,no_prune_on_cl = True ,cl_no_prune = [0 ],
182+ leaf_loss_quantify = True ,leaf_loss_threshold = 0.5 ,no_prune_with_translation = True ,
183+ root_source_values = root_source_values ,Nkmin = Nkmin ,coeffs = coeffs )
184+
128185 score = transferred_dt .estimator .score (Xt_test , yt_test )
129186 #score = clf_transfer.score(Xt_test, yt_test)
130187 print ('Testing score transferred model ({}) : {:.3f}' .format (method , score ))
0 commit comments