7171- `rng = Random.GLOBAL_RNG`,
7272
7373By default, `dist` is set to `Categorical`, which means it will only works
74- on environments of discrete actions. To work with environments of
74+ on environments of discrete actions. To work with environments of continuous
75+ actions `dist` should be set to `Normal` and the `actor` in the `approximator`
76+ should be a `GaussianNetwork`. Using it with a `GaussianNetwork` supports
77+ multi-dimensional action spaces, though it only supports it under the assumption
78+ that the dimensions are independent since the `GaussianNetwork` outputs a single
79+ `μ` and `σ` for each dimension which is used to simplify the calculations.
7580"""
7681mutable struct PPOPolicy{A<: ActorCritic ,D,R} <: AbstractPolicy
7782 approximator:: A
178183function (agent:: Agent{<:PPOPolicy} )(env:: MultiThreadEnv )
179184 dist = prob (agent. policy, env)
180185 action = rand .(agent. policy. rng, dist)
181- EnrichedAction (action; action_log_prob = logpdf .(dist, action))
186+ if ndims (action) == 2
187+ action_log_prob = sum (logpdf .(dist, action), dims= 1 )
188+ else
189+ action_log_prob = logpdf .(dist, action)
190+ end
191+ EnrichedAction (action; action_log_prob= vec (action_log_prob))
182192end
183193
184194function RLBase. update! (
@@ -227,7 +237,7 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
227237 )
228238 returns = advantages .+ select_last_dim (states_plus_values, 1 : n_rollout)
229239
230- actions = select_last_dim (t[:action ], 1 : n)
240+ actions_flatten = flatten_batch ( select_last_dim (t[:action ], 1 : n) )
231241 action_log_probs = select_last_dim (t[:action_log_prob ], 1 : n)
232242
233243 # TODO : normalize advantage
@@ -246,7 +256,7 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
246256 @error " TODO:"
247257 end
248258 s = send_to_device (D, select_last_dim (states_flatten, inds)) # !!! performance critical
249- a = vec (actions)[ inds]
259+ a = send_to_device (D, select_last_dim (actions_flatten, inds))
250260 r = send_to_device (D, vec (returns)[inds])
251261 log_p = send_to_device (D, vec (action_log_probs)[inds])
252262 adv = send_to_device (D, vec (advantages)[inds])
@@ -256,8 +266,12 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
256266 v′ = AC. critic (s) |> vec
257267 if AC. actor isa GaussianNetwork
258268 μ, σ = AC. actor (s)
259- log_p′ₐ = normlogpdf (μ, σ, a)
260- entropy_loss = mean ((log (2.0f0 π) + 1 ) / 2 .+ log .(σ))
269+ if ndims (a) == 2
270+ log_p′ₐ = sum (normlogpdf (μ, σ, a), dims= 1 )
271+ else
272+ log_p′ₐ = normlogpdf (μ, σ, a)
273+ end
274+ entropy_loss = mean ((log (2.0f0 π) + 1 ) / 2 .+ sum (log .(σ), dims= 1 ))
261275 else
262276 # actor is assumed to return discrete logits
263277 logit′ = AC. actor (s)
0 commit comments