Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 5d0780b

Browse files
authored
improve basicdqn (#117)
1 parent ff95729 commit 5d0780b

1 file changed

Lines changed: 21 additions & 7 deletions

File tree

src/algorithms/dqns/basic_dqn.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,33 @@ end
6565
function RLBase.update!(learner::BasicDQNLearner, T::AbstractTrajectory)
6666
length(T[:terminal]) < learner.min_replay_history && return
6767

68+
inds = rand(learner.rng, 1:length(T[:terminal]), learner.batch_size)
69+
70+
batch = (
71+
state = consecutive_view(T[:state], inds),
72+
action = consecutive_view(T[:action], inds),
73+
reward = consecutive_view(T[:reward], inds),
74+
terminal = consecutive_view(T[:terminal], inds),
75+
next_state = consecutive_view(T[:next_state], inds),
76+
)
77+
78+
update!(learner, batch)
79+
end
80+
81+
function RLBase.update!(learner::BasicDQNLearner, batch::NamedTuple)
82+
6883
Q = learner.approximator
6984
D = device(Q)
7085
γ = learner.γ
7186
loss_func = learner.loss_func
72-
batch_size = learner.batch_size
7387

74-
inds = rand(learner.rng, 1:length(T[:terminal]), learner.batch_size)
88+
batch_size = nframes(batch.terminal)
7589

76-
s = send_to_device(D, consecutive_view(T[:state], inds))
77-
a = consecutive_view(T[:action], inds)
78-
r = send_to_device(D, consecutive_view(T[:reward], inds))
79-
t = send_to_device(D, consecutive_view(T[:terminal], inds))
80-
s′ = send_to_device(D, consecutive_view(T[:next_state], inds))
90+
s = send_to_device(D, batch.state)
91+
a = batch.action
92+
r = send_to_device(D, batch.reward)
93+
t = send_to_device(D, batch.terminal)
94+
s′ = send_to_device(D, batch.next_state)
8195

8296
a = CartesianIndex.(a, 1:batch_size)
8397

0 commit comments

Comments
 (0)