|
128 | 128 | #### Bootstrap |
129 | 129 |
|
130 | 130 | ```{julia} |
131 | | -n_bootstrap = 1000 |
| 131 | +n_bootstrap = 10 |
132 | 132 | df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_synthetic.csv")) |
133 | 133 | ``` |
134 | 134 |
|
|
294 | 294 | #### Bootstrap |
295 | 295 |
|
296 | 296 | ```{julia} |
297 | | -n_bootstrap = 1000 |
| 297 | +n_bootstrap = 10 |
298 | 298 | df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_latent.csv")) |
299 | 299 | ``` |
300 | 300 |
|
@@ -364,10 +364,88 @@ Images.save(joinpath(www_path,"paper_synthetic_latent_results.png"), img) |
364 | 364 | Images.load(joinpath(www_path,"paper_synthetic_latent_results.png")) |
365 | 365 | ``` |
366 | 366 |
|
| 367 | +## Real World |
| 368 | + |
| 369 | +```{julia} |
| 370 | +generators = Dict( |
| 371 | + :Generic=>GenericGenerator(decision_threshold=0.5), |
| 372 | + :Latent=>REVISEGenerator(), |
| 373 | + :Generic_conservative=>GenericGenerator(decision_threshold=0.9), |
| 374 | + :Gravitational=>GravitationalGenerator(), |
| 375 | + :ClapROAR=>ClapROARGenerator() |
| 376 | +) |
| 377 | +``` |
| 378 | + |
| 379 | +```{julia} |
| 380 | +max_obs = 2500 |
| 381 | +data_path = data_dir("real_world") |
| 382 | +data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs; data_dir=data_path) |
| 383 | +choices = [ |
| 384 | + :cal_housing, |
| 385 | + :credit_default, |
| 386 | + :gmsc, |
| 387 | +] |
| 388 | +data_sets = filter(p -> p[1] in choices, data_sets) |
| 389 | +``` |
| 390 | + |
| 391 | +```{julia} |
| 392 | +using CounterfactualExplanations.DataPreprocessing: unpack |
| 393 | +bs = 500 |
| 394 | +function data_loader(data::CounterfactualData) |
| 395 | + X, y = unpack(data) |
| 396 | + data = Flux.DataLoader((X,y),batchsize=bs) |
| 397 | + return data |
| 398 | +end |
| 399 | +model_params = (batch_norm=false,n_hidden=64,n_layers=3,dropout=true,p_dropout=0.1) |
| 400 | +``` |
| 401 | + |
| 402 | +```{julia} |
| 403 | +experiments = set_up_experiments( |
| 404 | + data_sets,models,generators; |
| 405 | + pre_train_models=100, model_params=model_params, |
| 406 | + data_loader=data_loader |
| 407 | +) |
| 408 | +``` |
| 409 | + |
| 410 | +```{julia} |
| 411 | +n_evals = 5 |
| 412 | +n_rounds = 50 |
| 413 | +evaluate_every = Int(round(n_rounds/n_evals)) |
| 414 | +n_folds = 5 |
| 415 | +n_samples = 10000 |
| 416 | +T = 100 |
| 417 | +generative_model_params = (epochs=250, latent_dim=8) |
| 418 | +results = run_experiments( |
| 419 | + experiments; |
| 420 | + save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T, n_samples=n_samples, |
| 421 | + generative_model_params=generative_model_params |
| 422 | +) |
| 423 | +Serialization.serialize(joinpath(output_path,"results_real_world.jls"),results) |
| 424 | +``` |
| 425 | + |
| 426 | +```{julia} |
| 427 | +using Serialization |
| 428 | +results = Serialization.deserialize(joinpath(output_path,"results_real_world.jls")) |
| 429 | +``` |
| 430 | + |
| 431 | +```{julia} |
| 432 | +using Images |
| 433 | +line_charts = Dict() |
| 434 | +errorbar_charts = Dict() |
| 435 | +for (data_name, res) in results |
| 436 | + plt = plot(res) |
| 437 | + Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt) |
| 438 | + line_charts[data_name] = plt |
| 439 | + plt = plot(res,maximum(res.output.n)) |
| 440 | + Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt) |
| 441 | + errorbar_charts[data_name] = plt |
| 442 | +end |
| 443 | +``` |
| 444 | + |
367 | 445 | #### Bootstrap |
368 | 446 |
|
369 | 447 | ```{julia} |
370 | | -n_bootstrap = 1000 |
| 448 | +n_bootstrap = 10 |
371 | 449 | df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_real_world.csv")) |
372 | 450 | ``` |
373 | 451 |
|
|
0 commit comments