Skip to content

Commit 50f8cf0

Browse files
committed
uh all running
1 parent ff9ef2c commit 50f8cf0

6 files changed

Lines changed: 89 additions & 11 deletions

File tree

docs/src/paper/experiments/_mitigation_strategies.qmd

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ end
128128
#### Bootstrap
129129

130130
```{julia}
131-
n_bootstrap = 1000
131+
n_bootstrap = 10
132132
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_synthetic.csv"))
133133
```
134134

@@ -294,7 +294,7 @@ end
294294
#### Bootstrap
295295

296296
```{julia}
297-
n_bootstrap = 1000
297+
n_bootstrap = 10
298298
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_latent.csv"))
299299
```
300300

@@ -364,10 +364,88 @@ Images.save(joinpath(www_path,"paper_synthetic_latent_results.png"), img)
364364
Images.load(joinpath(www_path,"paper_synthetic_latent_results.png"))
365365
```
366366

367+
## Real World
368+
369+
```{julia}
370+
generators = Dict(
371+
:Generic=>GenericGenerator(decision_threshold=0.5),
372+
:Latent=>REVISEGenerator(),
373+
:Generic_conservative=>GenericGenerator(decision_threshold=0.9),
374+
:Gravitational=>GravitationalGenerator(),
375+
:ClapROAR=>ClapROARGenerator()
376+
)
377+
```
378+
379+
```{julia}
380+
max_obs = 2500
381+
data_path = data_dir("real_world")
382+
data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs; data_dir=data_path)
383+
choices = [
384+
:cal_housing,
385+
:credit_default,
386+
:gmsc,
387+
]
388+
data_sets = filter(p -> p[1] in choices, data_sets)
389+
```
390+
391+
```{julia}
392+
using CounterfactualExplanations.DataPreprocessing: unpack
393+
bs = 500
394+
function data_loader(data::CounterfactualData)
395+
X, y = unpack(data)
396+
data = Flux.DataLoader((X,y),batchsize=bs)
397+
return data
398+
end
399+
model_params = (batch_norm=false,n_hidden=64,n_layers=3,dropout=true,p_dropout=0.1)
400+
```
401+
402+
```{julia}
403+
experiments = set_up_experiments(
404+
data_sets,models,generators;
405+
pre_train_models=100, model_params=model_params,
406+
data_loader=data_loader
407+
)
408+
```
409+
410+
```{julia}
411+
n_evals = 5
412+
n_rounds = 50
413+
evaluate_every = Int(round(n_rounds/n_evals))
414+
n_folds = 5
415+
n_samples = 10000
416+
T = 100
417+
generative_model_params = (epochs=250, latent_dim=8)
418+
results = run_experiments(
419+
experiments;
420+
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T, n_samples=n_samples,
421+
generative_model_params=generative_model_params
422+
)
423+
Serialization.serialize(joinpath(output_path,"results_real_world.jls"),results)
424+
```
425+
426+
```{julia}
427+
using Serialization
428+
results = Serialization.deserialize(joinpath(output_path,"results_real_world.jls"))
429+
```
430+
431+
```{julia}
432+
using Images
433+
line_charts = Dict()
434+
errorbar_charts = Dict()
435+
for (data_name, res) in results
436+
plt = plot(res)
437+
Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
438+
line_charts[data_name] = plt
439+
plt = plot(res,maximum(res.output.n))
440+
Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
441+
errorbar_charts[data_name] = plt
442+
end
443+
```
444+
367445
#### Bootstrap
368446

369447
```{julia}
370-
n_bootstrap = 1000
448+
n_bootstrap = 10
371449
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_real_world.csv"))
372450
```
373451

docs/src/paper/experiments/_real_world.qmd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ max_obs = 2500
1515
data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs; data_dir=data_path)
1616
choices = [
1717
:cal_housing,
18-
# :credit_default,
18+
:credit_default,
1919
:gmsc,
2020
]
2121
data_sets = filter(p -> p[1] in choices, data_sets)
@@ -136,7 +136,7 @@ end
136136
#### Bootstrap
137137

138138
```{julia}
139-
n_bootstrap = 1000
139+
n_bootstrap = 10
140140
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap.csv"))
141141
```
142142

docs/src/paper/experiments/_synthetic.qmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ end
269269
#### Bootstrap
270270

271271
```{julia}
272-
n_bootstrap = 100
272+
n_bootstrap = 10
273273
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap.csv"))
274274
```
275275

src/base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function run!(
6666
# Pre-allocate memory:
6767
output = [DataFrame() for i in 1:M]
6868

69-
p_fold = Progress(K; desc="Total Progress:", showspeed=true, enabled=show_progress, output=stderr, color=:yellow)
69+
p_fold = Progress(K; desc="Progress on folds:", showspeed=true, enabled=show_progress, output=stderr, color=:yellow)
7070
@info "Running experiment ..."
7171
for k in 1:K
7272
recourse_systems = experiment.recourse_systems[k]
@@ -96,7 +96,7 @@ function run!(
9696
output[m] = vcat(output[m], output_checkpoint, cols=:union)
9797
end
9898
end
99-
next!(p_round, showvalues=[(:Fold, k//K), (:Round, n//N)])
99+
next!(p_round, showvalues=[(:Fold, "$k/$K"), (:Round, "$n/$N")])
100100
end
101101
next!(p_fold)
102102
end

src/experiments/functions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ function update_experiment!(experiment::Experiment, recourse_system::RecourseSys
168168

169169
indices_ = rand(1:experiment.num_counterfactuals, length(results)) # randomly draw from generated counterfactuals
170170
X′ = reduce(hcat, @.(selectdim(counterfactual(results), 3, indices_)))
171-
y′ = reduce(hcat, @.(selectdim(counterfactual_label(results), 1, indices_)))
171+
y′ = reduce(hcat, @.(selectdim(counterfactual_label(results),3,indices_)))
172172

173173
X[:, chosen_individuals] = X′
174174
y[:, chosen_individuals] = y′

src/post_processing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ function run_bootstrap(
2929
df_.generator .= gen_name
3030
df_.fold .= fold
3131
df = vcat(df, df_)
32-
next!(p_sys, showvalues = [(:Model, model_name), (:Generator, gen_name), (:System, i//N)])
32+
next!(p_sys, showvalues = [(:Model, model_name), (:Generator, gen_name), (:System, "$i/$N")])
3333
end
34-
next!(p_fold, showvalues = [(:Fold, fold//n_folds)])
34+
next!(p_fold, showvalues = [(:Fold, "$fold/$n_folds")])
3535
end
3636
next!(p_total)
3737
end

0 commit comments

Comments
 (0)