11include (" ppo_trajectory.jl" )
22
33using Random
4+ using Distributions: Categorical, Normal, logpdf
5+ using StructArrays
46
5- export PPOLearner
7+ export PPOPolicy
68
79"""
8- PPOLearner (;kwargs)
10+ PPOPolicy (;kwargs)
911
1012# Keyword arguments
1113
@@ -19,9 +21,13 @@ export PPOLearner
1921- `actor_loss_weight = 1.0f0`,
2022- `critic_loss_weight = 0.5f0`,
2123- `entropy_loss_weight = 0.01f0`,
24+ - `dist = Categorical`,
2225- `rng = Random.GLOBAL_RNG`,
26+
27+ By default, `dist` is set to `Categorical`, which means it will only works
28+ on environments of discrete actions. To work with environments of
2329"""
24- mutable struct PPOLearner {A<: ActorCritic ,R} <: AbstractLearner
30+ mutable struct PPOPolicy {A<: ActorCritic ,D, R} <: AbstractPolicy
2531 approximator:: A
2632 γ:: Float32
2733 λ:: Float32
@@ -41,7 +47,7 @@ mutable struct PPOLearner{A<:ActorCritic,R} <: AbstractLearner
4147 loss:: Matrix{Float32}
4248end
4349
44- function PPOLearner (;
50+ function PPOPolicy (;
4551 approximator,
4652 γ = 0.99f0 ,
4753 λ = 0.95f0 ,
@@ -52,9 +58,10 @@ function PPOLearner(;
5258 actor_loss_weight = 1.0f0 ,
5359 critic_loss_weight = 0.5f0 ,
5460 entropy_loss_weight = 0.01f0 ,
61+ dist = Categorical,
5562 rng = Random. GLOBAL_RNG,
5663)
57- PPOLearner (
64+ PPOPolicy {typeof(approximator),dist,typeof(rng)} (
5865 approximator,
5966 γ,
6067 λ,
@@ -74,21 +81,33 @@ function PPOLearner(;
7481 )
7582end
7683
77- function (learner:: PPOLearner )(env:: MultiThreadEnv )
78- learner. approximator. actor (send_to_device (
79- device (learner. approximator),
80- get_state (env),
81- )) |> send_to_host
84+ function RLBase. get_prob (p:: PPOPolicy{<:ActorCritic{<:NeuralNetworkApproximator{<:GaussianNetwork}}, Normal} , state:: AbstractArray )
85+ p. approximator. actor (send_to_device (
86+ device (p. approximator),
87+ state,
88+ )) |> send_to_host |> StructArray{Normal}
89+ end
90+
91+ function RLBase. get_prob (p:: PPOPolicy{<:ActorCritic, Categorical} , state:: AbstractArray )
92+ logits = p. approximator. actor (send_to_device (
93+ device (p. approximator),
94+ state,
95+ )) |> softmax |> send_to_host
96+ [Categorical (x;check_args= false ) for x in eachcol (logits)]
8297end
8398
84- function (learner:: PPOLearner )(env)
99+ RLBase. get_prob (p:: PPOPolicy , env:: MultiThreadEnv ) = get_prob (p, get_state (env))
100+
101+ function RLBase. get_prob (p:: PPOPolicy , env:: AbstractEnv )
85102 s = get_state (env)
86103 s = Flux. unsqueeze (s, ndims (s) + 1 )
87- s = send_to_device (device (learner. approximator), s)
88- learner. approximator. actor (s) |> vec |> send_to_host
104+ get_prob (p, s)[1 ]
89105end
90106
91- function RLBase. update! (learner:: PPOLearner , t:: PPOTrajectory )
107+ (p:: PPOPolicy )(env:: MultiThreadEnv ) = rand .(p. rng, get_prob (p, env))
108+ (p:: PPOPolicy )(env:: AbstractEnv ) = rand (p. rng, get_prob (p, env))
109+
110+ function RLBase. update! (p:: PPOPolicy , t:: PPOTrajectory )
92111 isfull (t) || return
93112
94113 states = t[:state ]
@@ -98,16 +117,16 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
98117 terminals = t[:terminal ]
99118 states_plus = t[:full_state ]
100119
101- rng = learner . rng
102- AC = learner . approximator
103- γ = learner . γ
104- λ = learner . λ
105- n_epochs = learner . n_epochs
106- n_microbatches = learner . n_microbatches
107- clip_range = learner . clip_range
108- w₁ = learner . actor_loss_weight
109- w₂ = learner . critic_loss_weight
110- w₃ = learner . entropy_loss_weight
120+ rng = p . rng
121+ AC = p . approximator
122+ γ = p . γ
123+ λ = p . λ
124+ n_epochs = p . n_epochs
125+ n_microbatches = p . n_microbatches
126+ clip_range = p . clip_range
127+ w₁ = p . actor_loss_weight
128+ w₂ = p . critic_loss_weight
129+ w₃ = p . entropy_loss_weight
111130 D = device (AC)
112131
113132 n_envs, n_rollout = size (terminals)
@@ -142,60 +161,63 @@ function RLBase.update!(learner::PPOLearner, t::PPOTrajectory)
142161 ps = Flux. params (AC)
143162 gs = gradient (ps) do
144163 v′ = AC. critic (s) |> vec
145- logit′ = AC. actor (s)
146- p′ = softmax (logit′)
147- log_p′ = logsoftmax (logit′)
148- log_p′ₐ = log_p′[CartesianIndex .(a, 1 : length (a))]
164+ if AC. actor isa NeuralNetworkApproximator{<: GaussianNetwork }
165+ μ, σ = AC. actor (s)
166+ log_p′ₐ = normlogpdf (μ, σ, a)
167+ entropy_loss = mean ((log (2.0f0 π)+ 1 )/ 2 .+ log .(σ))
168+ else
169+ # actor is assumed to return discrete logits
170+ logit′ = AC. actor (s)
171+ p′ = softmax (logit′)
172+ log_p′ = logsoftmax (logit′)
173+ log_p′ₐ = log_p′[CartesianIndex .(a, 1 : length (a))]
174+ entropy_loss = - sum (p′ .* log_p′) * 1 // size (p′, 2 )
175+ end
149176
150177 ratio = exp .(log_p′ₐ .- log_p)
151178 surr1 = ratio .* adv
152179 surr2 = clamp .(ratio, 1.0f0 - clip_range, 1.0f0 + clip_range) .* adv
153180
154181 actor_loss = - mean (min .(surr1, surr2))
155182 critic_loss = mean ((r .- v′) .^ 2 )
156- entropy_loss = - sum (p′ .* log_p′) * 1 // size (p′, 2 )
157183 loss = w₁ * actor_loss + w₂ * critic_loss - w₃ * entropy_loss
158184
159185 ignore () do
160- learner . actor_loss[i, epoch] = actor_loss
161- learner . critic_loss[i, epoch] = critic_loss
162- learner . entropy_loss[i, epoch] = entropy_loss
163- learner . loss[i, epoch] = loss
186+ p . actor_loss[i, epoch] = actor_loss
187+ p . critic_loss[i, epoch] = critic_loss
188+ p . entropy_loss[i, epoch] = entropy_loss
189+ p . loss[i, epoch] = loss
164190 end
165191
166192 loss
167193 end
168194
169- learner . norm[i, epoch] = clip_by_global_norm! (gs, ps, learner . max_grad_norm)
195+ p . norm[i, epoch] = clip_by_global_norm! (gs, ps, p . max_grad_norm)
170196 update! (AC, gs)
171197 end
172198 end
173199end
174200
175- function (π:: QBasedPolicy{<:PPOLearner} )(env:: MultiThreadEnv )
176- action_values = π. learner (env)
177- logits = logsoftmax (action_values)
178- actions = π. explorer (action_values)
179- actions_log_prob = logits[CartesianIndex .(actions, 1 : size (action_values, 2 ))]
180- actions, actions_log_prob
181- end
201+ function (agent:: Agent{<:Union{PPOPolicy, RandomStartPolicy{<:PPOPolicy}}} )(:: Training{PreActStage} , env:: MultiThreadEnv )
202+ state = get_state (env)
203+ dist = get_prob (agent. policy, env)
182204
183- (π:: QBasedPolicy{<:PPOLearner} )(env) = env |> π. learner |> π. explorer
205+ # currently RandomPolicy returns a Matrix instead of a (vector of) distribution.
206+ if dist isa Matrix{<: Number }
207+ dist = [Categorical (x;check_args= false ) for x in eachcol (dist)]
208+ elseif dist isa Vector{<: Vector{<:Number} }
209+ dist = [Categorical (x;check_args= false ) for x in dist]
210+ end
184211
185- function (p:: RandomStartPolicy{<:QBasedPolicy{<:PPOLearner}} )(env:: MultiThreadEnv )
186- p. num_rand_start -= 1
187- if p. num_rand_start < 0
188- p. policy (env)
189- else
190- a = p. random_policy (env)
191- log_p = log .(get_prob (p. random_policy, env, a))
192- a, log_p
212+ # !!! a little ugly
213+ rng = if agent. policy isa PPOPolicy
214+ agent. policy. rng
215+ elseif agent. policy isa RandomStartPolicy
216+ agent. policy. policy. rng
193217 end
194- end
195218
196- function (agent:: Agent{<:AbstractPolicy,<:PPOTrajectory} )(:: Training{PreActStage} , env)
197- action, action_log_prob = agent. policy (env)
198- state = get_state (env)
219+ action = [rand (rng, d) for d in dist]
220+ action_log_prob = [logpdf (d, a) for (d, a) in zip (dist, action)]
199221 push! (
200222 agent. trajectory;
201223 state = state,
@@ -217,12 +239,3 @@ function (agent::Agent{<:AbstractPolicy,<:PPOTrajectory})(::Training{PreActStage
217239
218240 action
219241end
220-
221- function (agent:: Agent{<:AbstractPolicy,<:PPOTrajectory} )(:: Training{PostActStage} , env)
222- push! (agent. trajectory; reward = get_reward (env), terminal = get_terminal (env))
223- nothing
224- end
225-
226- function (agent:: Agent{<:AbstractPolicy,<:PPOTrajectory} )(:: Testing{PreActStage} , env)
227- agent. policy (env)[1 ] # ignore the log_prob of action
228- end
0 commit comments