|
1 | | -using LazyArtifacts |
2 | | -using CounterfactualExplanations.DataPreprocessing: CounterfactualData |
3 | | -using CSV |
4 | | -using DataFrames |
5 | 1 | using Serialization |
6 | 2 | using StatsBase |
7 | 3 |
|
8 | | -function load_synthetic(max_obs::Union{Nothing,Int}=nothing) |
9 | | - data_dir = joinpath(artifact"data", "data/synthetic") |
10 | | - files = readdir(data_dir) |
11 | | - files = files[contains.(files, ".csv")] |
12 | | - data = map(files) do file |
13 | | - df = CSV.read(joinpath(data_dir, file), DataFrame) |
14 | | - X = convert(Matrix, hcat(df.x1, df.x2)') |
15 | | - y = convert(Matrix, df.target') |
16 | | - data = CounterfactualData(X, y) |
17 | | - if !isnothing(max_obs) |
18 | | - n_classes = length(unique(y)) |
19 | | - data = undersample(data, Int(round(max_obs / n_classes))) |
20 | | - end |
21 | | - (Symbol(replace(file, ".csv" => "")) => data) |
22 | | - end |
23 | | - data = Dict(data...) |
24 | | - return data |
25 | | -end |
26 | | - |
27 | | -function load_real_world(max_obs::Union{Nothing,Int}=nothing; data_dir::Union{Nothing, String}=nothing) |
28 | | - if isnothing(data_dir) |
29 | | - data_dir = joinpath(artifact"data", "data/real_world") |
30 | | - end |
31 | | - files = readdir(data_dir) |
32 | | - files = files[contains.(files, ".jls")] |
33 | | - data = map(files) do file |
34 | | - counterfactual_data = Serialization.deserialize(joinpath(data_dir, file)) |
35 | | - if !isnothing(max_obs) |
36 | | - n_classes = length(unique(counterfactual_data.y)) |
37 | | - counterfactual_data = undersample(counterfactual_data, Int(round(max_obs / n_classes))) |
38 | | - end |
39 | | - (Symbol(replace(file, ".jls" => "")) => counterfactual_data) |
40 | | - end |
41 | | - data = Dict(data...) |
42 | | - return data |
43 | | -end |
44 | | - |
45 | | - |
46 | 4 | function scale(X, dim) |
47 | 5 | dt = fit(ZScoreTransform, X, dim=dim) |
48 | 6 | X_scaled = StatsBase.transform(dt, X) |
|
0 commit comments