Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit e837463

Browse files
authored
Drop dependency of RLEnvs (#136)
* drop dependency of RLEnvs * remove dependency on RLEnvs * minor fix * fix warnings
1 parent 4632465 commit e837463

7 files changed

Lines changed: 156 additions & 3 deletions

File tree

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1919
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2020
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
2121
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
22-
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
2322
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2423
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2524
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -40,7 +39,6 @@ IntervalSets = "0.5"
4039
MacroTools = "0.5"
4140
ReinforcementLearningBase = "0.9"
4241
ReinforcementLearningCore = "0.6.1"
43-
ReinforcementLearningEnvironments = "0.4"
4442
Requires = "1"
4543
Setfield = "0.6, 0.7"
4644
StableRNGs = "1.0"

src/ReinforcementLearningZoo.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ export RLZoo
66
using CircularArrayBuffers
77
using ReinforcementLearningBase
88
using ReinforcementLearningCore
9-
using ReinforcementLearningEnvironments
109
using Setfield: @set
1110
using StableRNGs
1211
using Logging

src/algorithms/policy_gradient/A2C.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,5 @@ function _update!(learner::A2CLearner, t::CircularArraySARTTrajectory)
110110
end
111111
update!(AC, gs)
112112
end
113+
114+
RLCore.check(::QBasedPolicy{<:A2CLearner}, ::MultiThreadEnv) = nothing

src/algorithms/policy_gradient/A2CGAE.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,5 @@ function _update!(learner::A2CGAELearner, t::CircularArraySARTTrajectory)
113113

114114
update!(AC, gs)
115115
end
116+
117+
RLCore.check(::QBasedPolicy{<:A2CGAELearner}, ::MultiThreadEnv) = nothing

src/algorithms/policy_gradient/MAC.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,5 @@ function _update!(learner::MACLearner, t::CircularArraySARTTrajectory)
133133
end
134134
update!(AC.critic, gs2)
135135
end
136+
137+
RLCore.check(::QBasedPolicy{<:MACLearner}, ::MultiThreadEnv) = nothing
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
export MultiThreadEnv
2+
3+
using Base.Threads:@spawn
4+
5+
"""
6+
MultiThreadEnv(envs::Vector{<:AbstractEnv})
7+
8+
Wrap multiple instances of the same environment type into one environment.
9+
Each environment will run in parallel by leveraging `Threads.@spawn`.
10+
So remember to set the environment variable `JULIA_NUM_THREADS`!
11+
"""
12+
struct MultiThreadEnv{E,S,R,AS,SS,L} <: AbstractEnv
13+
envs::Vector{E}
14+
states::S
15+
rewards::R
16+
terminals::BitArray{1}
17+
action_space::AS
18+
state_space::SS
19+
legal_action_space_mask::L
20+
end
21+
22+
function Base.show(io::IO, t::MIME"text/markdown", env::MultiThreadEnv)
23+
s = """
24+
# MultiThreadEnv($(length(env)) x $(nameof(env[1])))
25+
"""
26+
show(io, t, Markdown.parse(s))
27+
end
28+
29+
"""
30+
MultiThreadEnv(f, n::Int)
31+
32+
`f` is a lambda function which creates an `AbstractEnv` by calling `f()`.
33+
"""
34+
MultiThreadEnv(f, n::Int) = MultiThreadEnv([f() for _ in 1:n])
35+
36+
function MultiThreadEnv(envs::Vector{<:AbstractEnv})
37+
n = length(envs)
38+
S = state_space(envs[1])
39+
s = state(envs[1])
40+
if S isa Space
41+
S_batch = similar(S, size(S)..., n)
42+
s_batch = similar(s, size(s)..., n)
43+
for j in 1:n
44+
Sₙ = state_space(envs[j])
45+
sₙ = state(envs[j])
46+
for i in CartesianIndices(size(S))
47+
S_batch[i, j] = Sₙ[i]
48+
s_batch[i, j] = sₙ[i]
49+
end
50+
end
51+
else
52+
S_batch = Space(state_space.(envs))
53+
s_batch = state.(envs)
54+
end
55+
56+
A = action_space(envs[1])
57+
if A isa Space
58+
A_batch = similar(A, size(A)..., n)
59+
for j in 1:n
60+
Aⱼ = action_space(envs[j])
61+
for i in CartesianIndices(size(A))
62+
A_batch[i, j] = Aⱼ[i]
63+
end
64+
end
65+
else
66+
A_batch = Space(action_space.(envs))
67+
end
68+
69+
r_batch = reward.(envs)
70+
t_batch = is_terminated.(envs)
71+
if ActionStyle(envs[1]) === FULL_ACTION_SET
72+
m_batch = BitArray(undef, size(A_batch))
73+
for j in 1:n
74+
L = legal_action_space_mask(envs[j])
75+
for i in CartesianIndices(size(A))
76+
m_batch[i, j] = L[i]
77+
end
78+
end
79+
else
80+
m_batch = nothing
81+
end
82+
MultiThreadEnv(envs, s_batch, r_batch, t_batch, A_batch, S_batch, m_batch)
83+
end
84+
85+
MacroTools.@forward MultiThreadEnv.envs Base.getindex, Base.length, Base.iterate
86+
87+
function (env::MultiThreadEnv)(actions)
88+
@sync for i in 1:length(env)
89+
@spawn begin
90+
env[i](actions[i])
91+
end
92+
end
93+
end
94+
95+
function RLBase.reset!(env::MultiThreadEnv; is_force = false)
96+
if is_force
97+
for i in 1:length(env)
98+
reset!(env[i])
99+
end
100+
else
101+
@sync for i in 1:length(env)
102+
if is_terminated(env[i])
103+
@spawn begin
104+
reset!(env[i])
105+
end
106+
end
107+
end
108+
end
109+
end
110+
111+
const MULTI_THREAD_ENV_CACHE = IdDict{AbstractEnv,Dict{Symbol,Array}}()
112+
113+
function RLBase.state(env::MultiThreadEnv)
114+
N = ndims(env.states)
115+
@sync for i in 1:length(env)
116+
@spawn selectdim(env.states, N, i) .= state(env[i])
117+
end
118+
env.states
119+
end
120+
121+
function RLBase.reward(env::MultiThreadEnv)
122+
env.rewards .= reward.(env.envs)
123+
env.rewards
124+
end
125+
126+
function RLBase.is_terminated(env::MultiThreadEnv)
127+
env.terminals .= is_terminated.(env.envs)
128+
env.terminals
129+
end
130+
131+
function RLBase.legal_action_space_mask(env::MultiThreadEnv)
132+
@sync for i in 1:length(env)
133+
@spawn selectdim(env.legal_action_space_mask, N, i) .=
134+
legal_action_space_mask(env[i])
135+
end
136+
env.legal_action_space_mask
137+
end
138+
139+
RLBase.action_space(env::MultiThreadEnv) = env.action_space
140+
RLBase.state_space(env::MultiThreadEnv) = env.state_space
141+
RLBase.legal_action_space(env::MultiThreadEnv) = Space(legal_action_space.(env.envs))
142+
# RLBase.current_player(env::MultiThreadEnv) = current_player.(env.envs)
143+
144+
for f in RLBase.ENV_API
145+
if endswith(String(f), "Style")
146+
@eval RLBase.$f(x::MultiThreadEnv) = $f(x[1])
147+
end
148+
end

src/algorithms/policy_gradient/run.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
include("multi_thread_env.jl")
2+
13
"""
24
Many policy gradient based algorithms require that the `env` is a
35
`MultiThreadEnv` to increase the diversity during training. So the training

0 commit comments

Comments
 (0)