Skip to content

Commit 38a04c4

Browse files
committed
uh
1 parent 5cd86d0 commit 38a04c4

2 files changed

Lines changed: 39 additions & 1 deletion

File tree

test/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
[deps]
22
CompatHelperLocal = "5224ae11-6099-4aaa-941d-3aab004bd678"
3+
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
4+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
5+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
6+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
7+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
38
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/runtests.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,41 @@ import CompatHelperLocal as CHL
22
CHL.@check()
33

44
using AlgorithmicRecourseDynamics
5+
using AlgorithmicRecourseDynamics.Data
6+
using AlgorithmicRecourseDynamics.Experiments
7+
using AlgorithmicRecourseDynamics.Models
8+
using AlgorithmicRecourseDynamics: run!
9+
using CounterfactualExplanations
10+
using Flux
11+
using MLJBase
12+
using Plots
13+
using Random
514
using Test
615

716
@testset "AlgorithmicRecourseDynamics.jl" begin
8-
# Write your tests here.
17+
18+
N = 1000
19+
xmax = 2
20+
X, ys = make_blobs(
21+
N, 2;
22+
centers=2, as_table=false, center_box=(-xmax => xmax), cluster_std=0.1
23+
)
24+
ys .= ys .== 2
25+
X = X'
26+
counterfactual_data = CounterfactualData(X, ys')
27+
28+
n_epochs = 100
29+
model = Chain(Dense(2, 1))
30+
mod = FluxModel(model)
31+
generator = GenericGenerator()
32+
33+
data_train, data_test = Data.train_test_split(counterfactual_data)
34+
Models.train(mod, data_train; n_epochs=n_epochs)
35+
36+
models = Dict(:mymodel => mod)
37+
generators = Dict(:wachter => generator)
38+
experiment = set_up_experiment(data_train, data_test, models, generators)
39+
40+
run!(experiment)
41+
942
end

0 commit comments

Comments
 (0)