Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 8668f3c

Browse files
Add REM-DQN(Random Ensemble Mixture) method (#160)
* add some explanations * Add REM DQN * Add docs * Add docs * Modified implementation * Some modifications * fix conflict Co-authored-by: Jun Tian <find_my_way@foxmail.com>
1 parent 2908338 commit 8668f3c

5 files changed

Lines changed: 243 additions & 3 deletions

File tree

src/algorithms/dqns/common.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
const PERLearners = Union{PrioritizedDQNLearner,RainbowLearner,IQNLearner}
66

7-
function RLBase.update!(learner::Union{DQNLearner,PERLearners}, t::AbstractTrajectory)
7+
function RLBase.update!(learner::Union{DQNLearner,REMDQNLearner,PERLearners}, t::AbstractTrajectory)
88
length(t[:terminal]) - learner.sampler.n <= learner.min_replay_history && return
99

1010
learner.update_step += 1

src/algorithms/dqns/dqns.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
include("basic_dqn.jl")
22
include("dqn.jl")
33
include("prioritized_dqn.jl")
4+
include("rem_dqn.jl")
45
include("rainbow.jl")
56
include("iqn.jl")
6-
include("common.jl")
7+
include("common.jl")

src/algorithms/dqns/rem_dqn.jl

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
export REMDQNLearner
2+
3+
mutable struct REMDQNLearner{
4+
Tq<:AbstractApproximator,
5+
Tt<:AbstractApproximator,
6+
Tf,
7+
R<:AbstractRNG,
8+
} <: AbstractLearner
9+
approximator::Tq
10+
target_approximator::Tt
11+
loss_func::Tf
12+
min_replay_history::Int
13+
update_freq::Int
14+
update_step::Int
15+
target_update_freq::Int
16+
sampler::NStepBatchSampler
17+
ensemble_num::Int
18+
ensemble_method::Symbol
19+
rng::R
20+
# for logging
21+
loss::Float32
22+
end
23+
24+
"""
25+
REMDQNLearner(;kwargs...)
26+
27+
See paper: [An Optimistic Perspective on Offline Reinforcement Learning](https://arxiv.org/abs/1907.04543)
28+
29+
# Keywords
30+
31+
- `approximator`::[`AbstractApproximator`](@ref): used to get Q-values of a state.
32+
- `target_approximator`::[`AbstractApproximator`](@ref): similar to `approximator`, but used to estimate the target (the next state).
33+
- `loss_func`: the loss function.
34+
- `γ::Float32=0.99f0`: discount rate.
35+
- `batch_size::Int=32`
36+
- `update_horizon::Int=1`: length of update ('n' in n-step update).
37+
- `min_replay_history::Int=32`: number of transitions that should be experienced before updating the `approximator`.
38+
- `update_freq::Int=4`: the frequency of updating the `approximator`.
39+
- `ensemble_num::Int=1`: the number of ensemble approximators.
40+
- `ensemble_method::Symbol=:rand`: the method of combining Q values. ':rand' represents random ensemble mixture, and ':mean' is the average.
41+
- `target_update_freq::Int=100`: the frequency of syncing `target_approximator`.
42+
- `stack_size::Union{Int, Nothing}=4`: use the recent `stack_size` frames to form a stacked state.
43+
- `traces = SARTS`, set to `SLARTSL` if you are to apply to an environment of `FULL_ACTION_SET`.
44+
- `rng = Random.GLOBAL_RNG`
45+
"""
46+
function REMDQNLearner(;
47+
approximator::Tq,
48+
target_approximator::Tt,
49+
loss_func::Tf,
50+
stack_size::Union{Int,Nothing} = nothing,
51+
γ::Float32 = 0.99f0,
52+
batch_size::Int = 32,
53+
update_horizon::Int = 1,
54+
min_replay_history::Int = 32,
55+
update_freq::Int = 1,
56+
ensemble_num::Int = 1,
57+
ensemble_method::Symbol = :rand,
58+
target_update_freq::Int = 100,
59+
traces = SARTS,
60+
update_step = 0,
61+
rng = Random.GLOBAL_RNG,
62+
) where {Tq,Tt,Tf}
63+
copyto!(approximator, target_approximator)
64+
sampler = NStepBatchSampler{traces}(;
65+
γ = γ,
66+
n = update_horizon,
67+
stack_size = stack_size,
68+
batch_size = batch_size,
69+
)
70+
REMDQNLearner(
71+
approximator,
72+
target_approximator,
73+
loss_func,
74+
min_replay_history,
75+
update_freq,
76+
update_step,
77+
target_update_freq,
78+
sampler,
79+
ensemble_num,
80+
ensemble_method,
81+
rng,
82+
0.0f0,
83+
)
84+
end
85+
86+
Flux.functor(x::REMDQNLearner) = (Q = x.approximator, Qₜ = x.target_approximator),
87+
y -> begin
88+
x = @set x.approximator = y.Q
89+
x = @set x.target_approximator = y.Qₜ
90+
x
91+
end
92+
93+
function (learner::REMDQNLearner)(env)
94+
s = send_to_device(device(learner.approximator), state(env))
95+
s = Flux.unsqueeze(s, ndims(s) + 1)
96+
q = reshape(learner.approximator(s), :, learner.ensemble_num)
97+
vec(mean(q, dims = 2)) |> send_to_host
98+
end
99+
100+
function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple)
101+
Q = learner.approximator
102+
Qₜ = learner.target_approximator
103+
γ = learner.sampler.γ
104+
loss_func = learner.loss_func
105+
n = learner.sampler.n
106+
batch_size = learner.sampler.batch_size
107+
ensemble_num = learner.ensemble_num
108+
D = device(Q)
109+
# Build a convex polygon to make a combination of multiple Q-value estimates as a Q-value estimate.
110+
if learner.ensemble_method == :rand
111+
convex_polygon = rand(Float32, (1, ensemble_num))
112+
else
113+
convex_polygon = ones(Float32, (1, ensemble_num))
114+
end
115+
convex_polygon ./= sum(convex_polygon)
116+
convex_polygon = send_to_device(D, convex_polygon)
117+
118+
s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS)
119+
a = CartesianIndex.(a, 1:batch_size)
120+
121+
target_q = Qₜ(s′)
122+
target_q = convex_polygon .* reshape(target_q, :, ensemble_num, batch_size)
123+
target_q = dropdims(sum(target_q, dims=2), dims=2)
124+
125+
if haskey(batch, :next_legal_actions_mask)
126+
l′ = send_to_device(D, batch[:next_legal_actions_mask])
127+
target_q .+= ifelse.(l′, 0.0f0, typemin(Float32))
128+
end
129+
130+
q′ = dropdims(maximum(target_q; dims = 1), dims = 1)
131+
G = r .+ γ^n .* (1 .- t) .* q′
132+
133+
gs = gradient(params(Q)) do
134+
q = Q(s)
135+
q = convex_polygon .* reshape(q, :, ensemble_num, batch_size)
136+
q = dropdims(sum(q, dims=2), dims=2)[a]
137+
138+
loss = loss_func(G, q)
139+
ignore() do
140+
learner.loss = loss
141+
end
142+
loss
143+
end
144+
145+
update!(Q, gs)
146+
end
147+
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
function RLCore.Experiment(
2+
::Val{:JuliaRL},
3+
::Val{:REMDQN},
4+
::Val{:CartPole},
5+
::Nothing;
6+
save_dir = nothing,
7+
seed = 123,
8+
)
9+
if isnothing(save_dir)
10+
t = Dates.format(now(), "yyyy_mm_dd_HH_MM_SS")
11+
save_dir = joinpath(pwd(), "checkpoints", "JuliaRL_REMDQN_CartPole_$(t)")
12+
end
13+
14+
lg = TBLogger(joinpath(save_dir, "tb_log"), min_level = Logging.Info)
15+
rng = StableRNG(seed)
16+
17+
env = CartPoleEnv(; T = Float32, rng = rng)
18+
ns, na = length(state(env)), length(action_space(env))
19+
ensemble_num = 6
20+
21+
agent = Agent(
22+
policy = QBasedPolicy(
23+
learner = REMDQNLearner(
24+
approximator = NeuralNetworkApproximator(
25+
model = Chain(
26+
# Multi-head method, please refer to "https://github.com/google-research/batch_rl/tree/b55ba35ebd2381199125dd77bfac9e9c59a64d74/batch_rl/multi_head".
27+
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
28+
Dense(128, 128, relu; initW = glorot_uniform(rng)),
29+
Dense(128, na * ensemble_num; initW = glorot_uniform(rng)),
30+
) |> cpu,
31+
optimizer = ADAM(),
32+
),
33+
target_approximator = NeuralNetworkApproximator(
34+
model = Chain(
35+
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
36+
Dense(128, 128, relu; initW = glorot_uniform(rng)),
37+
Dense(128, na * ensemble_num; initW = glorot_uniform(rng)),
38+
) |> cpu,
39+
),
40+
loss_func = huber_loss,
41+
stack_size = nothing,
42+
batch_size = 32,
43+
update_horizon = 1,
44+
min_replay_history = 100,
45+
update_freq = 1,
46+
target_update_freq = 100,
47+
ensemble_num = ensemble_num,
48+
ensemble_method = :rand,
49+
rng = rng,
50+
),
51+
explorer = EpsilonGreedyExplorer(
52+
kind = :exp,
53+
ϵ_stable = 0.01,
54+
decay_steps = 500,
55+
rng = rng,
56+
),
57+
),
58+
trajectory = CircularArraySARTTrajectory(
59+
capacity = 1000,
60+
state = Vector{Float32} => (ns,),
61+
),
62+
)
63+
64+
stop_condition = StopAfterStep(10_000)
65+
66+
total_reward_per_episode = TotalRewardPerEpisode()
67+
time_per_step = TimePerStep()
68+
hook = ComposedHook(
69+
total_reward_per_episode,
70+
time_per_step,
71+
DoEveryNStep() do t, agent, env
72+
if agent.policy.learner.update_step % agent.policy.learner.update_freq == 0
73+
with_logger(lg) do
74+
@info "training" loss = agent.policy.learner.loss
75+
end
76+
end
77+
end,
78+
DoEveryNEpisode() do t, agent, env
79+
with_logger(lg) do
80+
@info "training" reward = total_reward_per_episode.rewards[end] log_step_increment =
81+
0
82+
end
83+
end,
84+
)
85+
86+
description = """
87+
This experiment uses the `REMDQNLearner` method with three dense layers to approximate the Q value.
88+
The testing environment is CartPoleEnv.
89+
"""
90+
91+
Experiment(agent, env, stop_condition, hook, description)
92+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333

3434
@testset "training" begin
3535
mktempdir() do dir
36-
for method in (:BasicDQN, :BC, :DQN, :PrioritizedDQN, :Rainbow, :IQN, :VPG)
36+
for method in (:BasicDQN, :BC, :DQN, :PrioritizedDQN, :Rainbow, :REMDQN, :IQN, :VPG)
3737
res = run(
3838
Experiment(
3939
Val(:JuliaRL),

0 commit comments

Comments
 (0)