Skip to content

Commit f318301

Browse files
committed
2 parents 9ebef92 + c1f236b commit f318301

8 files changed

Lines changed: 256 additions & 94 deletions

File tree

dev/notebooks/mitigation_strategies.qmd

Lines changed: 91 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -194,28 +194,41 @@ transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X
194194
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
195195
transform!(df_plot, :generator => (X -> [x=="Generic" ? "Generic (γ=0.5)" : x for x in X]) => :generator)
196196
transform!(df_plot, :generator => (X -> [x=="Generic_conservative" ? "Generic (γ=0.9)" : x for x in X]) => :generator)
197+
transform!(df_plot, :model => (X -> [x=="FluxEnsemble" ? "Deep Ensemble" : x for x in X]) => :model)
198+
transform!(df_plot, :model => (X -> [x=="FluxModel" ? "MLP" : x for x in X]) => :model)
199+
transform!(df_plot, :model => (X -> [x=="LogisticRegression" ? "Linear" : x for x in X]) => :model)
197200
198201
ncol = length(unique(df_plot.model))
199202
nrow = length(unique(df_plot.name))
200203
201204
using RCall
202-
scale_ = 2.2
205+
scale_ = 2.0
203206
R"""
207+
library(data.table)
208+
df_plot <- data.table($df_plot)
209+
model_order <- c("Linear", "MLP", "Deep Ensemble")
210+
df_plot[,model:=factor(model, levels=model_order)]
204211
library(ggplot2)
205212
plt <- ggplot($df_plot) +
206213
geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
207214
geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
208-
facet_wrap(name ~ model, scale="free_y", ncol=$ncol) +
215+
facet_grid(
216+
rows = vars(name),
217+
cols = vars(model),
218+
scales = "free_y"
219+
) +
209220
labs(y = "Value") +
210221
scale_fill_discrete(name="Generator:") +
211222
scale_colour_discrete(name="Generator:") +
212223
theme(
213224
axis.title.x=element_blank(),
214225
axis.text.x=element_blank(),
215-
axis.ticks.x=element_blank()
216-
)
226+
axis.ticks.x=element_blank(),
227+
legend.position="bottom"
228+
) +
229+
guides(fill=guide_legend(ncol=3))
217230
temp_path <- file.path(tempdir(), "plot.png")
218-
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.75)
231+
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8)
219232
"""
220233
221234
img = Images.load(rcopy(R"temp_path"))
@@ -288,28 +301,41 @@ transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X
288301
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
289302
transform!(df_plot, :generator => (X -> [x=="Latent" ? "Latent (γ=0.5)" : x for x in X]) => :generator)
290303
transform!(df_plot, :generator => (X -> [x=="Latent_conservative" ? "Latent (γ=0.9)" : x for x in X]) => :generator)
304+
transform!(df_plot, :model => (X -> [x=="FluxEnsemble" ? "Deep Ensemble" : x for x in X]) => :model)
305+
transform!(df_plot, :model => (X -> [x=="FluxModel" ? "MLP" : x for x in X]) => :model)
306+
transform!(df_plot, :model => (X -> [x=="LogisticRegression" ? "Linear" : x for x in X]) => :model)
291307
292308
ncol = length(unique(df_plot.model))
293309
nrow = length(unique(df_plot.name))
294310
295311
using RCall
296-
scale_ = 2.2
312+
scale_ = 1.9
297313
R"""
314+
library(data.table)
315+
df_plot <- data.table($df_plot)
316+
model_order <- c("Linear", "MLP", "Deep Ensemble")
317+
df_plot[,model:=factor(model, levels=model_order)]
298318
library(ggplot2)
299319
plt <- ggplot($df_plot) +
300320
geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
301321
geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
302-
facet_wrap(name ~ model, scale="free_y", ncol=$ncol) +
322+
facet_grid(
323+
rows = vars(name),
324+
cols = vars(model),
325+
scales = "free_y"
326+
) +
303327
labs(y = "Value") +
304328
scale_fill_discrete(name="Generator:") +
305329
scale_colour_discrete(name="Generator:") +
306330
theme(
307331
axis.title.x=element_blank(),
308332
axis.text.x=element_blank(),
309-
axis.ticks.x=element_blank()
310-
)
333+
axis.ticks.x=element_blank(),
334+
legend.position="bottom"
335+
) +
336+
guides(fill=guide_legend(ncol=4))
311337
temp_path <- file.path(tempdir(), "plot.png")
312-
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.75)
338+
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8)
313339
"""
314340
315341
img = Images.load(rcopy(R"temp_path"))
@@ -329,7 +355,7 @@ generators = Dict(
329355
```
330356

331357
```{julia}
332-
max_obs = 1000
358+
max_obs = 2500
333359
data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs)
334360
```
335361

@@ -358,6 +384,7 @@ n_rounds = 50
358384
evaluate_every = Int(round(n_rounds/n_evals))
359385
n_folds = 5
360386
n_bootstrap = 1
387+
n_samples = 10000
361388
T = 250
362389
using Serialization
363390
results = run_experiments(
@@ -391,68 +418,62 @@ for (data_name, res) in results
391418
end
392419
```
393420

394-
### Latent Space Search
395-
396-
```{julia}
397-
generators = Dict(
398-
:REVISE=>GenericGenerator(decision_threshold=0.5),
399-
:REVISE_conservative=>GenericGenerator(decision_threshold=0.9),
400-
:Gravitational=>GravitationalGenerator(),
401-
:ClapROAR=>ClapROARGenerator()
402-
)
403-
```
404-
405-
406-
```{julia}
407-
generative_model_params = (epochs=250, latent_dim=8)
408-
experiments = set_up_experiments(
409-
data_sets,models,generators;
410-
pre_train_models=100, model_params=model_params,
411-
data_loader=data_loader
412-
)
413-
```
421+
### Chart in paper
414422

415423
```{julia}
416-
n_evals = 5
417-
n_rounds = 50
418-
evaluate_every = Int(round(n_rounds/n_evals))
419-
n_folds = 5
420-
n_bootstrap = 1
421-
T = 200
422-
using Serialization
423-
results = run_experiments(
424-
experiments;
425-
save_path=output_path,
426-
save_name_suffix="latent",
427-
evaluate_every=evaluate_every,
428-
n_rounds=n_rounds,
429-
n_folds=n_folds,
430-
n_bootstrap=n_bootstrap,
431-
T=T,
432-
latent_space=true,
433-
generative_model_params=generative_model_params
434-
)
435-
Serialization.serialize(joinpath(output_path,"results_latent.jls"),results)
436-
```
424+
using DataFrames, Statistics
425+
model_ = :FluxEnsemble
426+
df = DataFrame()
427+
for (key, val) in results
428+
df_ = deepcopy(val.output)
429+
df_.dataset .= key
430+
df = vcat(df,df_)
431+
end
432+
df = df[df.n .== maximum(df.n),:]
433+
df = df[df.model .== model_,:]
434+
filter!(:value => x -> !any(f -> f(x), (ismissing, isnothing, isnan)), df)
435+
gdf = groupby(df, [:generator, :dataset, :n, :name, :scope])
436+
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
437+
df_plot = df_plot[[name in [:mmd, :model_performance] for name in df_plot.name],:]
438+
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.!=:model),:]
439+
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
440+
transform!(df_plot, :dataset => (X -> [x=="cal_housing" ? "California Housing" : x for x in X]) => :dataset)
441+
transform!(df_plot, :dataset => (X -> [x=="credit_default" ? "Credit Default" : x for x in X]) => :dataset)
442+
transform!(df_plot, :dataset => (X -> [x=="gmsc" ? "GMSC" : x for x in X]) => :dataset)
443+
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (model)" : x for x in X]) => :name)
444+
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
445+
transform!(df_plot, :generator => (X -> [x=="Generic" ? "Generic (γ=0.5)" : x for x in X]) => :generator)
446+
transform!(df_plot, :generator => (X -> [x=="Generic_conservative" ? "Generic (γ=0.9)" : x for x in X]) => :generator)
437447
438-
### Plots
448+
ncol = length(unique(df_plot.dataset))
449+
nrow = length(unique(df_plot.name))
439450
440-
```{julia}
441-
using Serialization
442-
results = Serialization.deserialize(joinpath(output_path,"results_real_world.jls"))
443-
```
451+
using RCall
452+
scale_ = 2.0
453+
R"""
454+
library(ggplot2)
455+
plt <- ggplot($df_plot) +
456+
geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
457+
geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
458+
facet_grid(
459+
rows = vars(name),
460+
cols = vars(dataset),
461+
scales = "free_y"
462+
) +
463+
labs(y = "Value") +
464+
scale_fill_discrete(name="Generator:") +
465+
scale_colour_discrete(name="Generator:") +
466+
theme(
467+
axis.title.x=element_blank(),
468+
axis.text.x=element_blank(),
469+
axis.ticks.x=element_blank(),
470+
legend.position="bottom"
471+
) +
472+
guides(fill=guide_legend(ncol=3))
473+
temp_path <- file.path(tempdir(), "plot.png")
474+
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.85)
475+
"""
444476
445-
```{julia}
446-
using Images
447-
line_charts = Dict()
448-
errorbar_charts = Dict()
449-
for (data_name, res) in results
450-
plt = plot(res)
451-
Images.save(joinpath(www_path, "line_chart_latent_$(data_name).png"), plt)
452-
line_charts[data_name] = plt
453-
plt = plot(res,maximum(res.output.n))
454-
Images.save(joinpath(www_path, "errorbar_chart_latent_$(data_name).png"), plt)
455-
errorbar_charts[data_name] = plt
456-
end
477+
img = Images.load(rcopy(R"temp_path"))
478+
Images.save(joinpath(www_path,"paper_real_world_results.png"), img)
457479
```
458-

dev/notebooks/real_world.qmd

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,14 @@ www_path = www_dir("real_world")
1717
```
1818

1919
```{julia}
20-
max_obs = 1000
20+
max_obs = 2500
2121
data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs)
22+
choices = [
23+
:cal_housing,
24+
:credit_default,
25+
:gmsc,
26+
]
27+
data_sets = filter(p -> p[1] in choices, data_sets)
2228
```
2329

2430
```{julia}
@@ -63,12 +69,13 @@ n_rounds = 50
6369
evaluate_every = Int(round(n_rounds/n_evals))
6470
n_folds = 5
6571
n_bootstrap = 1
72+
n_samples = 10000
6673
T = 250
6774
generative_model_params = (epochs=250, latent_dim=8)
6875
using Serialization
6976
results = run_experiments(
7077
experiments;
71-
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T,
78+
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T, n_samples=n_samples,
7279
generative_model_params=generative_model_params
7380
)
7481
Serialization.serialize(joinpath(output_path,"results.jls"),results)
@@ -93,11 +100,63 @@ for (data_name, res) in results
93100
end
94101
```
95102

96-
### Table in paper
103+
### Chart in paper
97104

98105
```{julia}
99-
using AlgorithmicRecourseDynamics: kable
100-
kable(results, [50])
106+
using DataFrames, Statistics
107+
model_ = :FluxEnsemble
108+
df = DataFrame()
109+
for (key, val) in results
110+
df_ = deepcopy(val.output)
111+
df_.dataset .= key
112+
df = vcat(df,df_)
113+
end
114+
df = df[df.n .== maximum(df.n),:]
115+
df = df[df.model .== model_,:]
116+
filter!(:value => x -> !any(f -> f(x), (ismissing, isnothing, isnan)), df)
117+
gdf = groupby(df, [:generator, :dataset, :n, :name, :scope])
118+
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
119+
df_plot = df_plot[[name in [:mmd, :model_performance] for name in df_plot.name],:]
120+
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
121+
df_plot.name .= [r[:name] == "mmd" ? "$(r[:name])_$(r[:scope])" : r[:name] for r in eachrow(df_plot)]
122+
transform!(df_plot, :dataset => (X -> [x=="cal_housing" ? "California Housing" : x for x in X]) => :dataset)
123+
transform!(df_plot, :dataset => (X -> [x=="credit_default" ? "Credit Default" : x for x in X]) => :dataset)
124+
transform!(df_plot, :dataset => (X -> [x=="gmsc" ? "GMSC" : x for x in X]) => :dataset)
125+
transform!(df_plot, :name => (X -> [x=="mmd_domain" ? "MMD (domain)" : x for x in X]) => :name)
126+
transform!(df_plot, :name => (X -> [x=="mmd_model" ? "MMD (model)" : x for x in X]) => :name)
127+
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
128+
transform!(df_plot, :generator => (X -> [x=="REVISE" ? "Latent" : x for x in X]) => :generator)
129+
130+
ncol = length(unique(df_plot.dataset))
131+
nrow = length(unique(df_plot.name))
132+
133+
using RCall
134+
scale_ = 1.75
135+
R"""
136+
library(ggplot2)
137+
plt <- ggplot($df_plot) +
138+
geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
139+
geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
140+
facet_grid(
141+
rows = vars(name),
142+
cols = vars(dataset),
143+
scales = "free_y"
144+
) +
145+
labs(y = "Value") +
146+
scale_fill_discrete(name="Generator:") +
147+
scale_colour_discrete(name="Generator:") +
148+
theme(
149+
axis.title.x=element_blank(),
150+
axis.text.x=element_blank(),
151+
axis.ticks.x=element_blank(),
152+
legend.position="bottom"
153+
)
154+
temp_path <- file.path(tempdir(), "plot.png")
155+
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8)
156+
"""
157+
158+
img = Images.load(rcopy(R"temp_path"))
159+
Images.save(joinpath(www_path,"paper_real_world_results.png"), img)
101160
```
102161

103162

0 commit comments

Comments
 (0)