Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 52424bf

Browse files
authored
Enable cfr tests (#141)
* decouple PreActStage * fix cfr related tests * rename * enable all tests * resolve test error * bugfix with A2C related code * update dependnecy
1 parent d7077e8 commit 52424bf

20 files changed

Lines changed: 395 additions & 385 deletions

Project.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Flux = "0.11"
3838
IntervalSets = "0.5"
3939
MacroTools = "0.5"
4040
ReinforcementLearningBase = "0.9"
41-
ReinforcementLearningCore = "0.6.1"
41+
ReinforcementLearningCore = "0.6.3"
4242
Requires = "1"
4343
Setfield = "0.6, 0.7"
4444
StableRNGs = "1.0"
@@ -47,3 +47,11 @@ StructArrays = "0.4"
4747
TensorBoardLogger = "0.1"
4848
Zygote = "0.5"
4949
julia = "1.4"
50+
51+
[extras]
52+
OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
53+
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
54+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55+
56+
[targets]
57+
test = ["Test", "ReinforcementLearningEnvironments", "OpenSpiel"]

src/ReinforcementLearningZoo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function __init__()
3737
@require ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" begin
3838
include("experiments/rl_envs/rl_envs.jl")
3939
@require ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" include("experiments/atari/atari.jl")
40-
# @require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include("experiments/open_spiel/open_spiel.jl")
40+
@require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include("experiments/open_spiel/open_spiel.jl")
4141
end
4242
end
4343

src/algorithms/cfr/abstract_cfr_policy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function Base.run(
1010
@assert DynamicStyle(env) === SEQUENTIAL
1111
@assert RewardStyle(env) === TERMINAL_REWARD
1212
@assert ChanceStyle(env) === EXPLICIT_STOCHASTIC
13-
@assert DefaultStateStyle(env) isa Information
13+
@assert DefaultStateStyle(env) isa InformationSet
1414

1515
RLBase.reset!(env)
1616

src/algorithms/cfr/best_response_policy.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,9 @@ function BestResponsePolicy(
1919
policy,
2020
env,
2121
best_responder;
22-
state_type = String,
23-
action_type = Int,
2422
)
25-
# S = typeof(state(env)) # TODO: currently it will break the OpenSpielEnv. Can not get information set for chance player
26-
# A = eltype(action_space(env)) # TODO: for chance players it will return ActionProbPair
27-
S = state_type
28-
A = action_type
23+
S = eltype(state_space(env))
24+
A = eltype(action_space(env))
2925
E = typeof(env)
3026

3127
p = BestResponsePolicy(
@@ -61,8 +57,10 @@ function init_cfr_reach_prob!(p, env, reach_prob = 1.0)
6157
init_cfr_reach_prob!(p, child(env, a), reach_prob)
6258
end
6359
elseif current_player(env) == chance_player(env)
64-
for a::ActionProbPair in action_space(env)
65-
init_cfr_reach_prob!(p, child(env, a), reach_prob * a.prob)
60+
for (a, pₐ) in zip(action_space(env), prob(env))
61+
if pₐ > 0
62+
init_cfr_reach_prob!(p, child(env, a), reach_prob * pₐ)
63+
end
6664
end
6765
else # opponents
6866
for a in legal_action_space(env)
@@ -81,8 +79,12 @@ function best_response_value(p, env)
8179
best_response_value(p, child(env, a))
8280
elseif current_player(env) == chance_player(env)
8381
v = 0.0
84-
for a::ActionProbPair in action_space(env)
85-
v += a.prob * best_response_value(p, child(env, a))
82+
A, P = action_space(env), prob(env)
83+
@assert length(A) == length(P)
84+
for (a, pₐ) in zip(A, P)
85+
if pₐ > 0
86+
v += pₐ * best_response_value(p, child(env, a))
87+
end
8688
end
8789
v
8890
else

src/algorithms/cfr/deep_cfr.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ end
5353

5454
"Run one interation"
5555
function RLBase.update!::DeepCFR, env::AbstractEnv)
56-
for p in get_players(env)
56+
for p in players(env)
5757
if p != chance_player(env)
5858
for k in 1:π.K
5959
external_sampling!(π, copy(env), p)

src/algorithms/cfr/external_sampling_mccfr.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ end
2020

2121
RLBase.prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv) = prob(p.behavior_policy, env)
2222

23+
RLBase.prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv, action) =
24+
prob(p.behavior_policy, env, action)
25+
2326
function ExternalSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_RNG)
2427
ExternalSamplingMCCFRPolicy(
2528
Dict{state_type,InfoStateNode}(),
@@ -36,7 +39,10 @@ function RLBase.update!(p::ExternalSamplingMCCFRPolicy)
3639
for (k, v) in p.nodes
3740
s = sum(v.cumulative_strategy)
3841
if s != 0
39-
update!(p.behavior_policy, k => v.cumulative_strategy ./ s)
42+
m = v.mask
43+
strategy = zeros(length(m))
44+
strategy[m] .= v.cumulative_strategy ./ s
45+
update!(p.behavior_policy, k => strategy)
4046
else
4147
# The TabularLearner will return uniform distribution by default.
4248
# So we do nothing here.
@@ -46,30 +52,31 @@ end
4652

4753
"Run one interation"
4854
function RLBase.update!(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv)
49-
for x in get_players(env)
55+
for x in players(env)
5056
if x != chance_player(env)
5157
external_sampling(copy(env), x, p.nodes, p.rng)
5258
end
5359
end
5460
end
5561

5662
function external_sampling(env, i, nodes, rng)
57-
current_player = current_player(env)
63+
player = current_player(env)
5864

5965
if is_terminated(env)
6066
reward(env, i)
61-
elseif current_player == chance_player(env)
62-
env(rand(rng, action_space(env)))
67+
elseif player == chance_player(env)
68+
env(sample(rng, action_space(env), Weights(prob(env), 1.0)))
6369
external_sampling(env, i, nodes, rng)
6470
else
6571
I = state(env)
6672
legal_actions = legal_action_space(env)
73+
M = legal_action_space_mask(env)
6774
n = length(legal_actions)
68-
node = get!(nodes, I, InfoStateNode(n))
75+
node = get!(nodes, I, InfoStateNode(M))
6976
regret_matching!(node; is_reset_neg_regrets = false)
7077
σ, rI, sI = node.strategy, node.cumulative_regret, node.cumulative_strategy
7178

72-
if i == current_player
79+
if i == player
7380
u = zeros(n)
7481
= 0
7582
for (aᵢ, a) in enumerate(legal_actions)

src/algorithms/cfr/nash_conv.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@ export expected_policy_values, nash_conv
22

33
function expected_policy_values::AbstractPolicy, env::AbstractEnv)
44
if is_terminated(env)
5-
[reward(env, p) for p in get_players(env) if p != chance_player(env)]
5+
[reward(env, p) for p in players(env) if p != chance_player(env)]
66
elseif current_player(env) == chance_player(env)
7-
vals = [0.0 for p in get_players(env) if p != chance_player(env)]
8-
for a::ActionProbPair in legal_action_space(env)
9-
vals .+= a.prob .* expected_policy_values(π, child(env, a))
7+
vals = [0.0 for p in players(env) if p != chance_player(env)]
8+
for (a, pₐ) in zip(action_space(env), prob(env))
9+
if pₐ > 0
10+
vals .+= pₐ .* expected_policy_values(π, child(env, a))
11+
end
1012
end
1113
vals
1214
else
13-
vals = [0.0 for p in get_players(env) if p != chance_player(env)]
15+
vals = [0.0 for p in players(env) if p != chance_player(env)]
1416
actions = action_space(env)
1517
probs = prob(π, env)
1618
@assert length(actions) == length(probs)
@@ -30,7 +32,7 @@ function nash_conv(π, env; is_reduce = true, kw...)
3032

3133
σ′ = [
3234
best_response_value(BestResponsePolicy(π, e, i; kw...), e)
33-
for i in get_players(e) if i != chance_player(e)
35+
for i in players(e) if i != chance_player(e)
3436
]
3537

3638
σ = expected_policy_values(π, e)

src/algorithms/cfr/outcome_sampling_mccfr.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ end
2121

2222
RLBase.prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv) = prob(p.behavior_policy, env)
2323

24+
RLBase.prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv, action) =
25+
prob(p.behavior_policy, env, action)
26+
2427
function OutcomeSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_RNG, ϵ = 0.6)
2528
OutcomeSamplingMCCFRPolicy(
2629
Dict{state_type,InfoStateNode}(),
@@ -36,7 +39,7 @@ end
3639

3740
"Run one interation"
3841
function RLBase.update!(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv)
39-
for x in get_players(env)
42+
for x in players(env)
4043
if x != chance_player(env)
4144
outcome_sampling(copy(env), x, p.nodes, p.ϵ, 1.0, 1.0, 1.0, p.rng)
4245
end
@@ -47,7 +50,10 @@ function RLBase.update!(p::OutcomeSamplingMCCFRPolicy)
4750
for (k, v) in p.nodes
4851
s = sum(v.cumulative_strategy)
4952
if s != 0
50-
update!(p.behavior_policy, k => v.cumulative_strategy ./ s)
53+
m = v.mask
54+
strategy = zeros(length(m))
55+
strategy[m] .= v.cumulative_strategy ./ s
56+
update!(p.behavior_policy, k => strategy)
5157
else
5258
# The TabularLearner will return uniform distribution by default.
5359
# So we do nothing here.
@@ -56,22 +62,24 @@ function RLBase.update!(p::OutcomeSamplingMCCFRPolicy)
5662
end
5763

5864
function outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
59-
current_player = current_player(env)
65+
player = current_player(env)
6066

6167
if is_terminated(env)
6268
reward(env, i) / s, 1.0
63-
elseif current_player == chance_player(env)
64-
env(rand(rng, action_space(env)))
69+
elseif player == chance_player(env)
70+
x = sample(rng, action_space(env), Weights(prob(env), 1.0))
71+
env(sample(rng, action_space(env), Weights(prob(env), 1.0)))
6572
outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
6673
else
6774
I = state(env)
6875
legal_actions = legal_action_space(env)
76+
M = legal_action_space_mask(env)
6977
n = length(legal_actions)
70-
node = get!(nodes, I, InfoStateNode(n))
78+
node = get!(nodes, I, InfoStateNode(M))
7179
regret_matching!(node; is_reset_neg_regrets = false)
7280
σ, rI, sI = node.strategy, node.cumulative_regret, node.cumulative_strategy
7381

74-
if i == current_player
82+
if i == player
7583
aᵢ = rand(rng) >= ϵ ? sample(rng, Weights(σ, 1.0)) : rand(rng, 1:n)
7684
pᵢ = σ[aᵢ] * (1 - ϵ) + ϵ / n
7785
πᵢ′, π₋ᵢ′, s′ = πᵢ * pᵢ, π₋ᵢ, s * pᵢ
@@ -84,7 +92,7 @@ function outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
8492
env(legal_action_space(env)[aᵢ])
8593
u, πₜₐᵢₗ = outcome_sampling(env, i, nodes, ϵ, πᵢ′, π₋ᵢ′, s′, rng)
8694

87-
if i == current_player
95+
if i == player
8896
w = u * π₋ᵢ
8997
rI .+= w * πₜₐᵢₗ .* ((1:n .== aᵢ) .- σ[aᵢ])
9098
else

src/algorithms/cfr/tabular_cfr.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
export TabularCFRPolicy
22

3-
struct InfoStateNode
3+
struct InfoStateNode{M<:AbstractVector{Bool}}
44
strategy::Vector{Float64}
55
cumulative_regret::Vector{Float64}
66
cumulative_strategy::Vector{Float64}
7+
mask::M
78
end
89

9-
InfoStateNode(n) = InfoStateNode(fill(1 / n, n), zeros(n), zeros(n))
10+
function InfoStateNode(mask)
11+
n = sum(mask)
12+
InfoStateNode(
13+
fill(1 / n, n),
14+
zeros(n),
15+
zeros(n),
16+
mask
17+
)
18+
end
1019

1120
#####
1221
# TabularCFRPolicy
@@ -77,7 +86,10 @@ function RLBase.update!(p::TabularCFRPolicy)
7786
for (k, v) in p.nodes
7887
s = sum(v.cumulative_strategy)
7988
if s != 0
80-
update!(p.behavior_policy, k => v.cumulative_strategy ./ s)
89+
m = v.mask
90+
strategy = zeros(length(m))
91+
strategy[m] .= v.cumulative_strategy ./ s
92+
update!(p.behavior_policy, k => strategy)
8193
else
8294
# The TabularLearner will return uniform distribution by default.
8395
# So we do nothing here.
@@ -89,7 +101,7 @@ end
89101
function RLBase.update!(p::TabularCFRPolicy, env::AbstractEnv)
90102
w = p.is_linear_averaging ? max(p.n_iteration - p.weighted_averaging_delay, 0) : 1
91103
if p.is_alternating_update
92-
for x in get_players(env)
104+
for x in players(env)
93105
if x != chance_player(env)
94106
cfr!(p.nodes, env, x, w)
95107
regret_matching!(p)
@@ -113,22 +125,25 @@ w: weight
113125
v: counterfactual value **before weighted by opponent's reaching probability**
114126
V: a vector containing the `v` after taking each action with current information set. Used to calculate the **regret value**
115127
"""
116-
function cfr!(nodes, env, p, w, π = Dict(x => 1.0 for x in get_players(env)))
128+
function cfr!(nodes, env, p, w, π = Dict(x => 1.0 for x in players(env)))
117129
if is_terminated(env)
118130
reward(env, p)
119131
else
120132
if current_player(env) == chance_player(env)
121133
v = 0.0
122-
for a::ActionProbPair in legal_action_space(env)
123-
π′ = copy(π)
124-
π′[current_player(env)] *= a.prob
125-
v += a.prob * cfr!(nodes, child(env, a), p, w, π′)
134+
for (a, pₐ) in zip(action_space(env), prob(env))
135+
if pₐ > 0
136+
π′ = copy(π)
137+
π′[current_player(env)] *= pₐ
138+
v += pₐ * cfr!(nodes, child(env, a), p, w, π′)
139+
end
126140
end
127141
v
128142
else
129143
v = 0.0
130144
legal_actions = legal_action_space(env)
131-
node = get!(nodes, state(env), InfoStateNode(length(legal_actions)))
145+
M = legal_action_space_mask(env)
146+
node = get!(nodes, state(env), InfoStateNode(M))
132147

133148
is_update = isnothing(p) || p == current_player(env)
134149
V = is_update ? Vector{Float64}(undef, length(legal_actions)) : nothing

src/algorithms/policy_gradient/multi_thread_env.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@ struct MultiThreadEnv{E,S,R,AS,SS,L} <: AbstractEnv
2020
end
2121

2222
function Base.show(io::IO, t::MIME"text/markdown", env::MultiThreadEnv)
23-
s = """
24-
# MultiThreadEnv($(length(env)) x $(nameof(env[1])))
25-
"""
26-
show(io, t, Markdown.parse(s))
23+
print(io, "MultiThreadEnv($(length(env)) x $(nameof(env[1])))")
2724
end
2825

2926
"""
@@ -146,3 +143,25 @@ for f in RLBase.ENV_API
146143
@eval RLBase.$f(x::MultiThreadEnv) = $f(x[1])
147144
end
148145
end
146+
147+
#####
148+
# Patches
149+
#####
150+
151+
(env::MultiThreadEnv)(action::EnrichedAction) = env(action.action)
152+
153+
function::QBasedPolicy)(env::MultiThreadEnv, ::MinimalActionSet, A)
154+
[A[i][a] for (i, a) in enumerate.explorer.learner(env)))]
155+
end
156+
157+
function::QBasedPolicy)(env::MultiThreadEnv, ::FullActionSet, A)
158+
[A[i][a] for (i,a) in enumerate.explorer.learner(env), legal_action_space_mask(env)))]
159+
end
160+
161+
function::QBasedPolicy)(env::MultiThreadEnv, ::MinimalActionSet, ::Space{<:Vector{<:Base.OneTo{<:Integer}}})
162+
π.explorer.learner(env))
163+
end
164+
165+
function::QBasedPolicy)(env::MultiThreadEnv, ::FullActionSet, ::Space{<:Vector{<:Base.OneTo{<:Integer}}})
166+
π.explorer.learner(env), legal_action_space_mask(env))
167+
end

0 commit comments

Comments
 (0)