@@ -172,14 +172,15 @@ function update_experiment!(experiment::Experiment, recourse_system::RecourseSys
172172 indices_ = rand (1 : experiment. num_counterfactuals, length (results)) # randomly draw from generated counterfactuals
173173 X′ = reduce (hcat, @. (selectdim (counterfactual (results), 3 , indices_)))
174174 y′ = reduce (hcat, @. (selectdim (counterfactual_label (results), 3 , indices_)))
175+ y′ = [y[1 ] for y in y′]
175176
176- # If for any counterfactuals the returned label is NaN, this is considered as invalid and the current label is not updated:
177- valid_ces = vec (.! (isnan .(y′)))
178- chosen_individuals = chosen_individuals[valid_ces]
177+ # If for any counterfactuals the returned label is NaN, this is considered as invalid and the current label is not updated:
178+ valid_ces = vec (.! (isnan .(y′)))
179+ chosen_individuals = chosen_individuals[valid_ces]
179180
180- # Update data:
181- X[:, chosen_individuals] = X′[:, valid_ces]
182- y[:, chosen_individuals] = y′[:, valid_ces]
181+ # Update data:
182+ X[:, chosen_individuals] = X′[:, valid_ces]
183+ y[:, chosen_individuals] = y′[:, valid_ces]
183184
184185 # Generative model:
185186 gen_mod = deepcopy (counterfactual_data. generative_model)
0 commit comments