Skip to content

Commit f904091

Browse files
committed
working on models for real-world data
1 parent b9e32e3 commit f904091

9 files changed

Lines changed: 150 additions & 61 deletions

File tree

docs/src/paper/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
77
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
88
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
99
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
10+
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
1011
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1112
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1213
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"

docs/src/paper/data_preprocessing/_real_world_data.qmd

Lines changed: 103 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,70 +27,103 @@ if not os.path.isdir(os.path.join(data_path,"raw")):
2727
df.to_csv(os.path.join(data_path,"raw/cal_housing.csv"), index=False)
2828
```
2929

30-
Loading the data into Julia session:
30+
Loading the data into Julia session.
3131

3232
```{julia}
3333
df = CSV.read(joinpath(data_path, "raw/cal_housing.csv"), DataFrame)
34-
# Features:
35-
X = Matrix(df[:,Not(:target)])
36-
dt = StatsBase.fit(ZScoreTransform, X, dims=1)
37-
StatsBase.transform!(dt, X)
34+
# # Features:
35+
# X = Matrix(df[:,Not(:target)])
36+
# dt = StatsBase.fit(ZScoreTransform, X, dims=1)
37+
# StatsBase.transform!(dt, X)
38+
# df = DataFrame(X,:auto)
3839
# Target:
3940
y = df.target
40-
y = Float64.(y .>= median(y)); # binary target
41+
y = Float64.(y .>= median(y)); # binary target (positive outcome)
4142
# Data:
42-
df = DataFrame(X,:auto)
4343
df.target = y
4444
```
4545

46+
Random undersampling to balance the data:
47+
4648
```{julia}
47-
using MLUtils: undersample
48-
# Make DataFrames.jl work
49-
MLUtils.getobs(data::DataFrame, i) = data[i,:]
50-
MLUtils.numobs(data::DataFrame) = nrow(data)
5149
df_balanced = getobs(undersample(df, df.target;shuffle=true))[1]
5250
```
5351

52+
All features are continuous:
53+
54+
```{julia}
55+
schema(df_balanced)
56+
```
57+
58+
Turning the data into `CounterfactualData`:
59+
5460
```{julia}
55-
CSV.write(joinpath(data_path, "cal_housing.csv"), df_balanced)
61+
X = Matrix(df_balanced[:,Not(:target)])
62+
X = permutedims(X)
63+
y = permutedims(df_balanced.target)
64+
data = CounterfactualData(X,y)
65+
```
66+
67+
Saving the data:
68+
69+
```{julia}
70+
CSV.write(joinpath(data_path, "cal_housing.csv"), df_balanced) # binary file
71+
Serialization.serialize(joinpath(data_path,"cal_housing.jls"), data) # CounterfactualData
5672
```
5773

5874

5975
## Give Me Some Credit
6076

77+
Loading and basic preprocessing:
78+
6179
```{julia}
6280
df = CSV.read(joinpath(data_path, "raw/cs-training.csv"), DataFrame)
6381
select!(df, Not([:Column1]))
6482
rename!(df, :SeriousDlqin2yrs => :target)
6583
mapcols!(x -> [ifelse(x_=="NA", missing, x_) for x_ in x], df)
6684
dropmissing!(df)
6785
mapcols!(x -> eltype(x) <: AbstractString ? parse.(Int, x) : x, df)
68-
# Features:
69-
X = Matrix(df[:,Not(:target)])
70-
dt = StatsBase.fit(ZScoreTransform, X, dims=1)
71-
StatsBase.transform!(dt, X)
86+
# # Features:
87+
# X = Matrix(df[:,Not(:target)])
88+
# dt = StatsBase.fit(ZScoreTransform, X, dims=1)
89+
# StatsBase.transform!(dt, X)
90+
# df = DataFrame(X,:auto)
7291
# Target:
73-
y = df.target
74-
# Data:
75-
df = DataFrame(X,:auto)
76-
df.target = y
92+
df.target .= map(y -> y == 0 ? 1 : 0, df.target) # postive outcome = no delinquency
7793
```
7894

95+
Balancing:
96+
7997
```{julia}
80-
using MLUtils
81-
using MLUtils: undersample
82-
# Make DataFrames.jl work
83-
MLUtils.getobs(data::DataFrame, i) = data[i,:]
84-
MLUtils.numobs(data::DataFrame) = nrow(data)
8598
df_balanced = getobs(undersample(df, df.target;shuffle=true))[1]
8699
```
87100

101+
All features are continuous:
102+
88103
```{julia}
89-
CSV.write(joinpath(data_path, "gmsc.csv"), df_balanced)
104+
schema(df_balanced)
105+
```
106+
107+
Turning the data into `CounterfactualData`:
108+
109+
```{julia}
110+
X = Matrix(df_balanced[:,Not(:target)])
111+
X = permutedims(X)
112+
y = permutedims(df_balanced.target)
113+
data = CounterfactualData(X,y)
114+
```
115+
116+
Saving:
117+
118+
```{julia}
119+
CSV.write(joinpath(data_path, "gmsc.csv"), df_balanced) # binary file
120+
Serialization.serialize(joinpath(data_path,"gmsc.jls"), data) # CounterfactualData
90121
```
91122

92123
## UCI Credit Card Default
93124

125+
Loading and basic preprocessing:
126+
94127
```{julia}
95128
df = CSV.read(joinpath(data_path, "raw/UCI_Credit_Card.csv"), DataFrame)
96129
select!(df, Not([:ID]))
@@ -100,17 +133,57 @@ df.SEX = categorical(df.SEX)
100133
df.EDUCATION = categorical(df.EDUCATION)
101134
df.MARRIAGE = categorical(df.MARRIAGE)
102135
mapcols!(x -> eltype(x) <: AbstractString ? parse.(Int, x) : x, df)
136+
df.target .= map(y -> y == 0 ? 1 : 0, df.target) # postive outcome = no default
103137
```
104138

139+
Balancing:
140+
105141
```{julia}
106-
# Make DataFrames.jl work
107-
MLUtils.getobs(data::DataFrame, i) = data[i,:]
108-
MLUtils.numobs(data::DataFrame) = nrow(data)
109142
df_balanced = getobs(undersample(df, df.target;shuffle=true))[1]
110143
```
111144

145+
**Not** all features are continuous:
146+
147+
```{julia}
148+
schema(df_balanced)
149+
```
150+
151+
One-hot encoding:
152+
153+
```{julia}
154+
hot = OneHotEncoder()
155+
mach = MLJBase.fit!(machine(hot, df_balanced))
156+
df_balanced = MLJBase.transform(mach, df_balanced)
157+
schema(df_balanced)
158+
```
159+
160+
Categorical indices:
161+
162+
```{julia}
163+
features_categorical = [
164+
[2,3],
165+
collect(4:10),
166+
collect(11:14)
167+
]
168+
```
169+
170+
Preparing for use with `CounterfactualExplanations.jl`:
171+
172+
```{julia}
173+
X = Matrix(df_balanced[:,Not(:target)])
174+
X = permutedims(X)
175+
y = permutedims(df_balanced.target)
176+
data = CounterfactualData(
177+
X, y;
178+
features_categorical = features_categorical
179+
)
180+
```
181+
182+
Saving:
183+
112184
```{julia}
113-
CSV.write(joinpath(data_path, "credit_default.csv"), df_balanced)
185+
CSV.write(joinpath(data_path, "credit_default.csv"), df_balanced) # binary file
186+
Serialization.serialize(joinpath(data_path,"credit_default.jls"), data) # CounterfactualData
114187
```
115188

116189

docs/src/paper/experiments/_real_world.qmd

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
include("docs/src/paper/setup.jl")
77
eval(setup)
88
output_path = output_dir("real_world")
9-
www_path = www_dir("real_world");
9+
www_path = www_dir("real_world")
10+
data_path = data_dir("real_world")
1011
```
1112

1213
```{julia}
1314
max_obs = 2500
14-
data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs)
15+
data_sets = AlgorithmicRecourseDynamics.Data.load_real_world(max_obs; data_dir=data_path)
1516
choices = [
1617
:cal_housing,
1718
:credit_default,
@@ -22,13 +23,13 @@ data_sets = filter(p -> p[1] in choices, data_sets)
2223

2324
```{julia}
2425
using CounterfactualExplanations.DataPreprocessing: unpack
25-
bs = 50
26+
bs = 500
2627
function data_loader(data::CounterfactualData)
2728
X, y = unpack(data)
2829
data = Flux.DataLoader((X,y),batchsize=bs)
2930
return data
3031
end
31-
model_params = (batch_norm=false,n_hidden=32,n_layers=3,dropout=true,p_dropout=0.25)
32+
model_params = (batch_norm=false,n_hidden=64,n_layers=3,dropout=true,p_dropout=0.5)
3233
```
3334

3435

@@ -62,9 +63,8 @@ n_rounds = 50
6263
evaluate_every = Int(round(n_rounds/n_evals))
6364
n_folds = 5
6465
n_samples = 10000
65-
T = 250
66+
T = 100
6667
generative_model_params = (epochs=250, latent_dim=8)
67-
using Serialization
6868
results = run_experiments(
6969
experiments;
7070
save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T, n_samples=n_samples,
@@ -76,7 +76,6 @@ Serialization.serialize(joinpath(output_path,"results.jls"),results)
7676
#### Plots
7777

7878
```{julia}
79-
using Serialization
8079
results = Serialization.deserialize(joinpath(output_path,"results.jls"))
8180
```
8281

docs/src/paper/setup.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ setup = quote
1717
using LaplaceRedux
1818
using Markdown
1919
using MLJBase
20+
using MLJModels: OneHotEncoder
2021
using MLUtils
22+
using MLUtils: undersample
2123
using Plots
2224
using Random
2325
using RCall
@@ -29,4 +31,8 @@ setup = quote
2931
theme(:wong)
3032
include("docs/src/utils.jl") # some helper functions
3133

34+
# Make DataFrames.jl work
35+
MLUtils.getobs(data::DataFrame, i) = data[i, :]
36+
MLUtils.numobs(data::DataFrame) = nrow(data)
37+
3238
end

src/base.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ function set_up_experiment(
175175

176176
# Pretrain:
177177
if !isnothing(pre_train_models)
178-
map!(model -> Models.train(model, data_train; n_epochs=pre_train_models, kwargs...), values(models))
178+
for (key, model) in models
179+
@info "Training $key"
180+
Models.train(model, data_train; n_epochs=pre_train_models, kwargs...)
181+
end
179182
end
180183

181184
experiment = Experiment(data_train, data_test, target, models, deepcopy(generators), num_counterfactuals)
@@ -218,7 +221,11 @@ function set_up_experiments(
218221
kwargs...
219222
)
220223

221-
experiments = Dict(key => set_up_single(data) for (key, data) in catalogue)
224+
experiments = Dict{Symbol, Experiment}()
225+
for (key, data) in catalogue
226+
@info "Setting up $(key)"
227+
experiments[key] = set_up_single(data)
228+
end
222229

223230
return experiments
224231
end

src/data/functions.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,41 @@ using LazyArtifacts
22
using CounterfactualExplanations.DataPreprocessing: CounterfactualData
33
using CSV
44
using DataFrames
5+
using Serialization
56
using StatsBase
67

7-
function load_synthetic(max_obs::Union{Nothing, Int}=nothing)
8-
data_dir = joinpath(artifact"data","data/synthetic")
8+
function load_synthetic(max_obs::Union{Nothing,Int}=nothing)
9+
data_dir = joinpath(artifact"data", "data/synthetic")
910
files = readdir(data_dir)
10-
files = files[contains.(files,".csv")]
11+
files = files[contains.(files, ".csv")]
1112
data = map(files) do file
1213
df = CSV.read(joinpath(data_dir, file), DataFrame)
13-
X = convert(Matrix, hcat(df.x1,df.x2)')
14+
X = convert(Matrix, hcat(df.x1, df.x2)')
1415
y = convert(Matrix, df.target')
15-
data = CounterfactualData(X,y)
16+
data = CounterfactualData(X, y)
1617
if !isnothing(max_obs)
1718
n_classes = length(unique(y))
18-
data = undersample(data, Int(round(max_obs/n_classes)))
19+
data = undersample(data, Int(round(max_obs / n_classes)))
1920
end
2021
(Symbol(replace(file, ".csv" => "")) => data)
2122
end
2223
data = Dict(data...)
2324
return data
2425
end
2526

26-
function load_real_world(max_obs::Union{Nothing, Int}=nothing)
27-
data_dir = joinpath(artifact"data","data/real_world")
27+
function load_real_world(max_obs::Union{Nothing,Int}=nothing; data_dir::Union{Nothing, String}=nothing)
28+
if isnothing(data_dir)
29+
data_dir = joinpath(artifact"data", "data/real_world")
30+
end
2831
files = readdir(data_dir)
29-
files = files[contains.(files,".csv")]
32+
files = files[contains.(files, ".jls")]
3033
data = map(files) do file
31-
df = CSV.read(joinpath(data_dir, file), DataFrame)
32-
X = Matrix(df[:,Not(:target)])
33-
X = permutedims(X)
34-
y = convert(Matrix, df.target')
35-
data = CounterfactualData(X,y)
34+
counterfactual_data = Serialization.deserialize(joinpath(data_dir, file))
3635
if !isnothing(max_obs)
37-
n_classes = length(unique(y))
38-
data = undersample(data, Int(round(max_obs/n_classes)))
36+
n_classes = length(unique(counterfactual_data.y))
37+
counterfactual_data = undersample(counterfactual_data, Int(round(max_obs / n_classes)))
3938
end
40-
(Symbol(replace(file, ".csv" => "")) => data)
39+
(Symbol(replace(file, ".jls" => "")) => counterfactual_data)
4140
end
4241
data = Dict(data...)
4342
return data

src/data/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ function undersample(data::CounterfactualData, n_per_class::Int)
3131
classes_ = sort(unique(y_cls))
3232

3333
idx = sort(reduce(vcat,[sample(findall(vec(y_cls.==cls)), n_per_class,replace=false) for cls in classes_]))
34-
data = CounterfactualData(X[:,idx], y[:,idx])
34+
data.X = X[:, idx]
35+
data.y = y[:,idx]
3536

3637
return data
3738

src/experiments/functions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ function update_experiment!(experiment::Experiment, recourse_system::RecourseSys
170170

171171
X′ = reduce(hcat, @.(selectdim(counterfactual(results), 3, indices_)))
172172
y′ = reduce(hcat, @.(selectdim(counterfactual_label(results), 1, indices_)))
173+
println(y′)
173174

174175
X[:, chosen_individuals] = X′
175176
y[:, chosen_individuals] = y′

0 commit comments

Comments
 (0)