Skip to content

Commit 0e42386

Browse files
committed
mitigation strategies also show some effect for real world data. Finally now also trying in combination with revise
1 parent 65bad24 commit 0e42386

2 files changed

Lines changed: 128 additions & 47 deletions

File tree

dev/notebooks/mitigation_strategies.qmd

Lines changed: 126 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,6 @@ output_path = output_dir("mitigation_strategies")
1616
www_path = www_dir("mitigation_strategies")
1717
```
1818

19-
```{julia}
20-
max_obs = 1000
21-
catalogue = AlgorithmicRecourseDynamics.Data.load_synthetic(max_obs)
22-
choices = [
23-
:linearly_separable,
24-
:overlapping,
25-
:circles,
26-
:moons,
27-
]
28-
data_sets = filter(p -> p[1] in choices, catalogue)
29-
```
30-
3119
```{julia}
3220
models = [
3321
:LogisticRegression,
@@ -42,6 +30,20 @@ generators = Dict(
4230
)
4331
```
4432

33+
## Synthetic
34+
35+
```{julia}
36+
max_obs = 1000
37+
catalogue = AlgorithmicRecourseDynamics.Data.load_synthetic(max_obs)
38+
choices = [
39+
:linearly_separable,
40+
:overlapping,
41+
:circles,
42+
:moons,
43+
]
44+
data_sets = filter(p -> p[1] in choices, catalogue)
45+
```
46+
4547
```{julia}
4648
experiments = set_up_experiments(data_sets,models,generators)
4749
```
@@ -82,8 +84,6 @@ plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices
8284
savefig(plt, joinpath(www_path,"models_train_before.png"))
8385
```
8486

85-
## Experiments
86-
8787
```{julia}
8888
n_evals = 5
8989
n_rounds = 50
@@ -173,3 +173,115 @@ for (data_name, res) in results
173173
errorbar_charts[data_name] = plot(res,50)
174174
end
175175
```
176+
177+
## Real World
178+
179+
```{julia}
180+
max_obs = 1000
181+
data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs)
182+
```
183+
184+
```{julia}
185+
using CounterfactualExplanations.DataPreprocessing: unpack
186+
bs = 50
187+
function data_loader(data::CounterfactualData)
188+
X, y = unpack(data)
189+
data = Flux.DataLoader((X,y),batchsize=bs)
190+
return data
191+
end
192+
model_params = (batch_norm=false,n_hidden=32,n_layers=3,dropout=true,p_dropout=0.25)
193+
```
194+
195+
```{julia}
196+
experiments = set_up_experiments(
197+
data_sets,models,generators;
198+
pre_train_models=100, model_params=model_params,
199+
data_loader=data_loader
200+
)
201+
```
202+
203+
```{julia}
204+
n_evals = 5
205+
n_rounds = 50
206+
evaluate_every = Int(round(n_rounds/n_evals))
207+
n_folds = 5
208+
n_bootstrap = 1
209+
T = 100
210+
using Serialization
211+
results = run_experiments(
212+
experiments;
213+
save_path=output_path,
214+
evaluate_every=evaluate_every,
215+
n_rounds=n_rounds,
216+
n_folds=n_folds,
217+
n_bootstrap=n_bootstrap,
218+
T=T
219+
)
220+
Serialization.serialize(joinpath(output_path,"results.jls"),results)
221+
```
222+
223+
```{julia}
224+
line_charts = Dict()
225+
errorbar_charts = Dict()
226+
for (data_name, res) in results
227+
line_charts[data_name] = plot(res)
228+
errorbar_charts[data_name] = plot(res,50)
229+
end
230+
```
231+
232+
### Latent Space Search
233+
234+
235+
236+
```{julia}
237+
generators = Dict(
238+
:REVISE=>GenericGenerator(decision_threshold=0.5),
239+
:REVISE_conservative=>GenericGenerator(decision_threshold=0.9),
240+
:Gravitational=>GravitationalGenerator(),
241+
:ROAR=>EndoROARGenerator()
242+
)
243+
```
244+
245+
246+
```{julia}
247+
generative_model_params = (epochs=250, latent_dim=8)
248+
experiments = set_up_experiments(
249+
data_sets,models,generators;
250+
pre_train_models=100, model_params=model_params,
251+
data_loader=data_loader
252+
)
253+
```
254+
255+
```{julia}
256+
n_evals = 5
257+
n_rounds = 50
258+
evaluate_every = Int(round(n_rounds/n_evals))
259+
n_folds = 5
260+
n_bootstrap = 1
261+
T = 100
262+
using Serialization
263+
results = run_experiments(
264+
experiments;
265+
save_path=output_path,
266+
save_name_suffix="latent",
267+
evaluate_every=evaluate_every,
268+
n_rounds=n_rounds,
269+
n_folds=n_folds,
270+
n_bootstrap=n_bootstrap,
271+
T=T,
272+
latent_space=true,
273+
generative_model_params=generative_model_params
274+
)
275+
Serialization.serialize(joinpath(output_path,"results_latent.jls"),results)
276+
```
277+
278+
### Plots
279+
280+
```{julia}
281+
line_charts = Dict()
282+
errorbar_charts = Dict()
283+
for (data_name, res) in results
284+
line_charts[data_name] = plot(res)
285+
errorbar_charts[data_name] = plot(res,50)
286+
end
287+
```

dev/notebooks/real_world.qmd

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -63,42 +63,13 @@ n_rounds = 50
6363
evaluate_every = Int(round(n_rounds/n_evals))
6464
n_folds = 5
6565
n_bootstrap = 1
66-
T = 250
66+
T = 100
6767
generative_model_params = (epochs=250, latent_dim=8)
6868
using Serialization
6969
results = run_experiments(
7070
experiments;
7171
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T,
72-
convergence=:strict, generative_model_params=generative_model_params, save_name_suffix="strict"
73-
)
74-
Serialization.serialize(joinpath(output_path,"results_strict.jls"),results)
75-
```
76-
77-
```{julia}
78-
using Serialization
79-
results = Serialization.deserialize(joinpath(output_path,"results_strict.jls"))
80-
```
81-
82-
```{julia}
83-
line_charts = Dict()
84-
errorbar_charts = Dict()
85-
for (data_name, res) in results
86-
line_charts[data_name] = plot(res)
87-
errorbar_charts[data_name] = plot(res,50)
88-
end
89-
```
90-
91-
## Simple convergence
92-
93-
```{julia}
94-
n_evals = 5
95-
n_rounds = 50
96-
evaluate_every = Int(round(n_rounds/n_evals))
97-
n_folds = 5
98-
n_bootstrap = 1
99-
using Serialization
100-
results = run_experiments(
101-
experiments;save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap
72+
generative_model_params=generative_model_params
10273
)
10374
Serialization.serialize(joinpath(output_path,"results.jls"),results)
10475
```
@@ -108,8 +79,6 @@ using Serialization
10879
results = Serialization.deserialize(joinpath(output_path,"results.jls"))
10980
```
11081

111-
### Plots
112-
11382
```{julia}
11483
line_charts = Dict()
11584
errorbar_charts = Dict()

0 commit comments

Comments
 (0)