@@ -24,7 +24,7 @@ models = [
2424]
2525generators = Dict(
2626 :Generic=>GenericGenerator(decision_threshold=0.5),
27- :REVISE =>REVISEGenerator(),
27+ :Latent =>REVISEGenerator(),
2828 :Generic_conservative=>GenericGenerator(decision_threshold=0.9),
2929 :Gravitational=>GravitationalGenerator(),
3030 :ClapROAR=>ClapROARGenerator()
@@ -164,19 +164,170 @@ using Serialization
164164results = Serialization.deserialize(joinpath(output_path,"results_synthetic.jls"))
165165```
166166
167- ### Plots
167+ ``` {julia}
168+ using Images
169+ line_charts = Dict()
170+ errorbar_charts = Dict()
171+ for (data_name, res) in results
172+ plt = plot(res)
173+ Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
174+ line_charts[data_name] = plt
175+ plt = plot(res,maximum(res.output.n))
176+ Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
177+ errorbar_charts[data_name] = plt
178+ end
179+ ```
180+
181+ ### Chart in paper
182+
183+ ``` {julia}
184+ using DataFrames, Statistics
185+ df = results[:overlapping].output
186+ df = df[df.n .== maximum(df.n),:]
187+ gdf = groupby(df, [:generator, :model, :n, :name, :scope])
188+ df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
189+ df_plot = df_plot[[name in [:mmd, :mmd_grid, :model_performance] for name in df_plot.name],:]
190+ df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.==:model),:]
191+ df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
192+ transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (domain)" : x for x in X]) => :name)
193+ transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X]) => :name)
194+ transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
195+ transform!(df_plot, :generator => (X -> [x=="Generic" ? "Generic (γ=0.5)" : x for x in X]) => :generator)
196+ transform!(df_plot, :generator => (X -> [x=="Generic_conservative" ? "Generic (γ=0.9)" : x for x in X]) => :generator)
197+
198+ ncol = length(unique(df_plot.model))
199+ nrow = length(unique(df_plot.name))
200+
201+ using RCall
202+ scale_ = 2.2
203+ R"""
204+ library(ggplot2)
205+ plt <- ggplot($df_plot) +
206+ geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
207+ geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
208+ facet_wrap(name ~ model, scale="free_y", ncol=$ncol) +
209+ labs(y = "Value") +
210+ scale_fill_discrete(name="Generator:") +
211+ scale_colour_discrete(name="Generator:") +
212+ theme(
213+ axis.title.x=element_blank(),
214+ axis.text.x=element_blank(),
215+ axis.ticks.x=element_blank()
216+ )
217+ temp_path <- file.path(tempdir(), "plot.png")
218+ ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.75)
219+ """
220+
221+ img = Images.load(rcopy(R"temp_path"))
222+ Images.save(joinpath(www_path,"paper_synthetic_results.png"), img)
223+ ```
224+
225+
226+ ### Latent Space Search
168227
169228``` {julia}
229+ generators = Dict(
230+ :Latent=>GenericGenerator(decision_threshold=0.5),
231+ :Latent_conservative=>GenericGenerator(decision_threshold=0.9),
232+ :Gravitational=>GravitationalGenerator(),
233+ :ClapROAR=>ClapROARGenerator()
234+ )
235+ ```
236+
237+ ``` {julia}
238+ experiments = set_up_experiments(data_sets,models,generators)
239+ ```
240+
241+ ``` {julia}
242+ n_evals = 5
243+ n_rounds = 50
244+ evaluate_every = Int(round(n_rounds/n_evals))
245+ n_folds = 5
246+ n_bootstrap = 1
247+ T = 100
248+ using Serialization
249+ results = run_experiments(
250+ experiments;
251+ save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T
252+ )
253+ Serialization.serialize(joinpath(output_path,"results_synthetic_latent.jls"),results)
254+ ```
255+
256+ ``` {julia}
257+ using Serialization
258+ results = Serialization.deserialize(joinpath(output_path,"results_synthetic_latent.jls"))
259+ ```
260+
261+ ``` {julia}
262+ using Images
170263line_charts = Dict()
171264errorbar_charts = Dict()
172265for (data_name, res) in results
173- line_charts[data_name] = plot(res)
174- errorbar_charts[data_name] = plot(res,50)
266+ plt = plot(res)
267+ Images.save(joinpath(www_path, "line_chart_latent_$(data_name).png"), plt)
268+ line_charts[data_name] = plt
269+ plt = plot(res,maximum(res.output.n))
270+ Images.save(joinpath(www_path, "errorbar_chart_latent_$(data_name).png"), plt)
271+ errorbar_charts[data_name] = plt
175272end
176273```
177274
275+ ### Chart in paper
276+
277+ ``` {julia}
278+ using DataFrames, Statistics
279+ df = results[:overlapping].output
280+ df = df[df.n .== maximum(df.n),:]
281+ gdf = groupby(df, [:generator, :model, :n, :name, :scope])
282+ df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
283+ df_plot = df_plot[[name in [:mmd, :mmd_grid, :model_performance] for name in df_plot.name],:]
284+ df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.==:model),:]
285+ df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
286+ transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (domain)" : x for x in X]) => :name)
287+ transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X]) => :name)
288+ transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
289+ transform!(df_plot, :generator => (X -> [x=="Latent" ? "Latent (γ=0.5)" : x for x in X]) => :generator)
290+ transform!(df_plot, :generator => (X -> [x=="Latent_conservative" ? "Latent (γ=0.9)" : x for x in X]) => :generator)
291+
292+ ncol = length(unique(df_plot.model))
293+ nrow = length(unique(df_plot.name))
294+
295+ using RCall
296+ scale_ = 2.2
297+ R"""
298+ library(ggplot2)
299+ plt <- ggplot($df_plot) +
300+ geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
301+ geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
302+ facet_wrap(name ~ model, scale="free_y", ncol=$ncol) +
303+ labs(y = "Value") +
304+ scale_fill_discrete(name="Generator:") +
305+ scale_colour_discrete(name="Generator:") +
306+ theme(
307+ axis.title.x=element_blank(),
308+ axis.text.x=element_blank(),
309+ axis.ticks.x=element_blank()
310+ )
311+ temp_path <- file.path(tempdir(), "plot.png")
312+ ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.75)
313+ """
314+
315+ img = Images.load(rcopy(R"temp_path"))
316+ Images.save(joinpath(www_path,"paper_synthetic_latent_results.png"), img)
317+ ```
318+
178319## Real World
179320
321+ ``` {julia}
322+ generators = Dict(
323+ :Generic=>GenericGenerator(decision_threshold=0.5),
324+ :Latent=>REVISEGenerator(),
325+ :Generic_conservative=>GenericGenerator(decision_threshold=0.9),
326+ :Gravitational=>GravitationalGenerator(),
327+ :ClapROAR=>ClapROARGenerator()
328+ )
329+ ```
330+
180331``` {julia}
181332max_obs = 1000
182333data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs)
@@ -207,7 +358,7 @@ n_rounds = 50
207358evaluate_every = Int(round(n_rounds/n_evals))
208359n_folds = 5
209360n_bootstrap = 1
210- T = 100
361+ T = 250
211362using Serialization
212363results = run_experiments(
213364 experiments;
@@ -227,11 +378,16 @@ results = Serialization.deserialize(joinpath(output_path,"results_real_world.jls
227378```
228379
229380``` {julia}
381+ using Images
230382line_charts = Dict()
231383errorbar_charts = Dict()
232384for (data_name, res) in results
233- line_charts[data_name] = plot(res)
234- errorbar_charts[data_name] = plot(res,50)
385+ plt = plot(res)
386+ Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
387+ line_charts[data_name] = plt
388+ plt = plot(res,maximum(res.output.n))
389+ Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
390+ errorbar_charts[data_name] = plt
235391end
236392```
237393
@@ -262,7 +418,7 @@ n_rounds = 50
262418evaluate_every = Int(round(n_rounds/n_evals))
263419n_folds = 5
264420n_bootstrap = 1
265- T = 100
421+ T = 200
266422using Serialization
267423results = run_experiments(
268424 experiments;
@@ -282,11 +438,21 @@ Serialization.serialize(joinpath(output_path,"results_latent.jls"),results)
282438### Plots
283439
284440``` {julia}
441+ using Serialization
442+ results = Serialization.deserialize(joinpath(output_path,"results_real_world.jls"))
443+ ```
444+
445+ ``` {julia}
446+ using Images
285447line_charts = Dict()
286448errorbar_charts = Dict()
287449for (data_name, res) in results
288- line_charts[data_name] = plot(res)
289- errorbar_charts[data_name] = plot(res,50)
450+ plt = plot(res)
451+ Images.save(joinpath(www_path, "line_chart_latent_$(data_name).png"), plt)
452+ line_charts[data_name] = plt
453+ plt = plot(res,maximum(res.output.n))
454+ Images.save(joinpath(www_path, "errorbar_chart_latent_$(data_name).png"), plt)
455+ errorbar_charts[data_name] = plt
290456end
291457```
292458
0 commit comments