@@ -126,47 +126,43 @@ function (learner::RainbowLearner)(env)
126126 state = Flux. unsqueeze (state, ndims (state) + 1 )
127127 logits = learner. approximator (state)
128128 q = learner. support .* softmax (reshape (logits, :, learner. n_actions))
129- # probs = vec(sum(q, dims=1)) .+ legal_action
130- vec (sum (q, dims = 1 )) |> send_to_host
129+ probs = vec (sum (q, dims = 1 )) |> send_to_host
130+ if ActionStyle (env) === FULL_ACTION_SET
131+ probs .+ = typemin (eltype (probs)) .* (1 .- get_legal_actions_mask (env))
132+ end
133+ probs
131134end
132135
133136function RLBase. update! (learner:: RainbowLearner , batch:: NamedTuple )
134- Q,
135- Qₜ,
136- γ,
137- β,
138- loss_func,
139- n_atoms,
140- n_actions,
141- support,
142- delta_z,
143- update_horizon,
144- batch_size = learner. approximator,
145- learner. target_approximator,
146- learner. γ,
147- learner. β_priority,
148- learner. loss_func,
149- learner. n_atoms,
150- learner. n_actions,
151- learner. support,
152- learner. delta_z,
153- learner. update_horizon,
154- learner. batch_size
155-
137+ Q = learner. approximator
138+ Qₜ = learner. target_approximator
139+ γ = learner. γ
140+ β = learner. β_priority
141+ loss_func = learner. loss_func
142+ n_atoms = learner. n_atoms
143+ n_actions = learner. n_actions
144+ support = learner. support
145+ delta_z = learner. delta_z
146+ update_horizon = learner. update_horizon
147+ batch_size = learner. batch_size
156148 D = device (Q)
157- states, rewards, terminals, next_states = map (
158- x -> send_to_device (D, x),
159- (batch. states, batch. rewards, batch. terminals, batch. next_states),
160- )
149+ states = send_to_device (D, batch. states)
150+ rewards = send_to_device (D, batch. rewards)
151+ terminals = send_to_device (D, batch. terminals)
152+ next_states = send_to_device (D, batch. next_states)
153+
161154 actions = CartesianIndex .(batch. actions, 1 : batch_size)
155+
162156 target_support =
163157 reshape (rewards, 1 , :) .+
164158 (reshape (support, :, 1 ) * reshape ((γ^ update_horizon) .* (1 .- terminals), 1 , :))
165159
166160 next_logits = Qₜ (next_states)
167161 next_probs = reshape (softmax (reshape (next_logits, n_atoms, :)), n_atoms, n_actions, :)
168162 next_q = reshape (sum (support .* next_probs, dims = 1 ), n_actions, :)
169- # next_q_argmax = argmax(cpu(next_q .+ next_legal_actions), dims=1)
163+ if ! isnothing (batch. next_legal_actions_mask)
164+ next_q .+ = typemin (eltype (next_q)) .* (1 .- send_to_device (D, batch. next_legal_actions_mask))
165+ end
170166 next_prob_select = select_best_probs (next_probs, next_q)
171167
172168 target_distribution = project_distribution (
@@ -178,18 +174,23 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
178174 learner. Vₘₐₓ,
179175 )
180176
181- updated_priorities = Vector {Float32} (undef, batch_size)
182- weights = 1f0 ./ ((batch. priorities .+ 1f-10 ) .^ β)
183- weights ./= maximum (weights)
184- weights = send_to_device (D, weights)
177+ is_use_PER = ! isnothing (batch. priorities) # is use Prioritized Experience Replay
178+ if is_use_PER
179+ updated_priorities = Vector {Float32} (undef, batch_size)
180+ weights = 1f0 ./ ((batch. priorities .+ 1f-10 ) .^ β)
181+ weights ./= maximum (weights)
182+ weights = send_to_device (D, weights)
183+ end
185184
186185 gs = gradient (Flux. params (Q)) do
187186 logits = reshape (Q (states), n_atoms, n_actions, :)
188187 select_logits = logits[:, actions]
189188 batch_losses = loss_func (select_logits, target_distribution)
190- loss = dot (vec (weights), vec (batch_losses)) * 1 // batch_size
189+ loss = is_use_PER ? dot (vec (weights), vec (batch_losses)) * 1 // batch_size : mean (batch_losses)
191190 ignore () do
192- updated_priorities .= send_to_host (vec ((batch_losses .+ 1f-10 ) .^ β))
191+ if is_use_PER
192+ updated_priorities .= send_to_host (vec ((batch_losses .+ 1f-10 ) .^ β))
193+ end
193194 learner. loss = loss
194195 end
195196 loss
0 commit comments