@@ -6,10 +6,10 @@ struct InfoStateNode
66 cumulative_strategy:: Vector{Float64}
77end
88
9- InfoStateNode (n) = InfoStateNode (fill (1 / n, n), zeros (n), zeros (n))
9+ InfoStateNode (n) = InfoStateNode (fill (1 / n, n), zeros (n), zeros (n))
1010
1111function init_info_state_nodes (env:: AbstractEnv )
12- nodes = Dict {String, InfoStateNode} ()
12+ nodes = Dict {String,InfoStateNode} ()
1313 walk (env) do x
1414 if ! get_terminal (x) && get_current_player (x) != get_chance_player (x)
1515 get! (nodes, get_state (x), InfoStateNode (length (get_legal_actions (x))))
2424See more details: [An Introduction to Counterfactual Regret Minimization](http://modelai.gettysburg.edu/2013/cfr/cfr.pdf)
2525"""
2626struct TabularCFRPolicy{S,T,R<: AbstractRNG } <: AbstractPolicy
27- nodes:: Dict{S, InfoStateNode}
28- behavior_policy:: QBasedPolicy{TabularLearner{S,T}, WeightedExplorer{true,R}}
27+ nodes:: Dict{S,InfoStateNode}
28+ behavior_policy:: QBasedPolicy{TabularLearner{S,T},WeightedExplorer{true,R}}
2929end
3030
3131(p:: TabularCFRPolicy )(env:: AbstractEnv ) = p. behavior_policy (env)
@@ -35,7 +35,13 @@ RLBase.get_prob(p::TabularCFRPolicy, env::AbstractEnv) = get_prob(p.behavior_pol
3535"""
3636 TabularCFRPolicy(;n_iter::Int, env::AbstractEnv)
3737"""
38- function TabularCFRPolicy (;n_iter:: Int , env:: AbstractEnv , rng= Random. GLOBAL_RNG, is_reset_neg_regrets= false , is_linear_averaging= false )
38+ function TabularCFRPolicy (;
39+ n_iter:: Int ,
40+ env:: AbstractEnv ,
41+ rng = Random. GLOBAL_RNG,
42+ is_reset_neg_regrets = false ,
43+ is_linear_averaging = false ,
44+ )
3945 @assert NumAgentStyle (env) isa MultiAgent
4046 @assert DynamicStyle (env) === SEQUENTIAL
4147 @assert RewardStyle (env) === TERMINAL_REWARD
@@ -47,7 +53,8 @@ function TabularCFRPolicy(;n_iter::Int, env::AbstractEnv, rng=Random.GLOBAL_RNG,
4753 for i in 1 : n_iter
4854 for p in get_players (env)
4955 if p != get_chance_player (env)
50- init_reach_prob = Dict (x=> 1.0 for x in get_players (env) if x != get_chance_player (env))
56+ init_reach_prob =
57+ Dict (x => 1.0 for x in get_players (env) if x != get_chance_player (env))
5158 cfr! (nodes, env, p, init_reach_prob, 1.0 , is_linear_averaging ? i : 1 )
5259 update_strategy! (nodes)
5360
@@ -60,9 +67,12 @@ function TabularCFRPolicy(;n_iter::Int, env::AbstractEnv, rng=Random.GLOBAL_RNG,
6067 end
6168 end
6269
63- behavior_policy = QBasedPolicy (;learner= TabularLearner {String} (), explorer= WeightedExplorer (;is_normalized= true , rng= rng))
70+ behavior_policy = QBasedPolicy (;
71+ learner = TabularLearner {String} (),
72+ explorer = WeightedExplorer (; is_normalized = true , rng = rng),
73+ )
6474
65- for (k,v) in nodes
75+ for (k, v) in nodes
6676 s = sum (v. cumulative_strategy)
6777 if s != 0
6878 update! (behavior_policy, k => v. cumulative_strategy ./ s)
@@ -77,23 +87,39 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
7787 get_reward (env, player)
7888 else
7989 if get_current_player (env) == get_chance_player (env)
80- v = 0.
90+ v = 0.0
8191 for a:: ActionProbPair in get_legal_actions (env)
82- v += a. prob * cfr! (nodes, child (env, a), player, reach_probs, chance_player_reach_prob * a. prob, ratio)
92+ v +=
93+ a. prob * cfr! (
94+ nodes,
95+ child (env, a),
96+ player,
97+ reach_probs,
98+ chance_player_reach_prob * a. prob,
99+ ratio,
100+ )
83101 end
84102 v
85103 else
86- v = 0.
104+ v = 0.0
87105 node = nodes[get_state (env)]
88106 legal_actions = get_legal_actions (env)
89- U = player == get_current_player (env) ? Vector {Float64} (undef, length (legal_actions)) : nothing
107+ U = player == get_current_player (env) ?
108+ Vector {Float64} (undef, length (legal_actions)) : nothing
90109
91110 for (i, action) in enumerate (legal_actions)
92111 prob = node. strategy[i]
93112 new_reach_probs = copy (reach_probs)
94113 new_reach_probs[get_current_player (env)] *= prob
95114
96- u = cfr! (nodes, child (env, action), player, new_reach_probs, chance_player_reach_prob, ratio)
115+ u = cfr! (
116+ nodes,
117+ child (env, action),
118+ player,
119+ new_reach_probs,
120+ chance_player_reach_prob,
121+ ratio,
122+ )
97123 isnothing (U) || (U[i] = u)
98124 v += prob * u
99125 end
@@ -102,8 +128,13 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
102128 reach_prob = reach_probs[player]
103129 counterfactual_reach_prob = reduce (
104130 * ,
105- (reach_probs[p] for p in get_players (env) if p != player && p != get_chance_player (env));
106- init= chance_player_reach_prob)
131+ (
132+ reach_probs[p]
133+ for
134+ p in get_players (env) if p != player && p != get_chance_player (env)
135+ );
136+ init = chance_player_reach_prob,
137+ )
107138 node. cumulative_regret .+ = counterfactual_reach_prob .* (U .- v)
108139 node. cumulative_strategy .+ = ratio .* reach_prob .* node. strategy
109140 end
@@ -113,16 +144,16 @@ function cfr!(nodes, env, player, reach_probs, chance_player_reach_prob, ratio)
113144end
114145
115146function regret_matching! (strategy, cumulative_regret)
116- s = mapreduce (x-> max (0 ,x), + ,cumulative_regret)
147+ s = mapreduce (x -> max (0 , x), + , cumulative_regret)
117148 if s > 0
118- strategy .= max .(0. , cumulative_regret) ./ s
149+ strategy .= max .(0.0 , cumulative_regret) ./ s
119150 else
120- fill! (strategy, 1 / length (strategy))
151+ fill! (strategy, 1 / length (strategy))
121152 end
122153end
123154
124155function update_strategy! (nodes)
125156 for node in values (nodes)
126157 regret_matching! (node. strategy, node. cumulative_regret)
127158 end
128- end
159+ end
0 commit comments