Skip to content

Commit 913acfd

Browse files
committed
renamed EndoROAR to ClapROAR as in Classifier-preserving ROAR
1 parent 0e42386 commit 913acfd

6 files changed

Lines changed: 21 additions & 20 deletions

File tree

dev/notebooks/endo_roar_generator.qmd

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ output_path = output_dir("generator")
1616
www_path = www_dir("generator")
1717
```
1818

19-
# `EndoROARGenerator`
19+
# `ClapROARGenerator`
2020

2121
```{julia}
2222
using MLJ
@@ -94,7 +94,7 @@ counterfactuals_strict = []
9494
generators = []
9595
for λ₂ ∈ Λ₂
9696
λ = [0.1, λ₂]
97-
generator = EndoROARGenerator(λ=λ)
97+
generator = ClapROARGenerator(λ=λ)
9898
generators = vcat(generators..., generator)
9999
counterfactuals_strict = vcat(
100100
counterfactuals_strict...,
@@ -120,7 +120,7 @@ plt = plot(plts..., size=(1200,300), layout=(1,3))
120120
```{julia}
121121
generators = Dict(
122122
:Generic => GenericGenerator(),
123-
:ROAR => EndoROARGenerator()
123+
:ROAR => ClapROARGenerator()
124124
)
125125
counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; T=T) for (name, gen) in generators])
126126
```

dev/notebooks/mitigation_strategies.qmd

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ models = [
2424
]
2525
generators = Dict(
2626
:Generic=>GenericGenerator(decision_threshold=0.5),
27+
:REVISE=>REVISEGenerator(),
2728
:Generic_conservative=>GenericGenerator(decision_threshold=0.9),
2829
:Gravitational=>GravitationalGenerator(),
29-
:ROAR=>EndoROARGenerator()
30+
:ClapROAR=>ClapROARGenerator()
3031
)
3132
```
3233

@@ -231,14 +232,12 @@ end
231232

232233
### Latent Space Search
233234

234-
235-
236235
```{julia}
237236
generators = Dict(
238237
:REVISE=>GenericGenerator(decision_threshold=0.5),
239238
:REVISE_conservative=>GenericGenerator(decision_threshold=0.9),
240239
:Gravitational=>GravitationalGenerator(),
241-
:ROAR=>EndoROARGenerator()
240+
:ROAR=>ClapROARGenerator()
242241
)
243242
```
244243

src/AlgorithmicRecourseDynamics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using .Models
1010

1111
include("generators/Generators.jl")
1212
using .Generators
13-
export GravitationalGenerator, EndoROARGenerator
13+
export GravitationalGenerator, ClapROARGenerator
1414

1515
include("experiments/Experiments.jl")
1616
using .Experiments

src/experiments/functions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Parameters
1010
intersect_::Bool = true
1111
convergence::Symbol = :threshold_only
1212
generative_model_params::NamedTuple = (;)
13+
latent_space::Union{Nothing, Bool} = nothing
1314
end
1415

1516
mutable struct Experiment
@@ -162,7 +163,8 @@ function update!(experiment::Experiment, recourse_system::RecourseSystem, chosen
162163
factuals = select_factual(counterfactual_data,chosen_individuals)
163164
results = generate_counterfactual(
164165
factuals, target, counterfactual_data, M, generator;
165-
T=T, num_counterfactuals=experiment.num_counterfactuals, generative_model_params=args.generative_model_params
166+
T=T, num_counterfactuals=experiment.num_counterfactuals, generative_model_params=args.generative_model_params,
167+
latent_space=args.latent_space
166168
);
167169

168170
indices_ = rand(1:experiment.num_counterfactuals,length(results)) # randomly draw from generated counterfactuals

src/generators/EndoROARGenerator.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra, CounterfactualExplanations
22

3-
mutable struct EndoROARGenerator <: AbstractGradientBasedGenerator
3+
mutable struct ClapROARGenerator <: AbstractGradientBasedGenerator
44
loss::Union{Nothing,Symbol} # loss function
55
complexity::Function # complexity function
66
λ::Union{AbstractFloat,AbstractVector} # strength of penalty
@@ -11,13 +11,13 @@ end
1111

1212
# API streamlining:
1313
using Parameters, Flux
14-
@with_kw struct EndoROARGeneratorParams
14+
@with_kw struct ClapROARGeneratorParams
1515
opt::Any=Flux.Optimise.Descent()
1616
τ::AbstractFloat=1e-5
1717
end
1818

1919
"""
20-
EndoROARGenerator(
20+
ClapROARGenerator(
2121
;
2222
loss::Symbol=:logitbinarycrossentropy,
2323
complexity::Function=norm,
@@ -30,23 +30,23 @@ An outer constructor method that instantiates a generic generator.
3030
3131
# Examples
3232
```julia-repl
33-
generator = EndoROARGenerator()
33+
generator = ClapROARGenerator()
3434
```
3535
"""
36-
function EndoROARGenerator(
36+
function ClapROARGenerator(
3737
;
3838
loss::Union{Nothing,Symbol}=nothing,
3939
complexity::Function=norm,
4040
λ::Union{AbstractFloat,AbstractVector}=[0.1,5.0],
4141
decision_threshold=nothing,
4242
kwargs...
4343
)
44-
params = EndoROARGeneratorParams(;kwargs...)
45-
EndoROARGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ)
44+
params = ClapROARGeneratorParams(;kwargs...)
45+
ClapROARGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ)
4646
end
4747

4848
using Flux
49-
function gradient_penalty(generator::EndoROARGenerator, counterfactual_state::CounterfactualState.State)
49+
function gradient_penalty(generator::ClapROARGenerator, counterfactual_state::CounterfactualState.State)
5050

5151
x_ = counterfactual_state.f(counterfactual_state.s′)
5252
M = counterfactual_state.M
@@ -68,7 +68,7 @@ import CounterfactualExplanations.Generators: h
6868
6969
The default method to apply the generator complexity penalty to the current counterfactual state for any generator.
7070
"""
71-
function h(generator::EndoROARGenerator, counterfactual_state::CounterfactualState.State)
71+
function h(generator::ClapROARGenerator, counterfactual_state::CounterfactualState.State)
7272

7373
# Distance from factual:
7474
dist_ = generator.complexity(counterfactual_state.x .- counterfactual_state.f(counterfactual_state.s′))

src/generators/Generators.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module Generators
33
include("GravitationalGenerator.jl")
44
export GravitationalGenerator
55

6-
include("EndoROARGenerator.jl")
7-
export EndoROARGenerator
6+
include("ClapROARGenerator.jl")
7+
export ClapROARGenerator
88

99
end

0 commit comments

Comments
 (0)