@@ -60,7 +60,12 @@ function PrioritizedDQNLearner(;
6060 rng = Random. GLOBAL_RNG,
6161) where {Tq,Tt,Tf}
6262 copyto! (approximator, target_approximator)
63- sampler = NStepBatchSampler {traces} (;γ= γ, n= update_horizon,stack_size= stack_size,batch_size= batch_size)
63+ sampler = NStepBatchSampler {traces} (;
64+ γ = γ,
65+ n = update_horizon,
66+ stack_size = stack_size,
67+ batch_size = batch_size,
68+ )
6469 PrioritizedDQNLearner (
6570 approximator,
6671 target_approximator,
9499function (learner:: PrioritizedDQNLearner )(env)
95100 env |>
96101 state |>
97- x -> Flux. unsqueeze (x, ndims (x) + 1 ) |>
98- x -> send_to_device (device (learner), x) |>
99- learner. approximator |>
100- vec |>
101- send_to_host
102+ x ->
103+ Flux. unsqueeze (x, ndims (x) + 1 ) |>
104+ x ->
105+ send_to_device (device (learner), x) |>
106+ learner. approximator |>
107+ vec |>
108+ send_to_host
102109end
103110
104111function RLBase. update! (learner:: PrioritizedDQNLearner , batch:: NamedTuple )
@@ -111,7 +118,7 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
111118 batch_size = learner. sampler. batch_size
112119
113120 D = device (Q)
114- s, a, r, t, s′ = (send_to_device (D,batch[x]) for x in SARTS)
121+ s, a, r, t, s′ = (send_to_device (D, batch[x]) for x in SARTS)
115122 a = CartesianIndex .(a, 1 : batch_size)
116123
117124 updated_priorities = Vector {Float32} (undef, batch_size)
@@ -122,7 +129,7 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
122129 target_q = Qₜ (s′)
123130 if haskey (batch, :next_legal_actions_mask )
124131 l′ = send_to_device (D, batch[:next_legal_actions_mask ])
125- target_q .+ = ifelse .(l′, 0.f0 , typemin (Float32))
132+ target_q .+ = ifelse .(l′, 0.0f0 , typemin (Float32))
126133 end
127134
128135 q′ = dropdims (maximum (target_q; dims = 1 ), dims = 1 )
0 commit comments