11using LinearAlgebra, CounterfactualExplanations
22
3- mutable struct EndoROARGenerator <: AbstractGradientBasedGenerator
3+ mutable struct ClapROARGenerator <: AbstractGradientBasedGenerator
44 loss:: Union{Nothing,Symbol} # loss function
55 complexity:: Function # complexity function
66 λ:: Union{AbstractFloat,AbstractVector} # strength of penalty
1111
1212# API streamlining:
1313using Parameters, Flux
14- @with_kw struct EndoROARGeneratorParams
14+ @with_kw struct ClapROARGeneratorParams
1515 opt:: Any = Flux. Optimise. Descent ()
1616 τ:: AbstractFloat = 1e-5
1717end
1818
1919"""
20- EndoROARGenerator (
20+ ClapROARGenerator (
2121 ;
2222 loss::Symbol=:logitbinarycrossentropy,
2323 complexity::Function=norm,
@@ -30,23 +30,23 @@ An outer constructor method that instantiates a generic generator.
3030
3131# Examples
3232```julia-repl
33- generator = EndoROARGenerator ()
33+ generator = ClapROARGenerator ()
3434```
3535"""
36- function EndoROARGenerator (
36+ function ClapROARGenerator (
3737 ;
3838 loss:: Union{Nothing,Symbol} = nothing ,
3939 complexity:: Function = norm,
4040 λ:: Union{AbstractFloat,AbstractVector} = [0.1 ,5.0 ],
4141 decision_threshold= nothing ,
4242 kwargs...
4343)
44- params = EndoROARGeneratorParams (;kwargs... )
45- EndoROARGenerator (loss, complexity, λ, decision_threshold, params. opt, params. τ)
44+ params = ClapROARGeneratorParams (;kwargs... )
45+ ClapROARGenerator (loss, complexity, λ, decision_threshold, params. opt, params. τ)
4646end
4747
4848using Flux
49- function gradient_penalty (generator:: EndoROARGenerator , counterfactual_state:: CounterfactualState.State )
49+ function gradient_penalty (generator:: ClapROARGenerator , counterfactual_state:: CounterfactualState.State )
5050
5151 x_ = counterfactual_state. f (counterfactual_state. s′)
5252 M = counterfactual_state. M
@@ -68,7 +68,7 @@ import CounterfactualExplanations.Generators: h
6868
6969The default method to apply the generator complexity penalty to the current counterfactual state for any generator.
7070"""
71- function h (generator:: EndoROARGenerator , counterfactual_state:: CounterfactualState.State )
71+ function h (generator:: ClapROARGenerator , counterfactual_state:: CounterfactualState.State )
7272
7373 # Distance from factual:
7474 dist_ = generator. complexity (counterfactual_state. x .- counterfactual_state. f (counterfactual_state. s′))
0 commit comments