Skip to content

Commit 0057490

Browse files
committed
I just don't know why this isn't working for cat data
1 parent f904091 commit 0057490

5 files changed

Lines changed: 60 additions & 30 deletions

File tree

docs/src/paper/data_preprocessing/_real_world_data.qmd

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,6 @@ Loading the data into Julia session.
3131

3232
```{julia}
3333
df = CSV.read(joinpath(data_path, "raw/cal_housing.csv"), DataFrame)
34-
# # Features:
35-
# X = Matrix(df[:,Not(:target)])
36-
# dt = StatsBase.fit(ZScoreTransform, X, dims=1)
37-
# StatsBase.transform!(dt, X)
38-
# df = DataFrame(X,:auto)
3934
# Target:
4035
y = df.target
4136
y = Float64.(y .>= median(y)); # binary target (positive outcome)
@@ -55,10 +50,19 @@ All features are continuous:
5550
schema(df_balanced)
5651
```
5752

53+
Feature transformation:
54+
55+
```{julia}
56+
transformer = Standardizer(count=true)
57+
mach = MLJBase.fit!(machine(transformer, df_balanced[:,Not(:target)]))
58+
X = MLJBase.transform(mach, df_balanced[:,Not(:target)])
59+
schema(X)
60+
```
61+
5862
Turning the data into `CounterfactualData`:
5963

6064
```{julia}
61-
X = Matrix(df_balanced[:,Not(:target)])
65+
X = Matrix(X)
6266
X = permutedims(X)
6367
y = permutedims(df_balanced.target)
6468
data = CounterfactualData(X,y)
@@ -83,12 +87,6 @@ rename!(df, :SeriousDlqin2yrs => :target)
8387
mapcols!(x -> [ifelse(x_=="NA", missing, x_) for x_ in x], df)
8488
dropmissing!(df)
8589
mapcols!(x -> eltype(x) <: AbstractString ? parse.(Int, x) : x, df)
86-
# # Features:
87-
# X = Matrix(df[:,Not(:target)])
88-
# dt = StatsBase.fit(ZScoreTransform, X, dims=1)
89-
# StatsBase.transform!(dt, X)
90-
# df = DataFrame(X,:auto)
91-
# Target:
9290
df.target .= map(y -> y == 0 ? 1 : 0, df.target) # postive outcome = no delinquency
9391
```
9492

@@ -104,10 +102,19 @@ All features are continuous:
104102
schema(df_balanced)
105103
```
106104

105+
Feature transformation:
106+
107+
```{julia}
108+
transformer = Standardizer(count=true)
109+
mach = MLJBase.fit!(machine(transformer, df_balanced[:,Not(:target)]))
110+
X = MLJBase.transform(mach, df_balanced[:,Not(:target)])
111+
schema(X)
112+
```
113+
107114
Turning the data into `CounterfactualData`:
108115

109116
```{julia}
110-
X = Matrix(df_balanced[:,Not(:target)])
117+
X = Matrix(X)
111118
X = permutedims(X)
112119
y = permutedims(df_balanced.target)
113120
data = CounterfactualData(X,y)
@@ -148,13 +155,13 @@ df_balanced = getobs(undersample(df, df.target;shuffle=true))[1]
148155
schema(df_balanced)
149156
```
150157

151-
One-hot encoding:
158+
Feature transformation:
152159

153160
```{julia}
154-
hot = OneHotEncoder()
155-
mach = MLJBase.fit!(machine(hot, df_balanced))
156-
df_balanced = MLJBase.transform(mach, df_balanced)
157-
schema(df_balanced)
161+
transformer = Standardizer(count=true) |> ContinuousEncoder()
162+
mach = MLJBase.fit!(machine(transformer, df_balanced[:,Not(:target)]))
163+
X = MLJBase.transform(mach, df_balanced[:,Not(:target)])
164+
schema(X)
158165
```
159166

160167
Categorical indices:
@@ -170,7 +177,7 @@ features_categorical = [
170177
Preparing for use with `CounterfactualExplanations.jl`:
171178

172179
```{julia}
173-
X = Matrix(df_balanced[:,Not(:target)])
180+
X = Matrix(X)
174181
X = permutedims(X)
175182
y = permutedims(df_balanced.target)
176183
data = CounterfactualData(

docs/src/paper/experiments/_real_world.qmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function data_loader(data::CounterfactualData)
2929
data = Flux.DataLoader((X,y),batchsize=bs)
3030
return data
3131
end
32-
model_params = (batch_norm=false,n_hidden=64,n_layers=3,dropout=true,p_dropout=0.5)
32+
model_params = (batch_norm=false,n_hidden=64,n_layers=3,dropout=true,p_dropout=0.1)
3333
```
3434

3535

docs/src/paper/setup.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ setup = quote
1717
using LaplaceRedux
1818
using Markdown
1919
using MLJBase
20-
using MLJModels: OneHotEncoder
20+
using MLJModels: ContinuousEncoder, OneHotEncoder, Standardizer
2121
using MLUtils
2222
using MLUtils: undersample
2323
using Plots

src/data/utils.jl

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,43 @@
11
using CounterfactualExplanations.DataPreprocessing: CounterfactualData
2+
using DataFrames
23
using Flux
34
using StatsBase
45

6+
function Base.hcat(data::CounterfactualData, more_data::CounterfactualData)
7+
8+
data = deepcopy(data)
9+
more_data = deepcopy(more_data)
10+
11+
@assert all(data.features_categorical .== more_data.features_categorical) "Datasets have different categorical indices."
12+
@assert all(data.features_continuous .== more_data.features_continuous) "Datasets have different continous indices."
13+
14+
data.X = hcat(data.X, more_data.X)
15+
data.y = hcat(data.y, more_data.y)
16+
17+
return data
18+
end
19+
20+
function DataFrames.subset(data::CounterfactualData, idx::Vector{Int})
21+
dsub = deepcopy(data)
22+
dsub.X = dsub.X[:,idx]
23+
dsub.y = dsub.y[:,idx]
24+
return dsub
25+
end
26+
527
"""
628
train_test_split(data::CounterfactualData;test_size=0.2)
729
830
Splits data into train and test split.
931
"""
1032
function train_test_split(data::CounterfactualData;test_size=0.2)
11-
X,y = CounterfactualExplanations.DataPreprocessing.unpack(data)
33+
X, y = CounterfactualExplanations.DataPreprocessing.unpack(data)
1234
N = size(y,2)
1335
classes_ = sort(unique(y))
1436
n_per_class = round(N/length(classes_))
1537
test_idx = sort(reduce(vcat,[sample(findall(vec(y.==cls)), Int(floor(test_size * n_per_class)),replace=false) for cls in classes_]))
1638
train_idx = setdiff(1:N, test_idx)
17-
train_data = CounterfactualData(X[:,train_idx], y[:,train_idx])
18-
test_data = CounterfactualData(X[:,test_idx], y[:,test_idx])
39+
train_data = subset(data, train_idx)
40+
test_data = subset(data, test_idx)
1941
return train_data, test_data
2042
end
2143

src/experiments/functions.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ function Experiment(
4646
system_identifiers = Base.Iterators.product(keys(models), keys(generators))
4747

4848
# Full data:
49-
X_train, y_train = DataPreprocessing.unpack(train_data)
50-
X_test, y_test = DataPreprocessing.unpack(train_data)
51-
data = CounterfactualData(hcat(X_train, X_test), hcat(y_train, y_test))
49+
data = hcat(train_data, test_data)
5250

5351
# Initial scores:
5452
initial_model_scores = [(name, Models.model_evaluation(model, test_data)) for (name, model) in pairs(models)]
@@ -167,10 +165,8 @@ function update_experiment!(experiment::Experiment, recourse_system::RecourseSys
167165
)
168166

169167
indices_ = rand(1:experiment.num_counterfactuals, length(results)) # randomly draw from generated counterfactuals
170-
171168
X′ = reduce(hcat, @.(selectdim(counterfactual(results), 3, indices_)))
172169
y′ = reduce(hcat, @.(selectdim(counterfactual_label(results), 1, indices_)))
173-
println(y′)
174170

175171
X[:, chosen_individuals] = X′
176172
y[:, chosen_individuals] = y′
@@ -182,7 +178,12 @@ function update_experiment!(experiment::Experiment, recourse_system::RecourseSys
182178
end
183179

184180
# Update data, classifier and benchmark:
185-
recourse_system.data = CounterfactualData(X, y; generative_model=gen_mod)
181+
recourse_system.data = CounterfactualData(
182+
X, y;
183+
generative_model = gen_mod,
184+
features_categorical = counterfactual_data.features_categorical,
185+
features_continuous = counterfactual_data.features_continuous,
186+
)
186187
recourse_system.model = Models.train(M, counterfactual_data)
187188
recourse_system.benchmark = vcat(recourse_system.benchmark, CounterfactualExplanations.Benchmark.benchmark(results))
188189

0 commit comments

Comments
 (0)