Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit c5fa628

Browse files
Format .jl files (#153)
Co-authored-by: norci <norci@users.noreply.github.com>
1 parent 333a481 commit c5fa628

3 files changed

Lines changed: 10 additions & 6 deletions

File tree

src/algorithms/policy_gradient/multi_thread_env.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ function (env::MultiThreadEnv)(actions)
8585
N = ndims(actions)
8686
@sync for i in 1:length(env)
8787
@spawn begin
88-
if N == 1
88+
if N == 1
8989
env[i](actions[i])
9090
else
9191
env[i](selectdim(actions, N, i))

src/algorithms/policy_gradient/ppo.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,11 @@ function (agent::Agent{<:PPOPolicy})(env::MultiThreadEnv)
184184
dist = prob(agent.policy, env)
185185
action = rand.(agent.policy.rng, dist)
186186
if ndims(action) == 2
187-
action_log_prob = sum(logpdf.(dist, action), dims=1)
187+
action_log_prob = sum(logpdf.(dist, action), dims = 1)
188188
else
189189
action_log_prob = logpdf.(dist, action)
190190
end
191-
EnrichedAction(action; action_log_prob=vec(action_log_prob))
191+
EnrichedAction(action; action_log_prob = vec(action_log_prob))
192192
end
193193

194194
function RLBase.update!(
@@ -267,11 +267,11 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
267267
if AC.actor isa GaussianNetwork
268268
μ, σ = AC.actor(s)
269269
if ndims(a) == 2
270-
log_p′ₐ = sum(normlogpdf(μ, σ, a), dims=1)
270+
log_p′ₐ = sum(normlogpdf(μ, σ, a), dims = 1)
271271
else
272272
log_p′ₐ = normlogpdf(μ, σ, a)
273273
end
274-
entropy_loss = mean((log(2.0f0π) + 1) / 2 .+ sum(log.(σ), dims=1))
274+
entropy_loss = mean((log(2.0f0π) + 1) / 2 .+ sum(log.(σ), dims = 1))
275275
else
276276
# actor is assumed to return discrete logits
277277
logit′ = AC.actor(s)

src/experiments/gridworlds/JuliaRL_BasicDQN_EmptyRoom.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ function RLCore.Experiment(
1717
inner_env = GridWorlds.EmptyRoom(rng = rng)
1818
action_space_mapping = x -> Base.OneTo(length(RLBase.action_space(inner_env)))
1919
action_mapping = i -> RLBase.action_space(inner_env)[i]
20-
env = RLEnvs.ActionTransformedEnv(inner_env, action_space_mapping = action_space_mapping, action_mapping = action_mapping)
20+
env = RLEnvs.ActionTransformedEnv(
21+
inner_env,
22+
action_space_mapping = action_space_mapping,
23+
action_mapping = action_mapping,
24+
)
2125
env = RLEnvs.StateOverriddenEnv(env, x -> vec(Float32.(x)))
2226
env = RewardOverriddenEnv(env, x -> x - convert(typeof(x), 0.01))
2327
env = MaxTimeoutEnv(env, 240)

0 commit comments

Comments
 (0)