Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 3663eee

Browse files
authored
Implemented double DQN (#156)
* Implemented double DQN Double DQN with an optional argument to disable it. * Implemented in matrix format Fixed the naming issue and changed the code to implement double DQN in matrix format.
1 parent 5e1db96 commit 3663eee

1 file changed

Lines changed: 20 additions & 4 deletions

File tree

src/algorithms/dqns/dqn.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ mutable struct DQNLearner{
1717
rng::R
1818
# for logging
1919
loss::Float32
20+
is_enable_double_DQN::Bool
2021
end
2122

2223
"""
@@ -36,8 +37,9 @@ See paper: [Human-level control through deep reinforcement learning](https://www
3637
- `update_freq::Int=4`: the frequency of updating the `approximator`.
3738
- `target_update_freq::Int=100`: the frequency of syncing `target_approximator`.
3839
- `stack_size::Union{Int, Nothing}=4`: use the recent `stack_size` frames to form a stacked state.
39-
- `traces = SARTS`, set to `SLARTSL` if you are to apply to an environment of `FULL_ACTION_SET`.
40+
- `traces = SARTS`: set to `SLARTSL` if you are to apply to an environment of `FULL_ACTION_SET`.
4041
- `rng = Random.GLOBAL_RNG`
42+
- `double = Bool`: Enable double dqn, enabled by default.
4143
"""
4244
function DQNLearner(;
4345
approximator::Tq,
@@ -53,6 +55,7 @@ function DQNLearner(;
5355
traces = SARTS,
5456
update_step = 0,
5557
rng = Random.GLOBAL_RNG,
58+
is_enable_double_DQN::Bool = true
5659
) where {Tq,Tt,Tf}
5760
copyto!(approximator, target_approximator)
5861
sampler = NStepBatchSampler{traces}(;
@@ -72,6 +75,7 @@ function DQNLearner(;
7275
sampler,
7376
rng,
7477
0.0f0,
78+
is_enable_double_DQN
7579
)
7680
end
7781

@@ -102,18 +106,30 @@ function RLBase.update!(learner::DQNLearner, batch::NamedTuple)
102106
loss_func = learner.loss_func
103107
n = learner.sampler.n
104108
batch_size = learner.sampler.batch_size
109+
is_enable_double_DQN = learner.is_enable_double_DQN
105110
D = device(Q)
106111

107112
s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS)
108113
a = CartesianIndex.(a, 1:batch_size)
109114

110-
target_q = Qₜ(s′)
115+
if is_enable_double_DQN
116+
q_values = Q(s′)
117+
else
118+
q_values = Qₜ(s′)
119+
end
120+
111121
if haskey(batch, :next_legal_actions_mask)
112122
l′ = send_to_device(D, batch[:next_legal_actions_mask])
113-
target_q .+= ifelse.(l′, 0.0f0, typemin(Float32))
123+
q_values .+= ifelse.(l′, 0.0f0, typemin(Float32))
124+
end
125+
126+
if is_enable_double_DQN
127+
selected_actions = dropdims(argmax(q_values, dims=1), dims=1)
128+
q′ = Qₜ(s′)[selected_actions]
129+
else
130+
q′ = dropdims(maximum(q_values; dims = 1), dims = 1)
114131
end
115132

116-
q′ = dropdims(maximum(target_q; dims = 1), dims = 1)
117133
G = r .+ γ^n .* (1 .- t) .* q′
118134

119135
gs = gradient(params(Q)) do

0 commit comments

Comments
 (0)