Skip to content

Commit 158dcf2

Browse files
committed
helper function for bootstrap with caching and multi-threading
1 parent 19c2efc commit 158dcf2

21 files changed

Lines changed: 279 additions & 412 deletions

docs/src/_intro.qmd

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ In this work we investigate what happens if Algorithmic Recourse is actually imp
1818

1919
![](paper/www/poc.png)
2020

21+
## Reproduce Results
22+
23+
The most natural and interactive way to reproduce the results in the paper is to use the relevant notebooks that go through the process step by step. Alternatively, you can rerun of the relevant notebooks as a background job:
24+
25+
```{shell}
26+
quarto render docs/src/paper/appendix.qmd
27+
```
28+
2129
## Paper Abstract
2230

2331
Existing work on Counterfactual Explanations (CE) and Algorithmic Recourse (AR) has largely focused on single individuals in a static environment: given some estimated model, the goal is to find valid counterfactuals for an individual instance that fulfill various desiderata. The ability of such counterfactuals to handle dynamics like data and model drift remains a largely unexplored research challenge. There has also been surprisingly little work on the related question of how the actual implementation of recourse by one individual may affect other individuals. Through this work we aim to close that gap. We first show that many of the existing methodologies can be collectively described by a generalized framework. We then argue that the existing framework does not account for a hidden external cost of recourse, that only reveals itself when studying the endogenous dynamics of recourse at the group level. Through simulation experiments involving various state-of-the-art counterfactual generators and several benchmark datasets, we generate large numbers of counterfactuals and study the resulting domain and model shifts. We find that the induced shifts are substantial enough to likely impede the applicability of Algorithmic Recourse in some situations. Fortunately, we find various strategies to mitigate these concerns. Our simulation framework for studying recourse dynamics is fast and open-sourced.

docs/src/paper/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
77
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
88
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
99
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
10+
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

docs/src/paper/experiments/_mitigation_strategies.qmd

Lines changed: 24 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
1-
## Mitigation Strategies
1+
### Mitigation Strategies
22

33
```{julia}
4-
#| eval: true
5-
using Pkg; Pkg.activate("dev")
6-
```
4+
#| echo: false
75
8-
```{julia}
9-
#| eval: true
10-
include("dev/utils.jl")
11-
using AlgorithmicRecourseDynamics
12-
using CounterfactualExplanations, Flux, Plots, PlotThemes, Random, LaplaceRedux, LinearAlgebra, Images
13-
theme(:wong)
6+
include("docs/src/paper/setup.jl")
7+
eval(setup)
148
output_path = output_dir("mitigation_strategies")
15-
www_path = www_dir("mitigation_strategies");
9+
www_path = www_dir("mitigation_strategies")
1610
```
1711

1812
```{julia}
@@ -30,7 +24,7 @@ generators = Dict(
3024
)
3125
```
3226

33-
### Synthetic
27+
#### Synthetic
3428

3529
```{julia}
3630
max_obs = 1000
@@ -53,17 +47,16 @@ n_evals = 5
5347
n_rounds = 50
5448
evaluate_every = Int(round(n_rounds/n_evals))
5549
n_folds = 5
56-
n_bootstrap = 1
5750
T = 100
5851
using Serialization
5952
results = run_experiments(
6053
experiments;
61-
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T
54+
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T
6255
)
6356
Serialization.serialize(joinpath(output_path,"results_synthetic.jls"),results)
6457
```
6558

66-
### Plots
59+
#### Plots
6760

6861
```{julia}
6962
using Serialization
@@ -84,7 +77,7 @@ for (data_name, res) in results
8477
end
8578
```
8679

87-
#### Line Charts
80+
##### Line Charts
8881

8982
@fig-mit-line shows the evolution of the evaluation metrics over the course of the experiment.
9083

@@ -108,7 +101,7 @@ for img in img_files
108101
end
109102
```
110103

111-
#### Error Bar Charts
104+
##### Error Bar Charts
112105

113106
@fig-mit-error shows the evaluation metrics at the end of the experiments.
114107

@@ -132,35 +125,14 @@ for img in img_files
132125
end
133126
```
134127

135-
### Bootstrap
128+
#### Bootstrap
136129

137130
```{julia}
138131
n_bootstrap = 1000
139-
using AlgorithmicRecourseDynamics.Evaluation: evaluate_system
140-
using DataFrames
141-
df = DataFrame()
142-
for (key, val) in results
143-
n_folds = length(val.experiment.recourse_systems)
144-
for fold in 1:n_folds
145-
for i in length(val.experiment.system_identifiers)
146-
rec_sys = val.experiment.recourse_systems[fold][i]
147-
model_name, gen_name = collect(val.experiment.system_identifiers)[i]
148-
df_ = evaluate_system(rec_sys, val.experiment; n=n_bootstrap)
149-
df_.model .= model_name
150-
df_.generator .= gen_name
151-
df_.fold .= fold
152-
df = vcat(df, df_)
153-
end
154-
end
155-
end
156-
df = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df)
157-
using RCall
158-
save_path = joinpath(output_path, "bootstrap_synthetic.csv")
159-
using CSV
160-
CSV.write(save_path)
132+
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_synthetic.csv"))
161133
```
162134

163-
### Chart in paper
135+
#### Chart in paper
164136

165137
@fig-mit-paper shows the chart that went into the paper.
166138

@@ -227,7 +199,7 @@ Images.load(joinpath(www_path,"paper_synthetic_results.png"))
227199
```
228200

229201

230-
### Latent Space Search
202+
#### Latent Space Search
231203

232204
```{julia}
233205
generators = Dict(
@@ -247,12 +219,11 @@ n_evals = 5
247219
n_rounds = 50
248220
evaluate_every = Int(round(n_rounds/n_evals))
249221
n_folds = 5
250-
n_bootstrap = 1
251222
T = 100
252223
using Serialization
253224
results = run_experiments(
254225
experiments;
255-
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T
226+
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T
256227
)
257228
Serialization.serialize(joinpath(output_path,"results_synthetic_latent.jls"),results)
258229
```
@@ -276,9 +247,9 @@ for (data_name, res) in results
276247
end
277248
```
278249

279-
### Plots
250+
#### Plots
280251

281-
#### Line Charts
252+
##### Line Charts
282253

283254
@fig-mit-line-latent shows the evolution of the evaluation metrics over the course of the experiment.
284255

@@ -299,7 +270,7 @@ for img in img_files
299270
end
300271
```
301272

302-
#### Error Bar Charts
273+
##### Error Bar Charts
303274

304275
@fig-mit-error-latent shows the evaluation metrics at the end of the experiments.
305276

@@ -320,35 +291,14 @@ for img in img_files
320291
end
321292
```
322293

323-
### Bootstrap
294+
#### Bootstrap
324295

325296
```{julia}
326297
n_bootstrap = 1000
327-
using AlgorithmicRecourseDynamics.Evaluation: evaluate_system
328-
using DataFrames
329-
df = DataFrame()
330-
for (key, val) in results
331-
n_folds = length(val.experiment.recourse_systems)
332-
for fold in 1:n_folds
333-
for i in length(val.experiment.system_identifiers)
334-
rec_sys = val.experiment.recourse_systems[fold][i]
335-
model_name, gen_name = collect(val.experiment.system_identifiers)[i]
336-
df_ = evaluate_system(rec_sys, val.experiment; n=n_bootstrap)
337-
df_.model .= model_name
338-
df_.generator .= gen_name
339-
df_.fold .= fold
340-
df = vcat(df, df_)
341-
end
342-
end
343-
end
344-
df = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df)
345-
using RCall
346-
save_path = joinpath(output_path, "bootstrap_latent.csv")
347-
using CSV
348-
CSV.write(save_path)
298+
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_latent.csv"))
349299
```
350300

351-
### Chart in paper
301+
#### Chart in paper
352302

353303
@fig-mit-latent-paper shows the chart that went into the paper.
354304

@@ -414,111 +364,14 @@ Images.save(joinpath(www_path,"paper_synthetic_latent_results.png"), img)
414364
Images.load(joinpath(www_path,"paper_synthetic_latent_results.png"))
415365
```
416366

417-
## Real World
418-
419-
```{julia}
420-
generators = Dict(
421-
:Generic=>GenericGenerator(decision_threshold=0.5),
422-
:Latent=>REVISEGenerator(),
423-
:Generic_conservative=>GenericGenerator(decision_threshold=0.9),
424-
:Gravitational=>GravitationalGenerator(),
425-
:ClapROAR=>ClapROARGenerator()
426-
)
427-
```
428-
429-
```{julia}
430-
max_obs = 2500
431-
data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs)
432-
```
433-
434-
```{julia}
435-
using CounterfactualExplanations.DataPreprocessing: unpack
436-
bs = 50
437-
function data_loader(data::CounterfactualData)
438-
X, y = unpack(data)
439-
data = Flux.DataLoader((X,y),batchsize=bs)
440-
return data
441-
end
442-
model_params = (batch_norm=false,n_hidden=32,n_layers=3,dropout=true,p_dropout=0.25)
443-
```
444-
445-
```{julia}
446-
experiments = set_up_experiments(
447-
data_sets,models,generators;
448-
pre_train_models=100, model_params=model_params,
449-
data_loader=data_loader
450-
)
451-
```
452-
453-
```{julia}
454-
n_evals = 5
455-
n_rounds = 50
456-
evaluate_every = Int(round(n_rounds/n_evals))
457-
n_folds = 5
458-
n_bootstrap = 1
459-
n_samples = 10000
460-
T = 250
461-
using Serialization
462-
results = run_experiments(
463-
experiments;
464-
save_path=output_path,
465-
evaluate_every=evaluate_every,
466-
n_rounds=n_rounds,
467-
n_folds=n_folds,
468-
n_bootstrap=n_bootstrap,
469-
T=T
470-
)
471-
Serialization.serialize(joinpath(output_path,"results_real_world.jls"),results)
472-
```
473-
474-
```{julia}
475-
using Serialization
476-
results = Serialization.deserialize(joinpath(output_path,"results_real_world.jls"))
477-
```
478-
479-
```{julia}
480-
using Images
481-
line_charts = Dict()
482-
errorbar_charts = Dict()
483-
for (data_name, res) in results
484-
plt = plot(res)
485-
Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
486-
line_charts[data_name] = plt
487-
plt = plot(res,maximum(res.output.n))
488-
Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
489-
errorbar_charts[data_name] = plt
490-
end
491-
```
492-
493-
### Bootstrap
367+
#### Bootstrap
494368

495369
```{julia}
496370
n_bootstrap = 1000
497-
using AlgorithmicRecourseDynamics.Evaluation: evaluate_system
498-
using DataFrames
499-
df = DataFrame()
500-
for (key, val) in results
501-
n_folds = length(val.experiment.recourse_systems)
502-
for fold in 1:n_folds
503-
for i in length(val.experiment.system_identifiers)
504-
rec_sys = val.experiment.recourse_systems[fold][i]
505-
model_name, gen_name = collect(val.experiment.system_identifiers)[i]
506-
df_ = evaluate_system(rec_sys, val.experiment; n=n_bootstrap)
507-
df_.model .= model_name
508-
df_.generator .= gen_name
509-
df_.fold .= fold
510-
df = vcat(df, df_)
511-
end
512-
end
513-
end
514-
df = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df)
515-
using RCall
516-
save_path = joinpath(output_path, "bootstrap_real_world.csv")
517-
using CSV
518-
CSV.write(save_path)
371+
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_real_world.csv"))
519372
```
520373

521-
### Chart in paper
374+
#### Chart in paper
522375

523376
@fig-mit-latent-paper shows the chart that went into the paper.
524377

0 commit comments

Comments
 (0)