Skip to content

Commit abea5f6

Browse files
committed
updated kable function to produce latex table
1 parent 03e77c4 commit abea5f6

8 files changed

Lines changed: 279 additions & 172 deletions

File tree

dev/notebooks/mitigation_strategies.qmd

Lines changed: 176 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ models = [
2424
]
2525
generators = Dict(
2626
:Generic=>GenericGenerator(decision_threshold=0.5),
27-
:REVISE=>REVISEGenerator(),
27+
:Latent=>REVISEGenerator(),
2828
:Generic_conservative=>GenericGenerator(decision_threshold=0.9),
2929
:Gravitational=>GravitationalGenerator(),
3030
:ClapROAR=>ClapROARGenerator()
@@ -164,19 +164,170 @@ using Serialization
164164
results = Serialization.deserialize(joinpath(output_path,"results_synthetic.jls"))
165165
```
166166

167-
### Plots
167+
```{julia}
168+
using Images
169+
line_charts = Dict()
170+
errorbar_charts = Dict()
171+
for (data_name, res) in results
172+
plt = plot(res)
173+
Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
174+
line_charts[data_name] = plt
175+
plt = plot(res,maximum(res.output.n))
176+
Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
177+
errorbar_charts[data_name] = plt
178+
end
179+
```
180+
181+
### Chart in paper
182+
183+
```{julia}
184+
using DataFrames, Statistics
185+
df = results[:overlapping].output
186+
df = df[df.n .== maximum(df.n),:]
187+
gdf = groupby(df, [:generator, :model, :n, :name, :scope])
188+
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
189+
df_plot = df_plot[[name in [:mmd, :mmd_grid, :model_performance] for name in df_plot.name],:]
190+
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.==:model),:]
191+
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
192+
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (domain)" : x for x in X]) => :name)
193+
transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X]) => :name)
194+
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
195+
transform!(df_plot, :generator => (X -> [x=="Generic" ? "Generic (γ=0.5)" : x for x in X]) => :generator)
196+
transform!(df_plot, :generator => (X -> [x=="Generic_conservative" ? "Generic (γ=0.9)" : x for x in X]) => :generator)
197+
198+
ncol = length(unique(df_plot.model))
199+
nrow = length(unique(df_plot.name))
200+
201+
using RCall
202+
scale_ = 2.2
203+
R"""
204+
library(ggplot2)
205+
plt <- ggplot($df_plot) +
206+
geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
207+
geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
208+
facet_wrap(name ~ model, scale="free_y", ncol=$ncol) +
209+
labs(y = "Value") +
210+
scale_fill_discrete(name="Generator:") +
211+
scale_colour_discrete(name="Generator:") +
212+
theme(
213+
axis.title.x=element_blank(),
214+
axis.text.x=element_blank(),
215+
axis.ticks.x=element_blank()
216+
)
217+
temp_path <- file.path(tempdir(), "plot.png")
218+
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.75)
219+
"""
220+
221+
img = Images.load(rcopy(R"temp_path"))
222+
Images.save(joinpath(www_path,"paper_synthetic_results.png"), img)
223+
```
224+
225+
226+
### Latent Space Search
168227

169228
```{julia}
229+
generators = Dict(
230+
:Latent=>GenericGenerator(decision_threshold=0.5),
231+
:Latent_conservative=>GenericGenerator(decision_threshold=0.9),
232+
:Gravitational=>GravitationalGenerator(),
233+
:ClapROAR=>ClapROARGenerator()
234+
)
235+
```
236+
237+
```{julia}
238+
experiments = set_up_experiments(data_sets,models,generators)
239+
```
240+
241+
```{julia}
242+
n_evals = 5
243+
n_rounds = 50
244+
evaluate_every = Int(round(n_rounds/n_evals))
245+
n_folds = 5
246+
n_bootstrap = 1
247+
T = 100
248+
using Serialization
249+
results = run_experiments(
250+
experiments;
251+
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T
252+
)
253+
Serialization.serialize(joinpath(output_path,"results_synthetic_latent.jls"),results)
254+
```
255+
256+
```{julia}
257+
using Serialization
258+
results = Serialization.deserialize(joinpath(output_path,"results_synthetic_latent.jls"))
259+
```
260+
261+
```{julia}
262+
using Images
170263
line_charts = Dict()
171264
errorbar_charts = Dict()
172265
for (data_name, res) in results
173-
line_charts[data_name] = plot(res)
174-
errorbar_charts[data_name] = plot(res,50)
266+
plt = plot(res)
267+
Images.save(joinpath(www_path, "line_chart_latent_$(data_name).png"), plt)
268+
line_charts[data_name] = plt
269+
plt = plot(res,maximum(res.output.n))
270+
Images.save(joinpath(www_path, "errorbar_chart_latent_$(data_name).png"), plt)
271+
errorbar_charts[data_name] = plt
175272
end
176273
```
177274

275+
### Chart in paper
276+
277+
```{julia}
278+
using DataFrames, Statistics
279+
df = results[:overlapping].output
280+
df = df[df.n .== maximum(df.n),:]
281+
gdf = groupby(df, [:generator, :model, :n, :name, :scope])
282+
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
283+
df_plot = df_plot[[name in [:mmd, :mmd_grid, :model_performance] for name in df_plot.name],:]
284+
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.==:model),:]
285+
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
286+
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (domain)" : x for x in X]) => :name)
287+
transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X]) => :name)
288+
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
289+
transform!(df_plot, :generator => (X -> [x=="Latent" ? "Latent (γ=0.5)" : x for x in X]) => :generator)
290+
transform!(df_plot, :generator => (X -> [x=="Latent_conservative" ? "Latent (γ=0.9)" : x for x in X]) => :generator)
291+
292+
ncol = length(unique(df_plot.model))
293+
nrow = length(unique(df_plot.name))
294+
295+
using RCall
296+
scale_ = 2.2
297+
R"""
298+
library(ggplot2)
299+
plt <- ggplot($df_plot) +
300+
geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
301+
geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
302+
facet_wrap(name ~ model, scale="free_y", ncol=$ncol) +
303+
labs(y = "Value") +
304+
scale_fill_discrete(name="Generator:") +
305+
scale_colour_discrete(name="Generator:") +
306+
theme(
307+
axis.title.x=element_blank(),
308+
axis.text.x=element_blank(),
309+
axis.ticks.x=element_blank()
310+
)
311+
temp_path <- file.path(tempdir(), "plot.png")
312+
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.75)
313+
"""
314+
315+
img = Images.load(rcopy(R"temp_path"))
316+
Images.save(joinpath(www_path,"paper_synthetic_latent_results.png"), img)
317+
```
318+
178319
## Real World
179320

321+
```{julia}
322+
generators = Dict(
323+
:Generic=>GenericGenerator(decision_threshold=0.5),
324+
:Latent=>REVISEGenerator(),
325+
:Generic_conservative=>GenericGenerator(decision_threshold=0.9),
326+
:Gravitational=>GravitationalGenerator(),
327+
:ClapROAR=>ClapROARGenerator()
328+
)
329+
```
330+
180331
```{julia}
181332
max_obs = 1000
182333
data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs)
@@ -207,7 +358,7 @@ n_rounds = 50
207358
evaluate_every = Int(round(n_rounds/n_evals))
208359
n_folds = 5
209360
n_bootstrap = 1
210-
T = 100
361+
T = 250
211362
using Serialization
212363
results = run_experiments(
213364
experiments;
@@ -227,11 +378,16 @@ results = Serialization.deserialize(joinpath(output_path,"results_real_world.jls
227378
```
228379

229380
```{julia}
381+
using Images
230382
line_charts = Dict()
231383
errorbar_charts = Dict()
232384
for (data_name, res) in results
233-
line_charts[data_name] = plot(res)
234-
errorbar_charts[data_name] = plot(res,50)
385+
plt = plot(res)
386+
Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
387+
line_charts[data_name] = plt
388+
plt = plot(res,maximum(res.output.n))
389+
Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
390+
errorbar_charts[data_name] = plt
235391
end
236392
```
237393

@@ -262,7 +418,7 @@ n_rounds = 50
262418
evaluate_every = Int(round(n_rounds/n_evals))
263419
n_folds = 5
264420
n_bootstrap = 1
265-
T = 100
421+
T = 200
266422
using Serialization
267423
results = run_experiments(
268424
experiments;
@@ -282,11 +438,21 @@ Serialization.serialize(joinpath(output_path,"results_latent.jls"),results)
282438
### Plots
283439

284440
```{julia}
441+
using Serialization
442+
results = Serialization.deserialize(joinpath(output_path,"results_real_world.jls"))
443+
```
444+
445+
```{julia}
446+
using Images
285447
line_charts = Dict()
286448
errorbar_charts = Dict()
287449
for (data_name, res) in results
288-
line_charts[data_name] = plot(res)
289-
errorbar_charts[data_name] = plot(res,50)
450+
plt = plot(res)
451+
Images.save(joinpath(www_path, "line_chart_latent_$(data_name).png"), plt)
452+
line_charts[data_name] = plt
453+
plt = plot(res,maximum(res.output.n))
454+
Images.save(joinpath(www_path, "errorbar_chart_latent_$(data_name).png"), plt)
455+
errorbar_charts[data_name] = plt
290456
end
291457
```
292458

dev/notebooks/real_world.qmd

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,15 @@ experiments = set_up_experiments(
5555
)
5656
```
5757

58-
## Strict convergence
58+
## Running Experiment
5959

6060
```{julia}
6161
n_evals = 5
6262
n_rounds = 50
6363
evaluate_every = Int(round(n_rounds/n_evals))
6464
n_folds = 5
6565
n_bootstrap = 1
66-
T = 100
66+
T = 250
6767
generative_model_params = (epochs=250, latent_dim=8)
6868
using Serialization
6969
results = run_experiments(
@@ -80,13 +80,26 @@ results = Serialization.deserialize(joinpath(output_path,"results.jls"))
8080
```
8181

8282
```{julia}
83+
using Images
8384
line_charts = Dict()
8485
errorbar_charts = Dict()
8586
for (data_name, res) in results
86-
line_charts[data_name] = plot(res)
87-
errorbar_charts[data_name] = plot(res,50)
87+
plt = plot(res)
88+
Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
89+
line_charts[data_name] = plt
90+
plt = plot(res,maximum(res.output.n))
91+
Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
92+
errorbar_charts[data_name] = plt
8893
end
8994
```
9095

96+
### Table in paper
97+
98+
```{julia}
99+
using AlgorithmicRecourseDynamics: kable
100+
kable(results, [50])
101+
```
102+
103+
91104

92105

dev/notebooks/results.qmd

Whitespace-only changes.

dev/notebooks/synthetic.qmd

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,9 @@ T = 100
9494
using Serialization
9595
results = run_experiments(
9696
experiments;
97-
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T,
98-
convergence=:strict, save_name_suffix="strict"
97+
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, n_bootstrap=n_bootstrap, T=T
9998
)
100-
Serialization.serialize(joinpath(output_path,"results_strict.jls"),results)
99+
Serialization.serialize(joinpath(output_path,"results.jls"),results)
101100
```
102101

103102
```{julia}
@@ -155,26 +154,71 @@ end
155154
plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
156155
for (name, plts) in plot_dict
157156
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
158-
savefig(plt, joinpath(www_path,"models_train_after_$(name)_strict.png"))
157+
savefig(plt, joinpath(www_path,"models_train_after_$(name).png"))
159158
end
160159
```
161160

162161
```{julia}
163162
using Serialization
164-
results = Serialization.deserialize(joinpath(output_path,"results_strict.jls"))
163+
results = Serialization.deserialize(joinpath(output_path,"results.jls"))
165164
```
166165

167166
### Plots
168167

169168
```{julia}
169+
using Images
170170
line_charts = Dict()
171171
errorbar_charts = Dict()
172172
for (data_name, res) in results
173173
plt = plot(res)
174-
savefig(plt, joinpath(www_path, "line_chart_$(data_name).png"))
174+
Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
175175
line_charts[data_name] = plt
176176
plt = plot(res,maximum(res.output.n))
177-
savefig(plt, joinpath(www_path, "errorbar_chart_$(data_name).png"))
177+
Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
178178
errorbar_charts[data_name] = plt
179179
end
180180
```
181+
182+
### Chart in paper
183+
184+
```{julia}
185+
using DataFrames, Statistics
186+
df = results[:overlapping].output
187+
df = df[df.n .== maximum(df.n),:]
188+
gdf = groupby(df, [:generator, :model, :n, :name, :scope])
189+
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
190+
df_plot = df_plot[[name in [:mmd, :mmd_grid, :model_performance] for name in df_plot.name],:]
191+
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.==:model),:]
192+
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
193+
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (domain)" : x for x in X]) => :name)
194+
transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X]) => :name)
195+
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
196+
197+
ncol = length(unique(df_plot.model))
198+
nrow = length(unique(df_plot.name))
199+
200+
using RCall
201+
scale_ = 2
202+
R"""
203+
library(ggplot2)
204+
plt <- ggplot($df_plot) +
205+
geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
206+
geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
207+
facet_wrap(name ~ model, scale="free_y", ncol=$ncol) +
208+
labs(y = "Value") +
209+
scale_fill_discrete(name="Generator:") +
210+
scale_colour_discrete(name="Generator:") +
211+
theme(
212+
axis.title.x=element_blank(),
213+
axis.text.x=element_blank(),
214+
axis.ticks.x=element_blank()
215+
)
216+
temp_path <- file.path(tempdir(), "plot.png")
217+
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8)
218+
"""
219+
220+
img = Images.load(rcopy(R"temp_path"))
221+
Images.save(joinpath(www_path,"paper_synthetic_results.png"), img)
222+
```
223+
224+

0 commit comments

Comments
 (0)