Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit d7077e8

Browse files
authored
Merge pull request #131 from JuliaReinforcementLearning/auto-juliaformatter-pr
Automatic JuliaFormatter.jl run
2 parents eb967ca + 79b6a60 commit d7077e8

36 files changed

Lines changed: 182 additions & 185 deletions

src/algorithms/cfr/best_response_policy.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,7 @@ function init_cfr_reach_prob!(p, env, reach_prob = 1.0)
6666
end
6767
else # opponents
6868
for a in legal_action_space(env)
69-
init_cfr_reach_prob!(
70-
p,
71-
child(env, a),
72-
reach_prob * prob(p.policy, env, a),
73-
)
69+
init_cfr_reach_prob!(p, child(env, a), reach_prob * prob(p.policy, env, a))
7470
end
7571
end
7672
end

src/algorithms/cfr/external_sampling_mccfr.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ end
1818

1919
(p::ExternalSamplingMCCFRPolicy)(env::AbstractEnv) = p.behavior_policy(env)
2020

21-
RLBase.prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv) =
22-
prob(p.behavior_policy, env)
21+
RLBase.prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv) = prob(p.behavior_policy, env)
2322

2423
function ExternalSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_RNG)
2524
ExternalSamplingMCCFRPolicy(

src/algorithms/cfr/outcome_sampling_mccfr.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ end
1919

2020
(p::OutcomeSamplingMCCFRPolicy)(env::AbstractEnv) = p.behavior_policy(env)
2121

22-
RLBase.prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv) =
23-
prob(p.behavior_policy, env)
22+
RLBase.prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv) = prob(p.behavior_policy, env)
2423

2524
function OutcomeSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_RNG, ϵ = 0.6)
2625
OutcomeSamplingMCCFRPolicy(

src/algorithms/dqns/basic_dqn.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ end
3838
(learner::BasicDQNLearner)(env) =
3939
env |>
4040
state |>
41-
x -> send_to_device(device(learner), x) |>
42-
learner.approximator |>
43-
send_to_host
41+
x -> send_to_device(device(learner), x) |> learner.approximator |> send_to_host
4442

4543
function BasicDQNLearner(;
4644
approximator::Q,

src/algorithms/dqns/common.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
const PERLearners = Union{PrioritizedDQNLearner,RainbowLearner,IQNLearner}
66

7-
function RLBase.update!(learner::Union{DQNLearner, PERLearners}, t::AbstractTrajectory)
7+
function RLBase.update!(learner::Union{DQNLearner,PERLearners}, t::AbstractTrajectory)
88
length(t[:terminal]) < learner.min_replay_history && return
99

1010
learner.update_step += 1
@@ -21,11 +21,16 @@ function RLBase.update!(learner::Union{DQNLearner, PERLearners}, t::AbstractTraj
2121
priorities = update!(learner, batch)
2222
t[:priority][inds] .= priorities
2323
else
24-
update!(learner,batch)
24+
update!(learner, batch)
2525
end
2626
end
2727

28-
function RLBase.update!(trajectory::PrioritizedTrajectory, p::QBasedPolicy{<:PERLearners}, env::AbstractEnv, ::PostActStage)
28+
function RLBase.update!(
29+
trajectory::PrioritizedTrajectory,
30+
p::QBasedPolicy{<:PERLearners},
31+
env::AbstractEnv,
32+
::PostActStage,
33+
)
2934
push!(trajectory[:reward], reward(env))
3035
push!(trajectory[:terminal], is_terminated(env))
3136
push!(trajectory[:priority], p.learner.default_priority)

src/algorithms/dqns/dqn.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ function DQNLearner(;
5555
rng = Random.GLOBAL_RNG,
5656
) where {Tq,Tt,Tf}
5757
copyto!(approximator, target_approximator)
58-
sampler = NStepBatchSampler{traces}(;γ=γ, n=update_horizon,stack_size=stack_size,batch_size=batch_size)
58+
sampler = NStepBatchSampler{traces}(;
59+
γ = γ,
60+
n = update_horizon,
61+
stack_size = stack_size,
62+
batch_size = batch_size,
63+
)
5964
DQNLearner(
6065
approximator,
6166
target_approximator,
@@ -81,11 +86,13 @@ end
8186
function (learner::DQNLearner)(env)
8287
env |>
8388
state |>
84-
x -> Flux.unsqueeze(x, ndims(x) + 1) |>
85-
x -> send_to_device(device(learner), x) |>
86-
learner.approximator |>
87-
vec |>
88-
send_to_host
89+
x ->
90+
Flux.unsqueeze(x, ndims(x) + 1) |>
91+
x ->
92+
send_to_device(device(learner), x) |>
93+
learner.approximator |>
94+
vec |>
95+
send_to_host
8996
end
9097

9198
function RLBase.update!(learner::DQNLearner, batch::NamedTuple)
@@ -103,7 +110,7 @@ function RLBase.update!(learner::DQNLearner, batch::NamedTuple)
103110
target_q = Qₜ(s′)
104111
if haskey(batch, :next_legal_actions_mask)
105112
l′ = send_to_device(D, batch[:next_legal_actions_mask])
106-
target_q .+= ifelse.(l′, 0.f0, typemin(Float32))
113+
target_q .+= ifelse.(l′, 0.0f0, typemin(Float32))
107114
end
108115

109116
q′ = dropdims(maximum(target_q; dims = 1), dims = 1)

src/algorithms/dqns/iqn.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,12 @@ function IQNLearner(;
114114
if device(approximator) !== device(device_rng)
115115
throw(ArgumentError("device of `approximator` doesn't match the device of `device_rng`: $(device(approximator)) !== $(device_rng)"))
116116
end
117-
sampler = NStepBatchSampler{traces}(;γ=γ, n=update_horizon,stack_size=stack_size,batch_size=batch_size)
117+
sampler = NStepBatchSampler{traces}(;
118+
γ = γ,
119+
n = update_horizon,
120+
stack_size = stack_size,
121+
batch_size = batch_size,
122+
)
118123
IQNLearner(
119124
approximator,
120125
target_approximator,
@@ -158,7 +163,8 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
158163
batch_size = learner.sampler.batch_size
159164

160165
D = device(Z)
161-
s, r, t, s′ = (send_to_device(D, batch[x]) for x in (:state, :reward, :terminal, :next_state))
166+
s, r, t, s′ =
167+
(send_to_device(D, batch[x]) for x in (:state, :reward, :terminal, :next_state))
162168

163169
τ′ = rand(learner.device_rng, Float32, N′, batch_size) # TODO: support β distribution
164170
τₑₘ′ = embed(τ′, Nₑₘ)
@@ -174,7 +180,9 @@ function RLBase.update!(learner::IQNLearner, batch::NamedTuple)
174180
aₜ = argmax(avg_zₜ, dims = 1)
175181
aₜ = aₜ .+ typeof(aₜ)(CartesianIndices((0, 0:N′-1, 0)))
176182
qₜ = reshape(zₜ[aₜ], :, batch_size)
177-
target = reshape(r, 1, batch_size) .+ learner.sampler.γ * reshape(1 .- t, 1, batch_size) .* qₜ # reshape to allow broadcast
183+
target =
184+
reshape(r, 1, batch_size) .+
185+
learner.sampler.γ * reshape(1 .- t, 1, batch_size) .* qₜ # reshape to allow broadcast
178186

179187
τ = rand(learner.device_rng, Float32, N, batch_size)
180188
τₑₘ = embed(τ, Nₑₘ)

src/algorithms/dqns/prioritized_dqn.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ function PrioritizedDQNLearner(;
6060
rng = Random.GLOBAL_RNG,
6161
) where {Tq,Tt,Tf}
6262
copyto!(approximator, target_approximator)
63-
sampler = NStepBatchSampler{traces}(;γ=γ, n=update_horizon,stack_size=stack_size,batch_size=batch_size)
63+
sampler = NStepBatchSampler{traces}(;
64+
γ = γ,
65+
n = update_horizon,
66+
stack_size = stack_size,
67+
batch_size = batch_size,
68+
)
6469
PrioritizedDQNLearner(
6570
approximator,
6671
target_approximator,
@@ -94,11 +99,13 @@ end
9499
function (learner::PrioritizedDQNLearner)(env)
95100
env |>
96101
state |>
97-
x -> Flux.unsqueeze(x, ndims(x) + 1) |>
98-
x -> send_to_device(device(learner), x) |>
99-
learner.approximator |>
100-
vec |>
101-
send_to_host
102+
x ->
103+
Flux.unsqueeze(x, ndims(x) + 1) |>
104+
x ->
105+
send_to_device(device(learner), x) |>
106+
learner.approximator |>
107+
vec |>
108+
send_to_host
102109
end
103110

104111
function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
@@ -111,7 +118,7 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
111118
batch_size = learner.sampler.batch_size
112119

113120
D = device(Q)
114-
s, a, r, t, s′ = (send_to_device(D,batch[x]) for x in SARTS)
121+
s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS)
115122
a = CartesianIndex.(a, 1:batch_size)
116123

117124
updated_priorities = Vector{Float32}(undef, batch_size)
@@ -122,7 +129,7 @@ function RLBase.update!(learner::PrioritizedDQNLearner, batch::NamedTuple)
122129
target_q = Qₜ(s′)
123130
if haskey(batch, :next_legal_actions_mask)
124131
l′ = send_to_device(D, batch[:next_legal_actions_mask])
125-
target_q .+= ifelse.(l′, 0.f0, typemin(Float32))
132+
target_q .+= ifelse.(l′, 0.0f0, typemin(Float32))
126133
end
127134

128135
q′ = dropdims(maximum(target_q; dims = 1), dims = 1)

src/algorithms/dqns/rainbow.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,12 @@ function RainbowLearner(;
8787
default_priority >= 1.0f0 || error("default value must be >= 1.0f0")
8888
copyto!(approximator, target_approximator) # force sync
8989
support = send_to_device(device(approximator), support)
90-
sampler = NStepBatchSampler{traces}(;γ=γ, n=update_horizon,stack_size=stack_size,batch_size=batch_size)
90+
sampler = NStepBatchSampler{traces}(;
91+
γ = γ,
92+
n = update_horizon,
93+
stack_size = stack_size,
94+
batch_size = batch_size,
95+
)
9196
RainbowLearner(
9297
approximator,
9398
target_approximator,
@@ -147,7 +152,7 @@ function RLBase.update!(learner::RainbowLearner, batch::NamedTuple)
147152
next_q = reshape(sum(support .* next_probs, dims = 1), n_actions, :)
148153
if haskey(batch, :next_legal_actions_mask)
149154
l′ = send_to_device(D, batch[:next_legal_actions_mask])
150-
next_q .+= ifelse.(l′, 0.f0, typemin(Float32))
155+
next_q .+= ifelse.(l′, 0.0f0, typemin(Float32))
151156
end
152157
next_prob_select = select_best_probs(next_probs, next_q)
153158

src/algorithms/policy_gradient/A2C.jl

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,10 @@ Base.@kwdef mutable struct A2CLearner{A<:ActorCritic} <: AbstractLearner
2929
norm::Float32 = 0.0f0
3030
end
3131

32-
Flux.functor(x::A2CLearner) = (app = x.approximator, ), y -> @set x.approximator = y.app
32+
Flux.functor(x::A2CLearner) = (app = x.approximator,), y -> @set x.approximator = y.app
3333

3434
function (learner::A2CLearner)(env::MultiThreadEnv)
35-
learner.approximator.actor(send_to_device(
36-
device(learner),
37-
state(env),
38-
)) |> send_to_host
35+
learner.approximator.actor(send_to_device(device(learner), state(env))) |> send_to_host
3936
end
4037

4138
function (learner::A2CLearner)(env)
@@ -70,20 +67,17 @@ function _update!(learner::A2CLearner, t::CircularArraySARTTrajectory)
7067
actions = flatten_batch(actions)
7168
actions = CartesianIndex.(actions, 1:length(actions))
7269

73-
next_state_values = t[:state] |>
74-
select_last_frame |>
75-
Array |>
76-
to_device |>
77-
AC.critic |>
78-
send_to_host
70+
next_state_values =
71+
t[:state] |> select_last_frame |> Array |> to_device |> AC.critic |> send_to_host
7972

80-
gains = discount_rewards(
81-
t[:reward],
82-
γ;
83-
dims = 2,
84-
init = send_to_host(next_state_values),
85-
terminal = t[:terminal],
86-
) |> to_device
73+
gains =
74+
discount_rewards(
75+
t[:reward],
76+
γ;
77+
dims = 2,
78+
init = send_to_host(next_state_values),
79+
terminal = t[:terminal],
80+
) |> to_device
8781

8882
ps = Flux.params(AC)
8983
gs = gradient(ps) do
@@ -111,4 +105,4 @@ function _update!(learner::A2CLearner, t::CircularArraySARTTrajectory)
111105
update!(AC, gs)
112106
end
113107

114-
RLCore.check(::QBasedPolicy{<:A2CLearner}, ::MultiThreadEnv) = nothing
108+
RLCore.check(::QBasedPolicy{<:A2CLearner}, ::MultiThreadEnv) = nothing

0 commit comments

Comments
 (0)