@@ -17,6 +17,7 @@ mutable struct DQNLearner{
1717 rng:: R
1818 # for logging
1919 loss:: Float32
20+ is_enable_double_DQN:: Bool
2021end
2122
2223"""
@@ -36,8 +37,9 @@ See paper: [Human-level control through deep reinforcement learning](https://www
3637- `update_freq::Int=4`: the frequency of updating the `approximator`.
3738- `target_update_freq::Int=100`: the frequency of syncing `target_approximator`.
3839- `stack_size::Union{Int, Nothing}=4`: use the recent `stack_size` frames to form a stacked state.
39- - `traces = SARTS`, set to `SLARTSL` if you are to apply to an environment of `FULL_ACTION_SET`.
40+ - `traces = SARTS`: set to `SLARTSL` if you are to apply to an environment of `FULL_ACTION_SET`.
4041- `rng = Random.GLOBAL_RNG`
42+ - `double = Bool`: Enable double dqn, enabled by default.
4143"""
4244function DQNLearner (;
4345 approximator:: Tq ,
@@ -53,6 +55,7 @@ function DQNLearner(;
5355 traces = SARTS,
5456 update_step = 0 ,
5557 rng = Random. GLOBAL_RNG,
58+ is_enable_double_DQN:: Bool = true
5659) where {Tq,Tt,Tf}
5760 copyto! (approximator, target_approximator)
5861 sampler = NStepBatchSampler {traces} (;
@@ -72,6 +75,7 @@ function DQNLearner(;
7275 sampler,
7376 rng,
7477 0.0f0 ,
78+ is_enable_double_DQN
7579 )
7680end
7781
@@ -102,18 +106,30 @@ function RLBase.update!(learner::DQNLearner, batch::NamedTuple)
102106 loss_func = learner. loss_func
103107 n = learner. sampler. n
104108 batch_size = learner. sampler. batch_size
109+ is_enable_double_DQN = learner. is_enable_double_DQN
105110 D = device (Q)
106111
107112 s, a, r, t, s′ = (send_to_device (D, batch[x]) for x in SARTS)
108113 a = CartesianIndex .(a, 1 : batch_size)
109114
110- target_q = Qₜ (s′)
115+ if is_enable_double_DQN
116+ q_values = Q (s′)
117+ else
118+ q_values = Qₜ (s′)
119+ end
120+
111121 if haskey (batch, :next_legal_actions_mask )
112122 l′ = send_to_device (D, batch[:next_legal_actions_mask ])
113- target_q .+ = ifelse .(l′, 0.0f0 , typemin (Float32))
123+ q_values .+ = ifelse .(l′, 0.0f0 , typemin (Float32))
124+ end
125+
126+ if is_enable_double_DQN
127+ selected_actions = dropdims (argmax (q_values, dims= 1 ), dims= 1 )
128+ q′ = Qₜ (s′)[selected_actions]
129+ else
130+ q′ = dropdims (maximum (q_values; dims = 1 ), dims = 1 )
114131 end
115132
116- q′ = dropdims (maximum (target_q; dims = 1 ), dims = 1 )
117133 G = r .+ γ^ n .* (1 .- t) .* q′
118134
119135 gs = gradient (params (Q)) do
0 commit comments