2121
2222RLBase. prob (p:: OutcomeSamplingMCCFRPolicy , env:: AbstractEnv ) = prob (p. behavior_policy, env)
2323
24+ RLBase. prob (p:: OutcomeSamplingMCCFRPolicy , env:: AbstractEnv , action) =
25+ prob (p. behavior_policy, env, action)
26+
2427function OutcomeSamplingMCCFRPolicy (; state_type = String, rng = Random. GLOBAL_RNG, ϵ = 0.6 )
2528 OutcomeSamplingMCCFRPolicy (
2629 Dict {state_type,InfoStateNode} (),
3639
3740" Run one interation"
3841function RLBase. update! (p:: OutcomeSamplingMCCFRPolicy , env:: AbstractEnv )
39- for x in get_players (env)
42+ for x in players (env)
4043 if x != chance_player (env)
4144 outcome_sampling (copy (env), x, p. nodes, p. ϵ, 1.0 , 1.0 , 1.0 , p. rng)
4245 end
@@ -47,7 +50,10 @@ function RLBase.update!(p::OutcomeSamplingMCCFRPolicy)
4750 for (k, v) in p. nodes
4851 s = sum (v. cumulative_strategy)
4952 if s != 0
50- update! (p. behavior_policy, k => v. cumulative_strategy ./ s)
53+ m = v. mask
54+ strategy = zeros (length (m))
55+ strategy[m] .= v. cumulative_strategy ./ s
56+ update! (p. behavior_policy, k => strategy)
5157 else
5258 # The TabularLearner will return uniform distribution by default.
5359 # So we do nothing here.
@@ -56,22 +62,24 @@ function RLBase.update!(p::OutcomeSamplingMCCFRPolicy)
5662end
5763
5864function outcome_sampling (env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
59- current_player = current_player (env)
65+ player = current_player (env)
6066
6167 if is_terminated (env)
6268 reward (env, i) / s, 1.0
63- elseif current_player == chance_player (env)
64- env (rand (rng, action_space (env)))
69+ elseif player == chance_player (env)
70+ x = sample (rng, action_space (env), Weights (prob (env), 1.0 ))
71+ env (sample (rng, action_space (env), Weights (prob (env), 1.0 )))
6572 outcome_sampling (env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
6673 else
6774 I = state (env)
6875 legal_actions = legal_action_space (env)
76+ M = legal_action_space_mask (env)
6977 n = length (legal_actions)
70- node = get! (nodes, I, InfoStateNode (n ))
78+ node = get! (nodes, I, InfoStateNode (M ))
7179 regret_matching! (node; is_reset_neg_regrets = false )
7280 σ, rI, sI = node. strategy, node. cumulative_regret, node. cumulative_strategy
7381
74- if i == current_player
82+ if i == player
7583 aᵢ = rand (rng) >= ϵ ? sample (rng, Weights (σ, 1.0 )) : rand (rng, 1 : n)
7684 pᵢ = σ[aᵢ] * (1 - ϵ) + ϵ / n
7785 πᵢ′, π₋ᵢ′, s′ = πᵢ * pᵢ, π₋ᵢ, s * pᵢ
@@ -84,7 +92,7 @@ function outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
8492 env (legal_action_space (env)[aᵢ])
8593 u, πₜₐᵢₗ = outcome_sampling (env, i, nodes, ϵ, πᵢ′, π₋ᵢ′, s′, rng)
8694
87- if i == current_player
95+ if i == player
8896 w = u * π₋ᵢ
8997 rI .+ = w * πₜₐᵢₗ .* ((1 : n .== aᵢ) .- σ[aᵢ])
9098 else
0 commit comments