Skip to content

Commit a867ea7

Browse files
committed
just missing the tables now
1 parent 307a243 commit a867ea7

6 files changed

Lines changed: 1270 additions & 1528 deletions

File tree

_freeze/dev/notebooks/appendix/execute-results/html.json

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

build/dev/notebooks/appendix.html

Lines changed: 1111 additions & 1422 deletions
Large diffs are not rendered by default.

dev/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
66
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
77
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
88
Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
9+
Gumbo = "708ec375-b3d6-5a57-a7ce-8257bf98657a"
910
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
1011
LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

dev/notebooks/appendix.qmd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ format:
88
echo: true
99
eval: false
1010
warning: false
11+
toc: true
1112
jupyter: julia-1.7
1213
---
1314

dev/notebooks/experiments/mitigation_strategies.qmd

Lines changed: 122 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ using Pkg; Pkg.activate("dev")
99
#| eval: true
1010
include("dev/utils.jl")
1111
using AlgorithmicRecourseDynamics
12-
using CounterfactualExplanations, Flux, Plots, PlotThemes, Random, LaplaceRedux, LinearAlgebra
12+
using CounterfactualExplanations, Flux, Plots, PlotThemes, Random, LaplaceRedux, LinearAlgebra, Images
1313
theme(:wong)
1414
output_path = output_dir("mitigation_strategies")
15-
www_path = www_dir("mitigation_strategies")
15+
www_path = www_dir("mitigation_strategies");
1616
```
1717

1818
```{julia}
@@ -48,42 +48,6 @@ data_sets = filter(p -> p[1] in choices, catalogue)
4848
experiments = set_up_experiments(data_sets,models,generators)
4949
```
5050

51-
```{julia}
52-
using AlgorithmicRecourseDynamics.Models: model_evaluation
53-
plts = []
54-
for (exp_name, exp_) in experiments
55-
for (M_name, M) in exp_.models
56-
score = round(model_evaluation(M, exp_.test_data),digits=2)
57-
plt = plot(M, exp_.test_data, title="$exp_name;\n $M_name ($score)")
58-
# Errors:
59-
ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
60-
x_wrongly_labelled = exp_.test_data.X[:,ids]
61-
scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
62-
plts = vcat(plts..., plt)
63-
end
64-
end
65-
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
66-
savefig(plt, joinpath(www_path,"models_test_before.png"))
67-
```
68-
69-
```{julia}
70-
using AlgorithmicRecourseDynamics.Models: model_evaluation
71-
plts = []
72-
for (exp_name, exp_) in experiments
73-
for (M_name, M) in exp_.models
74-
score = round(model_evaluation(M, exp_.train_data),digits=2)
75-
plt = plot(M, exp_.train_data, title="$exp_name;\n $M_name ($score)")
76-
# Errors:
77-
ids = findall(vec(round.(probs(M, exp_.train_data.X)) .!= exp_.train_data.y))
78-
x_wrongly_labelled = exp_.train_data.X[:,ids]
79-
scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
80-
plts = vcat(plts..., plt)
81-
end
82-
end
83-
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
84-
savefig(plt, joinpath(www_path,"models_train_before.png"))
85-
```
86-
8751
```{julia}
8852
n_evals = 5
8953
n_rounds = 50
@@ -99,64 +63,7 @@ results = run_experiments(
9963
Serialization.serialize(joinpath(output_path,"results_synthetic.jls"),results)
10064
```
10165

102-
```{julia}
103-
using AlgorithmicRecourseDynamics.Models: model_evaluation
104-
plot_dict = Dict(key => Dict() for (key,val) in results)
105-
fold = 1
106-
for (name, res) in results
107-
exp_ = res.experiment
108-
plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
109-
rec_sys = exp_.recourse_systems[fold]
110-
sys_ids = collect(exp_.system_identifiers)
111-
M = length(rec_sys)
112-
for m in 1:M
113-
model_name, generator_name = sys_ids[m]
114-
M = rec_sys[m].model
115-
score = round(model_evaluation(M, exp_.test_data),digits=2)
116-
plt = plot(M, exp_.test_data, title="$name;\n $model_name ($score)")
117-
# Errors:
118-
ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
119-
x_wrongly_labelled = exp_.test_data.X[:,ids]
120-
scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
121-
plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
122-
end
123-
end
124-
plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
125-
for (name, plts) in plot_dict
126-
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
127-
savefig(plt, joinpath(www_path,"models_test_after_$(name).png"))
128-
end
129-
```
130-
131-
```{julia}
132-
using AlgorithmicRecourseDynamics.Models: model_evaluation
133-
plot_dict = Dict(key => Dict() for (key,val) in results)
134-
fold = 1
135-
for (name, res) in results
136-
exp_ = res.experiment
137-
plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
138-
rec_sys = exp_.recourse_systems[fold]
139-
sys_ids = collect(exp_.system_identifiers)
140-
M = length(rec_sys)
141-
for m in 1:M
142-
model_name, generator_name = sys_ids[m]
143-
M = rec_sys[m].model
144-
data = rec_sys[m].data
145-
score = round(model_evaluation(M, data),digits=2)
146-
plt = plot(M, data, title="$name;\n $model_name ($score)")
147-
# Errors:
148-
ids = findall(vec(round.(probs(M, data.X)) .!= data.y))
149-
x_wrongly_labelled = data.X[:,ids]
150-
scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
151-
plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
152-
end
153-
end
154-
plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
155-
for (name, plts) in plot_dict
156-
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
157-
savefig(plt, joinpath(www_path,"models_train_after_$(name).png"))
158-
end
159-
```
66+
### Plots
16067

16168
```{julia}
16269
using Serialization
@@ -177,8 +84,58 @@ for (data_name, res) in results
17784
end
17885
```
17986

87+
#### Line Charts
88+
89+
@fig-mit-line shows the evolution of the evaluation metrics over the course of the experiment.
90+
91+
```{julia}
92+
#| eval: true
93+
#| fig-cap: "Line Charts"
94+
#| fig-subcap:
95+
#| - "California Housing"
96+
#| - "Circles"
97+
#| - "Credit Default"
98+
#| - "GMSC"
99+
#| - "Linearly Separable"
100+
#| - "Moons"
101+
#| - "Overlapping"
102+
#| layout-ncol: 1
103+
#| label: fig-mit-line
104+
img_files = readdir(www_path)[contains.(readdir(www_path),"line_chart") .&& .!contains.(readdir(www_path),"latent")]
105+
img_files = joinpath.(www_path,img_files)
106+
for img in img_files
107+
display(load(img))
108+
end
109+
```
110+
111+
#### Error Bar Charts
112+
113+
@fig-mit-error shows the evaluation metrics at the end of the experiments.
114+
115+
```{julia}
116+
#| eval: true
117+
#| fig-cap: "Error Bar Charts"
118+
#| fig-subcap:
119+
#| - "California Housing"
120+
#| - "Circles"
121+
#| - "Credit Default"
122+
#| - "GMSC"
123+
#| - "Linearly Separable"
124+
#| - "Moons"
125+
#| - "Overlapping"
126+
#| layout-ncol: 1
127+
#| label: fig-mit-error
128+
img_files = readdir(www_path)[contains.(readdir(www_path),"errorbar_chart") .&& .!contains.(readdir(www_path),"latent")]
129+
img_files = joinpath.(www_path,img_files)
130+
for img in img_files
131+
display(load(img))
132+
end
133+
```
134+
180135
### Chart in paper
181136

137+
@fig-mit-paper shows the chart that went into the paper.
138+
182139
```{julia}
183140
using DataFrames, Statistics
184141
df = results[:overlapping].output
@@ -234,6 +191,13 @@ img = Images.load(rcopy(R"temp_path"))
234191
Images.save(joinpath(www_path,"paper_synthetic_results.png"), img)
235192
```
236193

194+
```{julia}
195+
#| label: fig-mit-paper
196+
#| fig-cap: "Chart in paper"
197+
#| eval: true
198+
Images.load(joinpath(www_path,"paper_synthetic_results.png"))
199+
```
200+
237201

238202
### Latent Space Search
239203

@@ -284,8 +248,54 @@ for (data_name, res) in results
284248
end
285249
```
286250

251+
### Plots
252+
253+
#### Line Charts
254+
255+
@fig-mit-line-latent shows the evolution of the evaluation metrics over the course of the experiment.
256+
257+
```{julia}
258+
#| eval: true
259+
#| fig-cap: "Line Charts"
260+
#| fig-subcap:
261+
#| - "Circles"
262+
#| - "Linearly Separable"
263+
#| - "Moons"
264+
#| - "Overlapping"
265+
#| layout-ncol: 1
266+
#| label: fig-mit-line-latent
267+
img_files = readdir(www_path)[contains.(readdir(www_path),"line_chart") .&& contains.(readdir(www_path),"latent")]
268+
img_files = joinpath.(www_path,img_files)
269+
for img in img_files
270+
display(load(img))
271+
end
272+
```
273+
274+
#### Error Bar Charts
275+
276+
@fig-mit-error-latent shows the evaluation metrics at the end of the experiments.
277+
278+
```{julia}
279+
#| eval: true
280+
#| fig-cap: "Error Bar Charts"
281+
#| fig-subcap:
282+
#| - "Circles"
283+
#| - "Linearly Separable"
284+
#| - "Moons"
285+
#| - "Overlapping"
286+
#| layout-ncol: 1
287+
#| label: fig-mit-error-latent
288+
img_files = readdir(www_path)[contains.(readdir(www_path),"errorbar_chart") .&& contains.(readdir(www_path),"latent")]
289+
img_files = joinpath.(www_path,img_files)
290+
for img in img_files
291+
display(load(img))
292+
end
293+
```
294+
287295
### Chart in paper
288296

297+
@fig-mit-latent-paper shows the chart that went into the paper.
298+
289299
```{julia}
290300
using DataFrames, Statistics
291301
df = results[:overlapping].output
@@ -341,6 +351,13 @@ img = Images.load(rcopy(R"temp_path"))
341351
Images.save(joinpath(www_path,"paper_synthetic_latent_results.png"), img)
342352
```
343353

354+
```{julia}
355+
#| label: fig-mit-latent-paper
356+
#| fig-cap: "Chart in paper"
357+
#| eval: true
358+
Images.load(joinpath(www_path,"paper_synthetic_latent_results.png"))
359+
```
360+
344361
## Real World
345362

346363
```{julia}
@@ -419,6 +436,8 @@ end
419436

420437
### Chart in paper
421438

439+
@fig-mit-latent-paper shows the chart that went into the paper.
440+
422441
```{julia}
423442
using DataFrames, Statistics
424443
model_ = :FluxEnsemble
@@ -476,3 +495,10 @@ ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.85)
476495
img = Images.load(rcopy(R"temp_path"))
477496
Images.save(joinpath(www_path,"paper_real_world_results.png"), img)
478497
```
498+
499+
```{julia}
500+
#| label: fig-mit-real-paper
501+
#| fig-cap: "Chart in paper"
502+
#| eval: true
503+
Images.load(joinpath(www_path,"paper_real_world_results.png"))
504+
```

dev/notebooks/experiments/synthetic.qmd

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ end
221221
```{julia}
222222
#| eval: true
223223
using Serialization
224-
results = Serialization.deserialize(joinpath(output_path,"results.jls"))
224+
results = Serialization.deserialize(joinpath(output_path,"results.jls"));
225225
```
226226

227227
```{julia}
@@ -280,17 +280,42 @@ for img in img_files
280280
end
281281
```
282282

283-
### Tables
283+
### Bootstrap
284284

285-
@tbl-results shows a summary of all results.
285+
```{julia}
286+
n_bootstrap = 1
287+
using AlgorithmicRecourseDynamics.Evaluation: evaluate_system
288+
using DataFrames
289+
df = DataFrame()
290+
for (key, val) in results
291+
n_folds = length(val.experiment.recourse_systems)
292+
for fold in 1:n_folds
293+
for i in length(val.experiment.system_identifiers)
294+
rec_sys = val.experiment.recourse_systems[fold][i]
295+
model_name, gen_name = collect(val.experiment.system_identifiers)[i]
296+
df_ = evaluate_system(rec_sys, val.experiment; n=n_bootstrap)
297+
df_.model .= model_name
298+
df_.generator .= gen_name
299+
df_.fold .= fold
300+
df = vcat(df, df_)
301+
end
302+
end
303+
end
304+
df = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df)
305+
using RCall
306+
save_path = joinpath(output_path, "bootstrap.html")
307+
R"""
308+
dt <- DT::datatable($df) |>
309+
DT::formatRound(columns=c("value"), digits=3)
310+
DT::saveWidget(dt, $save_path)
311+
"""
312+
```
286313

287314
```{julia}
288-
#| label: tbl-results
289-
#| tbl-cap: "Summary of all results"
290315
#| eval: true
291-
#| results: asis
292-
using AlgorithmicRecourseDynamics: kable
293-
kable(results, [50]; format="html")
316+
using Gumbo
317+
save_path = joinpath(output_path, "bootstrap.html")
318+
parsehtml(read(save_path, String))
294319
```
295320

296321
### Chart in paper {#sec-app-synthetic-paper}

0 commit comments

Comments
 (0)