@@ -9,10 +9,10 @@ using Pkg; Pkg.activate("dev")
99#| eval: true
1010include("dev/utils.jl")
1111using AlgorithmicRecourseDynamics
12- using CounterfactualExplanations, Flux, Plots, PlotThemes, Random, LaplaceRedux, LinearAlgebra
12+ using CounterfactualExplanations, Flux, Plots, PlotThemes, Random, LaplaceRedux, LinearAlgebra, Images
1313theme(:wong)
1414output_path = output_dir("mitigation_strategies")
15- www_path = www_dir("mitigation_strategies")
15+ www_path = www_dir("mitigation_strategies");
1616```
1717
1818``` {julia}
@@ -48,42 +48,6 @@ data_sets = filter(p -> p[1] in choices, catalogue)
4848experiments = set_up_experiments(data_sets,models,generators)
4949```
5050
51- ``` {julia}
52- using AlgorithmicRecourseDynamics.Models: model_evaluation
53- plts = []
54- for (exp_name, exp_) in experiments
55- for (M_name, M) in exp_.models
56- score = round(model_evaluation(M, exp_.test_data),digits=2)
57- plt = plot(M, exp_.test_data, title="$exp_name;\n $M_name ($score)")
58- # Errors:
59- ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
60- x_wrongly_labelled = exp_.test_data.X[:,ids]
61- scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
62- plts = vcat(plts..., plt)
63- end
64- end
65- plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
66- savefig(plt, joinpath(www_path,"models_test_before.png"))
67- ```
68-
69- ``` {julia}
70- using AlgorithmicRecourseDynamics.Models: model_evaluation
71- plts = []
72- for (exp_name, exp_) in experiments
73- for (M_name, M) in exp_.models
74- score = round(model_evaluation(M, exp_.train_data),digits=2)
75- plt = plot(M, exp_.train_data, title="$exp_name;\n $M_name ($score)")
76- # Errors:
77- ids = findall(vec(round.(probs(M, exp_.train_data.X)) .!= exp_.train_data.y))
78- x_wrongly_labelled = exp_.train_data.X[:,ids]
79- scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
80- plts = vcat(plts..., plt)
81- end
82- end
83- plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
84- savefig(plt, joinpath(www_path,"models_train_before.png"))
85- ```
86-
8751``` {julia}
8852n_evals = 5
8953n_rounds = 50
@@ -99,64 +63,7 @@ results = run_experiments(
9963Serialization.serialize(joinpath(output_path,"results_synthetic.jls"),results)
10064```
10165
102- ``` {julia}
103- using AlgorithmicRecourseDynamics.Models: model_evaluation
104- plot_dict = Dict(key => Dict() for (key,val) in results)
105- fold = 1
106- for (name, res) in results
107- exp_ = res.experiment
108- plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
109- rec_sys = exp_.recourse_systems[fold]
110- sys_ids = collect(exp_.system_identifiers)
111- M = length(rec_sys)
112- for m in 1:M
113- model_name, generator_name = sys_ids[m]
114- M = rec_sys[m].model
115- score = round(model_evaluation(M, exp_.test_data),digits=2)
116- plt = plot(M, exp_.test_data, title="$name;\n $model_name ($score)")
117- # Errors:
118- ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
119- x_wrongly_labelled = exp_.test_data.X[:,ids]
120- scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
121- plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
122- end
123- end
124- plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
125- for (name, plts) in plot_dict
126- plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
127- savefig(plt, joinpath(www_path,"models_test_after_$(name).png"))
128- end
129- ```
130-
131- ``` {julia}
132- using AlgorithmicRecourseDynamics.Models: model_evaluation
133- plot_dict = Dict(key => Dict() for (key,val) in results)
134- fold = 1
135- for (name, res) in results
136- exp_ = res.experiment
137- plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
138- rec_sys = exp_.recourse_systems[fold]
139- sys_ids = collect(exp_.system_identifiers)
140- M = length(rec_sys)
141- for m in 1:M
142- model_name, generator_name = sys_ids[m]
143- M = rec_sys[m].model
144- data = rec_sys[m].data
145- score = round(model_evaluation(M, data),digits=2)
146- plt = plot(M, data, title="$name;\n $model_name ($score)")
147- # Errors:
148- ids = findall(vec(round.(probs(M, data.X)) .!= data.y))
149- x_wrongly_labelled = data.X[:,ids]
150- scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
151- plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
152- end
153- end
154- plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
155- for (name, plts) in plot_dict
156- plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
157- savefig(plt, joinpath(www_path,"models_train_after_$(name).png"))
158- end
159- ```
66+ ### Plots
16067
16168``` {julia}
16269using Serialization
@@ -177,8 +84,58 @@ for (data_name, res) in results
17784end
17885```
17986
87+ #### Line Charts
88+
89+ @fig-mit-line shows the evolution of the evaluation metrics over the course of the experiment.
90+
91+ ``` {julia}
92+ #| eval: true
93+ #| fig-cap: "Line Charts"
94+ #| fig-subcap:
95+ #| - "California Housing"
96+ #| - "Circles"
97+ #| - "Credit Default"
98+ #| - "GMSC"
99+ #| - "Linearly Separable"
100+ #| - "Moons"
101+ #| - "Overlapping"
102+ #| layout-ncol: 1
103+ #| label: fig-mit-line
104+ img_files = readdir(www_path)[contains.(readdir(www_path),"line_chart") .&& .!contains.(readdir(www_path),"latent")]
105+ img_files = joinpath.(www_path,img_files)
106+ for img in img_files
107+ display(load(img))
108+ end
109+ ```
110+
111+ #### Error Bar Charts
112+
113+ @fig-mit-error shows the evaluation metrics at the end of the experiments.
114+
115+ ``` {julia}
116+ #| eval: true
117+ #| fig-cap: "Error Bar Charts"
118+ #| fig-subcap:
119+ #| - "California Housing"
120+ #| - "Circles"
121+ #| - "Credit Default"
122+ #| - "GMSC"
123+ #| - "Linearly Separable"
124+ #| - "Moons"
125+ #| - "Overlapping"
126+ #| layout-ncol: 1
127+ #| label: fig-mit-error
128+ img_files = readdir(www_path)[contains.(readdir(www_path),"errorbar_chart") .&& .!contains.(readdir(www_path),"latent")]
129+ img_files = joinpath.(www_path,img_files)
130+ for img in img_files
131+ display(load(img))
132+ end
133+ ```
134+
180135### Chart in paper
181136
137+ @fig-mit-paper shows the chart that went into the paper.
138+
182139``` {julia}
183140using DataFrames, Statistics
184141df = results[:overlapping].output
@@ -234,6 +191,13 @@ img = Images.load(rcopy(R"temp_path"))
234191Images.save(joinpath(www_path,"paper_synthetic_results.png"), img)
235192```
236193
194+ ``` {julia}
195+ #| label: fig-mit-paper
196+ #| fig-cap: "Chart in paper"
197+ #| eval: true
198+ Images.load(joinpath(www_path,"paper_synthetic_results.png"))
199+ ```
200+
237201
238202### Latent Space Search
239203
@@ -284,8 +248,54 @@ for (data_name, res) in results
284248end
285249```
286250
251+ ### Plots
252+
253+ #### Line Charts
254+
255+ @fig-mit-line-latent shows the evolution of the evaluation metrics over the course of the experiment.
256+
257+ ``` {julia}
258+ #| eval: true
259+ #| fig-cap: "Line Charts"
260+ #| fig-subcap:
261+ #| - "Circles"
262+ #| - "Linearly Separable"
263+ #| - "Moons"
264+ #| - "Overlapping"
265+ #| layout-ncol: 1
266+ #| label: fig-mit-line-latent
267+ img_files = readdir(www_path)[contains.(readdir(www_path),"line_chart") .&& contains.(readdir(www_path),"latent")]
268+ img_files = joinpath.(www_path,img_files)
269+ for img in img_files
270+ display(load(img))
271+ end
272+ ```
273+
274+ #### Error Bar Charts
275+
276+ @fig-mit-error-latent shows the evaluation metrics at the end of the experiments.
277+
278+ ``` {julia}
279+ #| eval: true
280+ #| fig-cap: "Error Bar Charts"
281+ #| fig-subcap:
282+ #| - "Circles"
283+ #| - "Linearly Separable"
284+ #| - "Moons"
285+ #| - "Overlapping"
286+ #| layout-ncol: 1
287+ #| label: fig-mit-error-latent
288+ img_files = readdir(www_path)[contains.(readdir(www_path),"errorbar_chart") .&& contains.(readdir(www_path),"latent")]
289+ img_files = joinpath.(www_path,img_files)
290+ for img in img_files
291+ display(load(img))
292+ end
293+ ```
294+
287295### Chart in paper
288296
297+ @fig-mit-latent-paper shows the chart that went into the paper.
298+
289299``` {julia}
290300using DataFrames, Statistics
291301df = results[:overlapping].output
@@ -341,6 +351,13 @@ img = Images.load(rcopy(R"temp_path"))
341351Images.save(joinpath(www_path,"paper_synthetic_latent_results.png"), img)
342352```
343353
354+ ``` {julia}
355+ #| label: fig-mit-latent-paper
356+ #| fig-cap: "Chart in paper"
357+ #| eval: true
358+ Images.load(joinpath(www_path,"paper_synthetic_latent_results.png"))
359+ ```
360+
344361## Real World
345362
346363``` {julia}
419436
420437### Chart in paper
421438
439+ @fig-mit-latent-paper shows the chart that went into the paper.
440+
422441``` {julia}
423442using DataFrames, Statistics
424443model_ = :FluxEnsemble
@@ -476,3 +495,10 @@ ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.85)
476495img = Images.load(rcopy(R"temp_path"))
477496Images.save(joinpath(www_path,"paper_real_world_results.png"), img)
478497```
498+
499+ ``` {julia}
500+ #| label: fig-mit-real-paper
501+ #| fig-cap: "Chart in paper"
502+ #| eval: true
503+ Images.load(joinpath(www_path,"paper_real_world_results.png"))
504+ ```
0 commit comments