44
55const PERLearners = Union{PrioritizedDQNLearner,RainbowLearner,IQNLearner}
66
7- function extract_experience (t:: AbstractTrajectory , learner:: PERLearners )
8- s = learner. stack_size
9- h = learner. update_horizon
10- n = learner. batch_size
11- γ = learner. γ
12-
13- # 1. sample indices based on priority
14- valid_ind_range =
15- isnothing (s) ? (1 : (length (t[:terminal ])- h)) : (s: (length (t[:terminal ])- h))
16- if haskey (t, :priority )
17- inds = Vector {Int} (undef, n)
18- priorities = Vector {Float32} (undef, n)
19- for i in 1 : n
20- ind, p = sample (learner. rng, t[:priority ])
21- while ind ∉ valid_ind_range
22- ind, p = sample (learner. rng, t[:priority ])
23- end
24- inds[i] = ind
25- priorities[i] = p
26- end
27- else
28- inds = rand (learner. rng, valid_ind_range, n)
29- priorities = nothing
30- end
31-
32- next_inds = inds .+ h
33-
34- # 2. extract SARTS
35- states = consecutive_view (t[:state ], inds; n_stack = s)
36- actions = consecutive_view (t[:action ], inds)
37- next_states = consecutive_view (t[:state ], next_inds; n_stack = s)
38-
39- if haskey (t, :legal_actions_mask )
40- legal_actions_mask = consecutive_view (t[:legal_actions_mask ], inds)
41- next_legal_actions_mask = consecutive_view (t[:next_legal_actions_mask ], inds)
42- else
43- legal_actions_mask = nothing
44- next_legal_actions_mask = nothing
45- end
46-
47- consecutive_rewards = consecutive_view (t[:reward ], inds; n_horizon = h)
48- consecutive_terminals = consecutive_view (t[:terminal ], inds; n_horizon = h)
49- rewards, terminals = zeros (Float32, n), fill (false , n)
50-
51- rewards = discount_rewards_reduced (
52- consecutive_rewards,
53- γ;
54- terminal = consecutive_terminals,
55- dims = 1 ,
56- )
57- terminals = mapslices (any, consecutive_terminals; dims = 1 ) |> vec
58-
59- inds,
60- (
61- states = states,
62- legal_actions_mask = legal_actions_mask,
63- actions = actions,
64- rewards = rewards,
65- terminals = terminals,
66- next_states = next_states,
67- next_legal_actions_mask = next_legal_actions_mask,
68- priorities = priorities,
69- )
70- end
71-
72- function RLBase. update! (p:: QBasedPolicy{<:PERLearners} , t:: AbstractTrajectory )
73- learner = p. learner
7+ function RLBase. update! (learner:: Union{DQNLearner, PERLearners} , t:: AbstractTrajectory )
748 length (t[:terminal ]) < learner. min_replay_history && return
759
7610 learner. update_step += 1
@@ -81,20 +15,18 @@ function RLBase.update!(p::QBasedPolicy{<:PERLearners}, t::AbstractTrajectory)
8115
8216 learner. update_step % learner. update_freq == 0 || return
8317
84- inds, experience = extract_experience ( t, p . learner)
18+ inds, batch = sample (learner . rng, t, learner. sampler )
8519
86- if haskey (t, :priority )
87- priorities = update! (p . learner, experience )
20+ if t isa PrioritizedTrajectory
21+ priorities = update! (learner, batch )
8822 t[:priority ][inds] .= priorities
8923 else
90- update! (p . learner, experience )
24+ update! (learner,batch )
9125 end
9226end
9327
94- function (agent:: Agent{<:QBasedPolicy{<:PERLearners}} )(:: RLCore.Training{PostActStage} , env)
95- push! (agent. trajectory; reward = get_reward (env), terminal = get_terminal (env))
96- if haskey (agent. trajectory, :priority )
97- push! (agent. trajectory; priority = agent. policy. learner. default_priority)
98- end
99- nothing
28+ function RLBase. update! (trajectory:: PrioritizedTrajectory , p:: QBasedPolicy{<:PERLearners} , env:: AbstractEnv , :: PostActStage )
29+ push! (trajectory[:reward ], get_reward (env))
30+ push! (trajectory[:terminal ], get_terminal (env))
31+ push! (trajectory[:priority ], p. learner. default_priority)
10032end
0 commit comments