Skip to content

Commit 1bced06

Browse files
committed
come on then
1 parent 4194be4 commit 1bced06

3 files changed

Lines changed: 9 additions & 8 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2525

2626
[compat]
2727
CSV = "0.10"
28-
CounterfactualExplanations = "0.1 - 0"
28+
CounterfactualExplanations = "0.1 - 0.1.6"
2929
DataFrames = "1"
3030
Distances = "0.10"
3131
Flux = "0.13 - 0.14"

src/experiments/functions.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,15 @@ function update_experiment!(experiment::Experiment, recourse_system::RecourseSys
172172
indices_ = rand(1:experiment.num_counterfactuals, length(results)) # randomly draw from generated counterfactuals
173173
X′ = reduce(hcat, @.(selectdim(counterfactual(results), 3, indices_)))
174174
y′ = reduce(hcat, @.(selectdim(counterfactual_label(results), 3, indices_)))
175+
y′ = [y[1] for y in y′]
175176

176-
# If for any counterfactuals the returned label is NaN, this is considered as invalid and the current label is not updated:
177-
valid_ces = vec(.!(isnan.(y′)))
178-
chosen_individuals = chosen_individuals[valid_ces]
177+
# If for any counterfactuals the returned label is NaN, this is considered as invalid and the current label is not updated:
178+
valid_ces = vec(.!(isnan.(y′)))
179+
chosen_individuals = chosen_individuals[valid_ces]
179180

180-
# Update data:
181-
X[:, chosen_individuals] = X′[:, valid_ces]
182-
y[:, chosen_individuals] = y′[:, valid_ces]
181+
# Update data:
182+
X[:, chosen_individuals] = X′[:, valid_ces]
183+
y[:, chosen_individuals] = y′[:, valid_ces]
183184

184185
# Generative model:
185186
gen_mod = deepcopy(counterfactual_data.generative_model)

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
88
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
99

1010
[compat]
11-
CounterfactualExplanations = "0.1 - 0"
11+
CounterfactualExplanations = "0.1 - 0.1.6"
1212
Flux = "0.13 - 0.14"
1313

0 commit comments

Comments
 (0)