Skip to content

Commit a4a4d73

Browse files
committed
uh
1 parent 4008ed2 commit a4a4d73

1 file changed

Lines changed: 0 additions & 42 deletions

File tree

src/data/functions.jl

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,6 @@
1-
using LazyArtifacts
2-
using CounterfactualExplanations.DataPreprocessing: CounterfactualData
3-
using CSV
4-
using DataFrames
51
using Serialization
62
using StatsBase
73

8-
function load_synthetic(max_obs::Union{Nothing,Int}=nothing)
9-
data_dir = joinpath(artifact"data", "data/synthetic")
10-
files = readdir(data_dir)
11-
files = files[contains.(files, ".csv")]
12-
data = map(files) do file
13-
df = CSV.read(joinpath(data_dir, file), DataFrame)
14-
X = convert(Matrix, hcat(df.x1, df.x2)')
15-
y = convert(Matrix, df.target')
16-
data = CounterfactualData(X, y)
17-
if !isnothing(max_obs)
18-
n_classes = length(unique(y))
19-
data = undersample(data, Int(round(max_obs / n_classes)))
20-
end
21-
(Symbol(replace(file, ".csv" => "")) => data)
22-
end
23-
data = Dict(data...)
24-
return data
25-
end
26-
27-
function load_real_world(max_obs::Union{Nothing,Int}=nothing; data_dir::Union{Nothing, String}=nothing)
28-
if isnothing(data_dir)
29-
data_dir = joinpath(artifact"data", "data/real_world")
30-
end
31-
files = readdir(data_dir)
32-
files = files[contains.(files, ".jls")]
33-
data = map(files) do file
34-
counterfactual_data = Serialization.deserialize(joinpath(data_dir, file))
35-
if !isnothing(max_obs)
36-
n_classes = length(unique(counterfactual_data.y))
37-
counterfactual_data = undersample(counterfactual_data, Int(round(max_obs / n_classes)))
38-
end
39-
(Symbol(replace(file, ".jls" => "")) => counterfactual_data)
40-
end
41-
data = Dict(data...)
42-
return data
43-
end
44-
45-
464
function scale(X, dim)
475
dt = fit(ZScoreTransform, X, dim=dim)
486
X_scaled = StatsBase.transform(dt, X)

0 commit comments

Comments
 (0)