Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 7913db6

Browse files
authored
Allow multidimensional actions in ppo (#151)
* Hack to allow multidim actions in ppo * Fix for single dim envs * Handle single and multi actions separately * Update PPOPolicy docstring for multidim actions * Update docstring for PPOPolicy
1 parent a9ad731 commit 7913db6

2 files changed

Lines changed: 27 additions & 7 deletions

File tree

src/algorithms/policy_gradient/multi_thread_env.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,14 @@ end
8282
MacroTools.@forward MultiThreadEnv.envs Base.getindex, Base.length, Base.iterate
8383

8484
function (env::MultiThreadEnv)(actions)
85+
N = ndims(actions)
8586
@sync for i in 1:length(env)
8687
@spawn begin
87-
env[i](actions[i])
88+
if N == 1
89+
env[i](actions[i])
90+
else
91+
env[i](selectdim(actions, N, i))
92+
end
8893
end
8994
end
9095
end
@@ -126,6 +131,7 @@ function RLBase.is_terminated(env::MultiThreadEnv)
126131
end
127132

128133
function RLBase.legal_action_space_mask(env::MultiThreadEnv)
134+
N = ndims(env.states)
129135
@sync for i in 1:length(env)
130136
@spawn selectdim(env.legal_action_space_mask, N, i) .=
131137
legal_action_space_mask(env[i])

src/algorithms/policy_gradient/ppo.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,12 @@ end
7171
- `rng = Random.GLOBAL_RNG`,
7272
7373
By default, `dist` is set to `Categorical`, which means it will only works
74-
on environments of discrete actions. To work with environments of
74+
on environments of discrete actions. To work with environments of continuous
75+
actions `dist` should be set to `Normal` and the `actor` in the `approximator`
76+
should be a `GaussianNetwork`. Using it with a `GaussianNetwork` supports
77+
multi-dimensional action spaces, though it only supports it under the assumption
78+
that the dimensions are independent since the `GaussianNetwork` outputs a single
79+
`μ` and `σ` for each dimension which is used to simplify the calculations.
7580
"""
7681
mutable struct PPOPolicy{A<:ActorCritic,D,R} <: AbstractPolicy
7782
approximator::A
@@ -178,7 +183,12 @@ end
178183
function (agent::Agent{<:PPOPolicy})(env::MultiThreadEnv)
179184
dist = prob(agent.policy, env)
180185
action = rand.(agent.policy.rng, dist)
181-
EnrichedAction(action; action_log_prob = logpdf.(dist, action))
186+
if ndims(action) == 2
187+
action_log_prob = sum(logpdf.(dist, action), dims=1)
188+
else
189+
action_log_prob = logpdf.(dist, action)
190+
end
191+
EnrichedAction(action; action_log_prob=vec(action_log_prob))
182192
end
183193

184194
function RLBase.update!(
@@ -227,7 +237,7 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
227237
)
228238
returns = advantages .+ select_last_dim(states_plus_values, 1:n_rollout)
229239

230-
actions = select_last_dim(t[:action], 1:n)
240+
actions_flatten = flatten_batch(select_last_dim(t[:action], 1:n))
231241
action_log_probs = select_last_dim(t[:action_log_prob], 1:n)
232242

233243
# TODO: normalize advantage
@@ -246,7 +256,7 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
246256
@error "TODO:"
247257
end
248258
s = send_to_device(D, select_last_dim(states_flatten, inds)) # !!! performance critical
249-
a = vec(actions)[inds]
259+
a = send_to_device(D, select_last_dim(actions_flatten, inds))
250260
r = send_to_device(D, vec(returns)[inds])
251261
log_p = send_to_device(D, vec(action_log_probs)[inds])
252262
adv = send_to_device(D, vec(advantages)[inds])
@@ -256,8 +266,12 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
256266
v′ = AC.critic(s) |> vec
257267
if AC.actor isa GaussianNetwork
258268
μ, σ = AC.actor(s)
259-
log_p′ₐ = normlogpdf(μ, σ, a)
260-
entropy_loss = mean((log(2.0f0π) + 1) / 2 .+ log.(σ))
269+
if ndims(a) == 2
270+
log_p′ₐ = sum(normlogpdf(μ, σ, a), dims=1)
271+
else
272+
log_p′ₐ = normlogpdf(μ, σ, a)
273+
end
274+
entropy_loss = mean((log(2.0f0π) + 1) / 2 .+ sum(log.(σ), dims=1))
261275
else
262276
# actor is assumed to return discrete logits
263277
logit′ = AC.actor(s)

0 commit comments

Comments
 (0)