Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 5e1db96

Browse files
authored
add some explanations (#155)
1 parent c5fa628 commit 5e1db96

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/algorithms/dqns/rainbow.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ See paper: [Rainbow: Combining Improvements in Deep Reinforcement Learning](http
99
1010
- `approximator`::[`AbstractApproximator`](@ref): used to get Q-values of a state.
1111
- `target_approximator`::[`AbstractApproximator`](@ref): similar to `approximator`, but used to estimate the target (the next state).
12-
- `loss_func`: the loss function.
12+
- `loss_func`: the loss function. It is recommended to use Flux.Losses.logitcrossentropy. Flux.Losses.crossentropy will encounter the problem of negative numbers.
1313
- `Vₘₐₓ::Float32`: the maximum value of distribution.
1414
- `Vₘᵢₙ::Float32`: the minimum value of distribution.
1515
- `n_actions::Int`: number of possible actions.
@@ -176,6 +176,7 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
176176
gs = gradient(Flux.params(Q)) do
177177
logits = reshape(Q(states), n_atoms, n_actions, :)
178178
select_logits = logits[:, actions]
179+
# The original paper normalized logits, but using normalization and Flux.Losses.crossentropy is not as stable as using Flux.Losses.logitcrossentropy.
179180
batch_losses = loss_func(select_logits, target_distribution)
180181
loss =
181182
is_use_PER ? dot(vec(weights), vec(batch_losses)) * 1 // batch_size :

0 commit comments

Comments
 (0)