Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit a9ad731

Browse files
Format .jl files (#142)
Co-authored-by: norci <norci@users.noreply.github.com>
1 parent 4cb26f4 commit a9ad731

40 files changed

Lines changed: 397 additions & 297 deletions

src/ReinforcementLearningZoo.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,12 @@ using Requires
3636
function __init__()
3737
@require ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" begin
3838
include("experiments/rl_envs/rl_envs.jl")
39-
@require ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" include("experiments/atari/atari.jl")
40-
@require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include("experiments/open_spiel/open_spiel.jl")
39+
@require ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" include(
40+
"experiments/atari/atari.jl",
41+
)
42+
@require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include(
43+
"experiments/open_spiel/open_spiel.jl",
44+
)
4145
end
4246
end
4347

src/algorithms/cfr/best_response_policy.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@ end
1515
- `env`, the environment to handle.
1616
- `best_responder`, the player to choose best response action.
1717
"""
18-
function BestResponsePolicy(
19-
policy,
20-
env,
21-
best_responder;
22-
)
18+
function BestResponsePolicy(policy, env, best_responder;)
2319
S = eltype(state_space(env))
2420
A = eltype(action_space(env))
2521
E = typeof(env)

src/algorithms/cfr/external_sampling_mccfr.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@ RLBase.prob(p::ExternalSamplingMCCFRPolicy, env::AbstractEnv, action) =
2626
function ExternalSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_RNG)
2727
ExternalSamplingMCCFRPolicy(
2828
Dict{state_type,InfoStateNode}(),
29-
TabularRandomPolicy(;
30-
rng = rng,
31-
table = Dict{state_type,Vector{Float64}}(),
32-
),
29+
TabularRandomPolicy(; rng = rng, table = Dict{state_type,Vector{Float64}}()),
3330
rng,
3431
)
3532
end

src/algorithms/cfr/nash_conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ function nash_conv(π, env; is_reduce = true, kw...)
3131
RLBase.reset!(e)
3232

3333
σ′ = [
34-
best_response_value(BestResponsePolicy(π, e, i; kw...), e)
35-
for i in players(e) if i != chance_player(e)
34+
best_response_value(BestResponsePolicy(π, e, i; kw...), e) for
35+
i in players(e) if i != chance_player(e)
3636
]
3737

3838
σ = expected_policy_values(π, e)

src/algorithms/cfr/outcome_sampling_mccfr.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ RLBase.prob(p::OutcomeSamplingMCCFRPolicy, env::AbstractEnv, action) =
2727
function OutcomeSamplingMCCFRPolicy(; state_type = String, rng = Random.GLOBAL_RNG, ϵ = 0.6)
2828
OutcomeSamplingMCCFRPolicy(
2929
Dict{state_type,InfoStateNode}(),
30-
TabularRandomPolicy(;
31-
rng = rng,
32-
table = Dict{state_type,Vector{Float64}}(),
33-
),
30+
TabularRandomPolicy(; rng = rng, table = Dict{state_type,Vector{Float64}}()),
3431
ϵ,
3532
rng,
3633
)

src/algorithms/cfr/tabular_cfr.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@ end
99

1010
function InfoStateNode(mask)
1111
n = sum(mask)
12-
InfoStateNode(
13-
fill(1 / n, n),
14-
zeros(n),
15-
zeros(n),
16-
mask
17-
)
12+
InfoStateNode(fill(1 / n, n), zeros(n), zeros(n), mask)
1813
end
1914

2015
#####
@@ -67,10 +62,7 @@ function TabularCFRPolicy(;
6762
)
6863
TabularCFRPolicy(
6964
Dict{state_type,InfoStateNode}(),
70-
TabularRandomPolicy(;
71-
rng = rng,
72-
table = Dict{state_type,Vector{Float64}}(),
73-
),
65+
TabularRandomPolicy(; rng = rng, table = Dict{state_type,Vector{Float64}}()),
7466
is_reset_neg_regrets,
7567
is_linear_averaging,
7668
weighted_averaging_delay,

src/algorithms/dqns/iqn.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,11 @@ function IQNLearner(;
112112
)
113113
copyto!(approximator, target_approximator) # force sync
114114
if device(approximator) !== device(device_rng)
115-
throw(ArgumentError("device of `approximator` doesn't match the device of `device_rng`: $(device(approximator)) !== $(device_rng)"))
115+
throw(
116+
ArgumentError(
117+
"device of `approximator` doesn't match the device of `device_rng`: $(device(approximator)) !== $(device_rng)",
118+
),
119+
)
116120
end
117121
sampler = NStepBatchSampler{traces}(;
118122
γ = γ,

src/algorithms/offline_rl/behavior_cloning.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ function RLBase.update!(p::BehaviorCloningPolicy, batch::NamedTuple{(:state, :ac
3030
logitcrossentropy(ŷ, y)
3131
end
3232
update!(m, gs)
33-
end
33+
end
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
include("behavior_cloning.jl")
1+
include("behavior_cloning.jl")

src/algorithms/policy_gradient/MAC.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ function (learner::MACLearner)(env)
3939
learner.approximator.actor(s) |> vec |> send_to_host
4040
end
4141

42-
function RLBase.update!(learner::MACLearner, t::CircularArraySARTTrajectory, ::AbstractEnv, ::PreActStage)
42+
function RLBase.update!(
43+
learner::MACLearner,
44+
t::CircularArraySARTTrajectory,
45+
::AbstractEnv,
46+
::PreActStage,
47+
)
4348
length(t) == 0 && return # in the first update, only state & action is inserted into trajectory
4449
learner.update_step += 1
4550
if learner.update_step % learner.update_freq == 0
@@ -112,10 +117,14 @@ function _update!(learner::MACLearner, t::CircularArraySARTTrajectory)
112117
next_state_values = AC.critic(next_state_flattened)
113118
target_action_values =
114119
vec(rewards_flattened) .+
115-
γ * vec(Zygote.dropgrad(sum(
116-
next_state_values .* softmax(AC.actor(next_state_flattened)),
117-
dims = 1,
118-
)))
120+
γ * vec(
121+
Zygote.dropgrad(
122+
sum(
123+
next_state_values .* softmax(AC.actor(next_state_flattened)),
124+
dims = 1,
125+
),
126+
),
127+
)
119128
critic_loss =
120129
mean((vec(target_action_values) .- vec(action_values[actions])) .^ 2)
121130
end

0 commit comments

Comments
 (0)