Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 2c0e248

Browse files
Update dependency (#135)
* moved imports * removed repeated imports * move using into RLZoo.jl * remove other duplications * update dependency Co-authored-by: Rishabh Varshney <rishabhvarshney14@gmail.com>
1 parent f55c858 commit 2c0e248

49 files changed

Lines changed: 232 additions & 1167 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
1111
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1212
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1313
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
14+
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1617
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1718
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1819
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1920
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
2021
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
22+
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
2123
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2224
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2325
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -34,9 +36,11 @@ CUDA = "1, 2.1"
3436
CircularArrayBuffers = "0.1"
3537
Distributions = "0.24"
3638
Flux = "0.11"
39+
IntervalSets = "0.5"
3740
MacroTools = "0.5"
38-
ReinforcementLearningBase = "0.8.4"
39-
ReinforcementLearningCore = "0.6"
41+
ReinforcementLearningBase = "0.9"
42+
ReinforcementLearningCore = "0.6.1"
43+
ReinforcementLearningEnvironments = "0.4"
4044
Requires = "1"
4145
Setfield = "0.6, 0.7"
4246
StableRNGs = "1.0"

src/ReinforcementLearningZoo.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ export RLZoo
66
using CircularArrayBuffers
77
using ReinforcementLearningBase
88
using ReinforcementLearningCore
9+
using ReinforcementLearningEnvironments
910
using Setfield: @set
1011
using StableRNGs
1112
using Logging
1213
using Flux.Losses
1314
using Dates
15+
using IntervalSets
1416
using Random
1517
using Random: shuffle
1618
using CUDA

src/algorithms/cfr/best_response_policy.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ function BestResponsePolicy(
2222
state_type = String,
2323
action_type = Int,
2424
)
25-
# S = typeof(get_state(env)) # TODO: currently it will break the OpenSpielEnv. Can not get information set for chance player
26-
# A = eltype(get_actions(env)) # TODO: for chance players it will return ActionProbPair
25+
# S = typeof(state(env)) # TODO: currently it will break the OpenSpielEnv. Can not get information set for chance player
26+
# A = eltype(action_space(env)) # TODO: for chance players it will return ActionProbPair
2727
S = state_type
2828
A = action_type
2929
E = typeof(env)
@@ -45,31 +45,31 @@ function BestResponsePolicy(
4545
end
4646

4747
function (p::BestResponsePolicy)(env::AbstractEnv)
48-
if get_current_player(env) == p.best_responder
48+
if current_player(env) == p.best_responder
4949
best_response_action(p, env)
5050
else
5151
p.policy(env)
5252
end
5353
end
5454

5555
function init_cfr_reach_prob!(p, env, reach_prob = 1.0)
56-
if !get_terminal(env)
57-
if get_current_player(env) == p.best_responder
58-
push!(get!(p.cfr_reach_prob, get_state(env), []), env => reach_prob)
56+
if !is_terminated(env)
57+
if current_player(env) == p.best_responder
58+
push!(get!(p.cfr_reach_prob, state(env), []), env => reach_prob)
5959

60-
for a in get_legal_actions(env)
60+
for a in legal_action_space(env)
6161
init_cfr_reach_prob!(p, child(env, a), reach_prob)
6262
end
63-
elseif get_current_player(env) == get_chance_player(env)
64-
for a::ActionProbPair in get_actions(env)
63+
elseif current_player(env) == chance_player(env)
64+
for a::ActionProbPair in action_space(env)
6565
init_cfr_reach_prob!(p, child(env, a), reach_prob * a.prob)
6666
end
6767
else # opponents
68-
for a in get_legal_actions(env)
68+
for a in legal_action_space(env)
6969
init_cfr_reach_prob!(
7070
p,
7171
child(env, a),
72-
reach_prob * get_prob(p.policy, env, a),
72+
reach_prob * prob(p.policy, env, a),
7373
)
7474
end
7575
end
@@ -78,34 +78,34 @@ end
7878

7979
function best_response_value(p, env)
8080
get!(p.best_response_value_cache, env) do
81-
if get_terminal(env)
82-
get_reward(env, p.best_responder)
83-
elseif get_current_player(env) == p.best_responder
81+
if is_terminated(env)
82+
reward(env, p.best_responder)
83+
elseif current_player(env) == p.best_responder
8484
a = best_response_action(p, env)
8585
best_response_value(p, child(env, a))
86-
elseif get_current_player(env) == get_chance_player(env)
86+
elseif current_player(env) == chance_player(env)
8787
v = 0.0
88-
for a::ActionProbPair in get_actions(env)
88+
for a::ActionProbPair in action_space(env)
8989
v += a.prob * best_response_value(p, child(env, a))
9090
end
9191
v
9292
else
9393
v = 0.0
94-
for a in get_legal_actions(env)
95-
v += get_prob(p.policy, env, a) * best_response_value(p, child(env, a))
94+
for a in legal_action_space(env)
95+
v += prob(p.policy, env, a) * best_response_value(p, child(env, a))
9696
end
9797
v
9898
end
9999
end
100100
end
101101

102102
function best_response_action(p, env)
103-
get!(p.best_response_action_cache, get_state(env)) do
103+
get!(p.best_response_action_cache, state(env)) do
104104
best_action, best_action_value = nothing, typemin(Float64)
105-
for a in get_legal_actions(env)
106-
# for each information set (`get_state(env)` here), we may have several paths to reach it
105+
for a in legal_action_space(env)
106+
# for each information set (`state(env)` here), we may have several paths to reach it
107107
# here we sum the cfr reach prob weighted value to find out the best action
108-
v = sum(p.cfr_reach_prob[get_state(env)]) do (e, reach_prob)
108+
v = sum(p.cfr_reach_prob[state(env)]) do (e, reach_prob)
109109
reach_prob * best_response_value(p, child(e, a))
110110
end
111111
if v > best_action_value
@@ -118,10 +118,10 @@ end
118118

119119
RLBase.update!(p::BestResponsePolicy, args...) = nothing
120120

121-
function RLBase.get_prob(p::BestResponsePolicy, env::AbstractEnv)
122-
if get_current_player(env) == p.best_responder
123-
onehot(p(env), get_actions(env))
121+
function RLBase.prob(p::BestResponsePolicy, env::AbstractEnv)
122+
if current_player(env) == p.best_responder
123+
onehot(p(env), action_space(env))
124124
else
125-
get_prob(p.policy, env)
125+
prob(p.policy, env)
126126
end
127127
end

src/algorithms/cfr/deep_cfr.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,21 @@ Base.@kwdef mutable struct DeepCFR{TP,TV,TMP,TMV,I,R,P} <: AbstractCFRPolicy
4040
Dict(k => zeros(Float32, n_training_steps_V) for (k, _) in MV)
4141
end
4242

43-
function RLBase.get_prob::DeepCFR, env::AbstractEnv)
44-
I = send_to_device(device.Π), get_state(env))
45-
m = send_to_device(device.Π), ifelse.(get_legal_actions_mask(env), 0.0f0, -Inf32))
43+
function RLBase.prob::DeepCFR, env::AbstractEnv)
44+
I = send_to_device(device.Π), state(env))
45+
m = send_to_device(device.Π), ifelse.(legal_action_space_mask(env), 0.0f0, -Inf32))
4646
logits = π.Π(Flux.unsqueeze(I, ndims(I) + 1)) |> vec
4747
σ = softmax(logits .+ m)
4848
send_to_host(σ)
4949
end
5050

5151
::DeepCFR)(env::AbstractEnv) =
52-
sample.rng, get_actions(env), Weights(get_prob(π, env), 1.0))
52+
sample.rng, action_space(env), Weights(prob(π, env), 1.0))
5353

5454
"Run one interation"
5555
function RLBase.update!::DeepCFR, env::AbstractEnv)
5656
for p in get_players(env)
57-
if p != get_chance_player(env)
57+
if p != chance_player(env)
5858
for k in 1:π.K
5959
external_sampling!(π, copy(env), p)
6060
end
@@ -135,17 +135,17 @@ end
135135

136136
"CFR Traversal with External Sampling"
137137
function external_sampling!::DeepCFR, env::AbstractEnv, p)
138-
if get_terminal(env)
139-
get_reward(env, p)
140-
elseif get_current_player(env) == get_chance_player(env)
141-
env(rand.rng, get_actions(env)))
138+
if is_terminated(env)
139+
reward(env, p)
140+
elseif current_player(env) == chance_player(env)
141+
env(rand.rng, action_space(env)))
142142
external_sampling!(π, env, p)
143-
elseif get_current_player(env) == p
143+
elseif current_player(env) == p
144144
V = π.V[p]
145-
s = get_state(env)
145+
s = state(env)
146146
I = send_to_device(device(V), Flux.unsqueeze(s, ndims(s) + 1))
147-
A = get_actions(env)
148-
m = get_legal_actions_mask(env)
147+
A = action_space(env)
148+
m = legal_action_space_mask(env)
149149
σ = masked_regret_matching(V(I) |> send_to_host |> vec, m)
150150
v = zeros(length(σ))
151151
= 0.0
@@ -158,11 +158,11 @@ function external_sampling!(π::DeepCFR, env::AbstractEnv, p)
158158
push!.MV[p], I = s, t = π.t, r̃ = (v .- v̄) .* m, m = m)
159159
160160
else
161-
V = π.V[get_current_player(env)]
162-
s = get_state(env)
161+
V = π.V[current_player(env)]
162+
s = state(env)
163163
I = send_to_device(device(V), Flux.unsqueeze(s, ndims(s) + 1))
164-
A = get_actions(env)
165-
m = get_legal_actions_mask(env)
164+
A = action_space(env)
165+
m = legal_action_space_mask(env)
166166
σ = masked_regret_matching(V(I) |> send_to_host |> vec, m)
167167
push!.MΠ, I = s, t = π.t, σ = σ, m = m)
168168
a = sample.rng, A, Weights(σ, 1.0))

src/algorithms/cfr/external_sampling_mccfr.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ end
1818

1919
(p::ExternalSamplingMCCFRPolicy)(env::AbstractEnv) = p.behavior_policy(env)
2020

21-
RLBase.get_prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv) =
22-
get_prob(p.behavior_policy, env)
21+
RLBase.prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv) =
22+
prob(p.behavior_policy, env)
2323

2424
function ExternalSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_RNG)
2525
ExternalSamplingMCCFRPolicy(
@@ -48,23 +48,23 @@ end
4848
"Run one interation"
4949
function RLBase.update!(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv)
5050
for x in get_players(env)
51-
if x != get_chance_player(env)
51+
if x != chance_player(env)
5252
external_sampling(copy(env), x, p.nodes, p.rng)
5353
end
5454
end
5555
end
5656

5757
function external_sampling(env, i, nodes, rng)
58-
current_player = get_current_player(env)
58+
current_player = current_player(env)
5959

60-
if get_terminal(env)
61-
get_reward(env, i)
62-
elseif current_player == get_chance_player(env)
63-
env(rand(rng, get_actions(env)))
60+
if is_terminated(env)
61+
reward(env, i)
62+
elseif current_player == chance_player(env)
63+
env(rand(rng, action_space(env)))
6464
external_sampling(env, i, nodes, rng)
6565
else
66-
I = get_state(env)
67-
legal_actions = get_legal_actions(env)
66+
I = state(env)
67+
legal_actions = legal_action_space(env)
6868
n = length(legal_actions)
6969
node = get!(nodes, I, InfoStateNode(n))
7070
regret_matching!(node; is_reset_neg_regrets = false)

src/algorithms/cfr/nash_conv.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
export expected_policy_values, nash_conv
22

33
function expected_policy_values::AbstractPolicy, env::AbstractEnv)
4-
if get_terminal(env)
5-
[get_reward(env, p) for p in get_players(env) if p != get_chance_player(env)]
6-
elseif get_current_player(env) == get_chance_player(env)
7-
vals = [0.0 for p in get_players(env) if p != get_chance_player(env)]
8-
for a::ActionProbPair in get_legal_actions(env)
4+
if is_terminated(env)
5+
[reward(env, p) for p in get_players(env) if p != chance_player(env)]
6+
elseif current_player(env) == chance_player(env)
7+
vals = [0.0 for p in get_players(env) if p != chance_player(env)]
8+
for a::ActionProbPair in legal_action_space(env)
99
vals .+= a.prob .* expected_policy_values(π, child(env, a))
1010
end
1111
vals
1212
else
13-
vals = [0.0 for p in get_players(env) if p != get_chance_player(env)]
14-
actions = get_actions(env)
15-
probs = get_prob(π, env)
13+
vals = [0.0 for p in get_players(env) if p != chance_player(env)]
14+
actions = action_space(env)
15+
probs = prob(π, env)
1616
@assert length(actions) == length(probs)
1717

1818
for (a, p) in zip(actions, probs)
@@ -30,7 +30,7 @@ function nash_conv(π, env; is_reduce = true, kw...)
3030

3131
σ′ = [
3232
best_response_value(BestResponsePolicy(π, e, i; kw...), e)
33-
for i in get_players(e) if i != get_chance_player(e)
33+
for i in get_players(e) if i != chance_player(e)
3434
]
3535

3636
σ = expected_policy_values(π, e)

src/algorithms/cfr/outcome_sampling_mccfr.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ end
1919

2020
(p::OutcomeSamplingMCCFRPolicy)(env::AbstractEnv) = p.behavior_policy(env)
2121

22-
RLBase.get_prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv) =
23-
get_prob(p.behavior_policy, env)
22+
RLBase.prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv) =
23+
prob(p.behavior_policy, env)
2424

2525
function OutcomeSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_RNG, ϵ = 0.6)
2626
OutcomeSamplingMCCFRPolicy(
@@ -38,7 +38,7 @@ end
3838
"Run one interation"
3939
function RLBase.update!(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv)
4040
for x in get_players(env)
41-
if x != get_chance_player(env)
41+
if x != chance_player(env)
4242
outcome_sampling(copy(env), x, p.nodes, p.ϵ, 1.0, 1.0, 1.0, p.rng)
4343
end
4444
end
@@ -57,16 +57,16 @@ function RLBase.update!(p::OutcomeSamplingMCCFRPolicy)
5757
end
5858

5959
function outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
60-
current_player = get_current_player(env)
60+
current_player = current_player(env)
6161

62-
if get_terminal(env)
63-
get_reward(env, i) / s, 1.0
64-
elseif current_player == get_chance_player(env)
65-
env(rand(rng, get_actions(env)))
62+
if is_terminated(env)
63+
reward(env, i) / s, 1.0
64+
elseif current_player == chance_player(env)
65+
env(rand(rng, action_space(env)))
6666
outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
6767
else
68-
I = get_state(env)
69-
legal_actions = get_legal_actions(env)
68+
I = state(env)
69+
legal_actions = legal_action_space(env)
7070
n = length(legal_actions)
7171
node = get!(nodes, I, InfoStateNode(n))
7272
regret_matching!(node; is_reset_neg_regrets = false)
@@ -82,7 +82,7 @@ function outcome_sampling(env, i, nodes, ϵ, πᵢ, π₋ᵢ, s, rng)
8282
πᵢ′, π₋ᵢ′, s′ = πᵢ, π₋ᵢ * pᵢ, s * pᵢ
8383
end
8484

85-
env(get_legal_actions(env)[aᵢ])
85+
env(legal_action_space(env)[aᵢ])
8686
u, πₜₐᵢₗ = outcome_sampling(env, i, nodes, ϵ, πᵢ′, π₋ᵢ′, s′, rng)
8787

8888
if i == current_player

0 commit comments

Comments
 (0)