|
65 | 65 | function RLBase.update!(learner::BasicDQNLearner, T::AbstractTrajectory) |
66 | 66 | length(T[:terminal]) < learner.min_replay_history && return |
67 | 67 |
|
| 68 | + inds = rand(learner.rng, 1:length(T[:terminal]), learner.batch_size) |
| 69 | + |
| 70 | + batch = ( |
| 71 | + state = consecutive_view(T[:state], inds), |
| 72 | + action = consecutive_view(T[:action], inds), |
| 73 | + reward = consecutive_view(T[:reward], inds), |
| 74 | + terminal = consecutive_view(T[:terminal], inds), |
| 75 | + next_state = consecutive_view(T[:next_state], inds), |
| 76 | + ) |
| 77 | + |
| 78 | + update!(learner, batch) |
| 79 | +end |
| 80 | + |
| 81 | +function RLBase.update!(learner::BasicDQNLearner, batch::NamedTuple) |
| 82 | + |
68 | 83 | Q = learner.approximator |
69 | 84 | D = device(Q) |
70 | 85 | γ = learner.γ |
71 | 86 | loss_func = learner.loss_func |
72 | | - batch_size = learner.batch_size |
73 | 87 |
|
74 | | - inds = rand(learner.rng, 1:length(T[:terminal]), learner.batch_size) |
| 88 | + batch_size = nframes(batch.terminal) |
75 | 89 |
|
76 | | - s = send_to_device(D, consecutive_view(T[:state], inds)) |
77 | | - a = consecutive_view(T[:action], inds) |
78 | | - r = send_to_device(D, consecutive_view(T[:reward], inds)) |
79 | | - t = send_to_device(D, consecutive_view(T[:terminal], inds)) |
80 | | - s′ = send_to_device(D, consecutive_view(T[:next_state], inds)) |
| 90 | + s = send_to_device(D, batch.state) |
| 91 | + a = batch.action |
| 92 | + r = send_to_device(D, batch.reward) |
| 93 | + t = send_to_device(D, batch.terminal) |
| 94 | + s′ = send_to_device(D, batch.next_state) |
81 | 95 |
|
82 | 96 | a = CartesianIndex.(a, 1:batch_size) |
83 | 97 |
|
|
0 commit comments