@@ -16,18 +16,6 @@ output_path = output_dir("mitigation_strategies")
1616www_path = www_dir("mitigation_strategies")
1717```
1818
19- ``` {julia}
20- max_obs = 1000
21- catalogue = AlgorithmicRecourseDynamics.Data.load_synthetic(max_obs)
22- choices = [
23- :linearly_separable,
24- :overlapping,
25- :circles,
26- :moons,
27- ]
28- data_sets = filter(p -> p[1] in choices, catalogue)
29- ```
30-
3119``` {julia}
3220models = [
3321 :LogisticRegression,
@@ -42,6 +30,20 @@ generators = Dict(
4230)
4331```
4432
33+ ## Synthetic
34+
35+ ``` {julia}
36+ max_obs = 1000
37+ catalogue = AlgorithmicRecourseDynamics.Data.load_synthetic(max_obs)
38+ choices = [
39+ :linearly_separable,
40+ :overlapping,
41+ :circles,
42+ :moons,
43+ ]
44+ data_sets = filter(p -> p[1] in choices, catalogue)
45+ ```
46+
4547``` {julia}
4648experiments = set_up_experiments(data_sets,models,generators)
4749```
@@ -82,8 +84,6 @@ plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices
8284savefig(plt, joinpath(www_path,"models_train_before.png"))
8385```
8486
85- ## Experiments
86-
8787``` {julia}
8888n_evals = 5
8989n_rounds = 50
@@ -173,3 +173,115 @@ for (data_name, res) in results
173173 errorbar_charts[data_name] = plot(res,50)
174174end
175175```
176+
177+ ## Real World
178+
179+ ``` {julia}
180+ max_obs = 1000
181+ data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs)
182+ ```
183+
184+ ``` {julia}
185+ using CounterfactualExplanations.DataPreprocessing: unpack
186+ bs = 50
187+ function data_loader(data::CounterfactualData)
188+ X, y = unpack(data)
189+ data = Flux.DataLoader((X,y),batchsize=bs)
190+ return data
191+ end
192+ model_params = (batch_norm=false,n_hidden=32,n_layers=3,dropout=true,p_dropout=0.25)
193+ ```
194+
195+ ``` {julia}
196+ experiments = set_up_experiments(
197+ data_sets,models,generators;
198+ pre_train_models=100, model_params=model_params,
199+ data_loader=data_loader
200+ )
201+ ```
202+
203+ ``` {julia}
204+ n_evals = 5
205+ n_rounds = 50
206+ evaluate_every = Int(round(n_rounds/n_evals))
207+ n_folds = 5
208+ n_bootstrap = 1
209+ T = 100
210+ using Serialization
211+ results = run_experiments(
212+ experiments;
213+ save_path=output_path,
214+ evaluate_every=evaluate_every,
215+ n_rounds=n_rounds,
216+ n_folds=n_folds,
217+ n_bootstrap=n_bootstrap,
218+ T=T
219+ )
220+ Serialization.serialize(joinpath(output_path,"results.jls"),results)
221+ ```
222+
223+ ``` {julia}
224+ line_charts = Dict()
225+ errorbar_charts = Dict()
226+ for (data_name, res) in results
227+ line_charts[data_name] = plot(res)
228+ errorbar_charts[data_name] = plot(res,50)
229+ end
230+ ```
231+
232+ ### Latent Space Search
233+
234+
235+
236+ ``` {julia}
237+ generators = Dict(
238+ :REVISE=>GenericGenerator(decision_threshold=0.5),
239+ :REVISE_conservative=>GenericGenerator(decision_threshold=0.9),
240+ :Gravitational=>GravitationalGenerator(),
241+ :ROAR=>EndoROARGenerator()
242+ )
243+ ```
244+
245+
246+ ``` {julia}
247+ generative_model_params = (epochs=250, latent_dim=8)
248+ experiments = set_up_experiments(
249+ data_sets,models,generators;
250+ pre_train_models=100, model_params=model_params,
251+ data_loader=data_loader
252+ )
253+ ```
254+
255+ ``` {julia}
256+ n_evals = 5
257+ n_rounds = 50
258+ evaluate_every = Int(round(n_rounds/n_evals))
259+ n_folds = 5
260+ n_bootstrap = 1
261+ T = 100
262+ using Serialization
263+ results = run_experiments(
264+ experiments;
265+ save_path=output_path,
266+ save_name_suffix="latent",
267+ evaluate_every=evaluate_every,
268+ n_rounds=n_rounds,
269+ n_folds=n_folds,
270+ n_bootstrap=n_bootstrap,
271+ T=T,
272+ latent_space=true,
273+ generative_model_params=generative_model_params
274+ )
275+ Serialization.serialize(joinpath(output_path,"results_latent.jls"),results)
276+ ```
277+
278+ ### Plots
279+
280+ ``` {julia}
281+ line_charts = Dict()
282+ errorbar_charts = Dict()
283+ for (data_name, res) in results
284+ line_charts[data_name] = plot(res)
285+ errorbar_charts[data_name] = plot(res,50)
286+ end
287+ ```
0 commit comments