@@ -194,28 +194,41 @@ transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X
194194transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
195195transform!(df_plot, :generator => (X -> [x=="Generic" ? "Generic (γ=0.5)" : x for x in X]) => :generator)
196196transform!(df_plot, :generator => (X -> [x=="Generic_conservative" ? "Generic (γ=0.9)" : x for x in X]) => :generator)
197+ transform!(df_plot, :model => (X -> [x=="FluxEnsemble" ? "Deep Ensemble" : x for x in X]) => :model)
198+ transform!(df_plot, :model => (X -> [x=="FluxModel" ? "MLP" : x for x in X]) => :model)
199+ transform!(df_plot, :model => (X -> [x=="LogisticRegression" ? "Linear" : x for x in X]) => :model)
197200
198201ncol = length(unique(df_plot.model))
199202nrow = length(unique(df_plot.name))
200203
201204using RCall
202- scale_ = 2.2
205+ scale_ = 2.0
203206R"""
207+ library(data.table)
208+ df_plot <- data.table($df_plot)
209+ model_order <- c("Linear", "MLP", "Deep Ensemble")
210+ df_plot[,model:=factor(model, levels=model_order)]
204211library(ggplot2)
205212plt <- ggplot($df_plot) +
206213 geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
207214 geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
208- facet_wrap(name ~ model, scale="free_y", ncol=$ncol) +
215+ facet_grid(
216+ rows = vars(name),
217+ cols = vars(model),
218+ scales = "free_y"
219+ ) +
209220 labs(y = "Value") +
210221 scale_fill_discrete(name="Generator:") +
211222 scale_colour_discrete(name="Generator:") +
212223 theme(
213224 axis.title.x=element_blank(),
214225 axis.text.x=element_blank(),
215- axis.ticks.x=element_blank()
216- )
226+ axis.ticks.x=element_blank(),
227+ legend.position="bottom"
228+ ) +
229+ guides(fill=guide_legend(ncol=3))
217230temp_path <- file.path(tempdir(), "plot.png")
218- ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.75 )
231+ ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8 )
219232"""
220233
221234img = Images.load(rcopy(R"temp_path"))
@@ -288,28 +301,41 @@ transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X
288301transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
289302transform!(df_plot, :generator => (X -> [x=="Latent" ? "Latent (γ=0.5)" : x for x in X]) => :generator)
290303transform!(df_plot, :generator => (X -> [x=="Latent_conservative" ? "Latent (γ=0.9)" : x for x in X]) => :generator)
304+ transform!(df_plot, :model => (X -> [x=="FluxEnsemble" ? "Deep Ensemble" : x for x in X]) => :model)
305+ transform!(df_plot, :model => (X -> [x=="FluxModel" ? "MLP" : x for x in X]) => :model)
306+ transform!(df_plot, :model => (X -> [x=="LogisticRegression" ? "Linear" : x for x in X]) => :model)
291307
292308ncol = length(unique(df_plot.model))
293309nrow = length(unique(df_plot.name))
294310
295311using RCall
296- scale_ = 2.2
312+ scale_ = 1.9
297313R"""
314+ library(data.table)
315+ df_plot <- data.table($df_plot)
316+ model_order <- c("Linear", "MLP", "Deep Ensemble")
317+ df_plot[,model:=factor(model, levels=model_order)]
298318library(ggplot2)
299319plt <- ggplot($df_plot) +
300320 geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
301321 geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
302- facet_wrap(name ~ model, scale="free_y", ncol=$ncol) +
322+ facet_grid(
323+ rows = vars(name),
324+ cols = vars(model),
325+ scales = "free_y"
326+ ) +
303327 labs(y = "Value") +
304328 scale_fill_discrete(name="Generator:") +
305329 scale_colour_discrete(name="Generator:") +
306330 theme(
307331 axis.title.x=element_blank(),
308332 axis.text.x=element_blank(),
309- axis.ticks.x=element_blank()
310- )
333+ axis.ticks.x=element_blank(),
334+ legend.position="bottom"
335+ ) +
336+ guides(fill=guide_legend(ncol=4))
311337temp_path <- file.path(tempdir(), "plot.png")
312- ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.75 )
338+ ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8 )
313339"""
314340
315341img = Images.load(rcopy(R"temp_path"))
@@ -329,7 +355,7 @@ generators = Dict(
329355```
330356
331357``` {julia}
332- max_obs = 1000
358+ max_obs = 2500
333359data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs)
334360```
335361
@@ -358,6 +384,7 @@ n_rounds = 50
358384evaluate_every = Int(round(n_rounds/n_evals))
359385n_folds = 5
360386n_bootstrap = 1
387+ n_samples = 10000
361388T = 250
362389using Serialization
363390results = run_experiments(
@@ -391,68 +418,62 @@ for (data_name, res) in results
391418end
392419```
393420
394- ### Latent Space Search
395-
396- ``` {julia}
397- generators = Dict(
398- :REVISE=>GenericGenerator(decision_threshold=0.5),
399- :REVISE_conservative=>GenericGenerator(decision_threshold=0.9),
400- :Gravitational=>GravitationalGenerator(),
401- :ClapROAR=>ClapROARGenerator()
402- )
403- ```
404-
405-
406- ``` {julia}
407- generative_model_params = (epochs=250, latent_dim=8)
408- experiments = set_up_experiments(
409- data_sets,models,generators;
410- pre_train_models=100, model_params=model_params,
411- data_loader=data_loader
412- )
413- ```
421+ ### Chart in paper
414422
415423``` {julia}
416- n_evals = 5
417- n_rounds = 50
418- evaluate_every = Int(round(n_rounds/n_evals))
419- n_folds = 5
420- n_bootstrap = 1
421- T = 200
422- using Serialization
423- results = run_experiments(
424- experiments;
425- save_path=output_path,
426- save_name_suffix="latent",
427- evaluate_every=evaluate_every,
428- n_rounds=n_rounds,
429- n_folds=n_folds,
430- n_bootstrap=n_bootstrap,
431- T=T,
432- latent_space=true,
433- generative_model_params=generative_model_params
434- )
435- Serialization.serialize(joinpath(output_path,"results_latent.jls"),results)
436- ```
424+ using DataFrames, Statistics
425+ model_ = :FluxEnsemble
426+ df = DataFrame()
427+ for (key, val) in results
428+ df_ = deepcopy(val.output)
429+ df_.dataset .= key
430+ df = vcat(df,df_)
431+ end
432+ df = df[df.n .== maximum(df.n),:]
433+ df = df[df.model .== model_,:]
434+ filter!(:value => x -> !any(f -> f(x), (ismissing, isnothing, isnan)), df)
435+ gdf = groupby(df, [:generator, :dataset, :n, :name, :scope])
436+ df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
437+ df_plot = df_plot[[name in [:mmd, :model_performance] for name in df_plot.name],:]
438+ df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.!=:model),:]
439+ df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
440+ transform!(df_plot, :dataset => (X -> [x=="cal_housing" ? "California Housing" : x for x in X]) => :dataset)
441+ transform!(df_plot, :dataset => (X -> [x=="credit_default" ? "Credit Default" : x for x in X]) => :dataset)
442+ transform!(df_plot, :dataset => (X -> [x=="gmsc" ? "GMSC" : x for x in X]) => :dataset)
443+ transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (model)" : x for x in X]) => :name)
444+ transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
445+ transform!(df_plot, :generator => (X -> [x=="Generic" ? "Generic (γ=0.5)" : x for x in X]) => :generator)
446+ transform!(df_plot, :generator => (X -> [x=="Generic_conservative" ? "Generic (γ=0.9)" : x for x in X]) => :generator)
437447
438- ### Plots
448+ ncol = length(unique(df_plot.dataset))
449+ nrow = length(unique(df_plot.name))
439450
440- ``` {julia}
441- using Serialization
442- results = Serialization.deserialize(joinpath(output_path,"results_real_world.jls"))
443- ```
451+ using RCall
452+ scale_ = 2.0
453+ R"""
454+ library(ggplot2)
455+ plt <- ggplot($df_plot) +
456+ geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
457+ geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
458+ facet_grid(
459+ rows = vars(name),
460+ cols = vars(dataset),
461+ scales = "free_y"
462+ ) +
463+ labs(y = "Value") +
464+ scale_fill_discrete(name="Generator:") +
465+ scale_colour_discrete(name="Generator:") +
466+ theme(
467+ axis.title.x=element_blank(),
468+ axis.text.x=element_blank(),
469+ axis.ticks.x=element_blank(),
470+ legend.position="bottom"
471+ ) +
472+ guides(fill=guide_legend(ncol=3))
473+ temp_path <- file.path(tempdir(), "plot.png")
474+ ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.85)
475+ """
444476
445- ``` {julia}
446- using Images
447- line_charts = Dict()
448- errorbar_charts = Dict()
449- for (data_name, res) in results
450- plt = plot(res)
451- Images.save(joinpath(www_path, "line_chart_latent_$(data_name).png"), plt)
452- line_charts[data_name] = plt
453- plt = plot(res,maximum(res.output.n))
454- Images.save(joinpath(www_path, "errorbar_chart_latent_$(data_name).png"), plt)
455- errorbar_charts[data_name] = plt
456- end
477+ img = Images.load(rcopy(R"temp_path"))
478+ Images.save(joinpath(www_path,"paper_real_world_results.png"), img)
457479```
458-
0 commit comments