Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 7bc3013

Browse files
authored
Fix #118 (#119)
1 parent 096535a commit 7bc3013

1 file changed

Lines changed: 18 additions & 28 deletions

File tree

src/experiments/rl_envs.jl

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,20 +1039,15 @@ function RLCore.Experiment(
10391039
agent = Agent(
10401040
policy = PPOPolicy(
10411041
approximator = ActorCritic(
1042-
actor = NeuralNetworkApproximator(
1043-
model = Chain(
1044-
Dense(ns, 256, relu; initW = glorot_uniform(rng)),
1045-
Dense(256, na; initW = glorot_uniform(rng)),
1042+
actor = Chain(
1043+
Dense(ns, 256, relu; initW = glorot_uniform(rng)),
1044+
Dense(256, na; initW = glorot_uniform(rng)),
10461045
),
1047-
optimizer = ADAM(1e-3),
1048-
),
1049-
critic = NeuralNetworkApproximator(
1050-
model = Chain(
1051-
Dense(ns, 256, relu; initW = glorot_uniform(rng)),
1052-
Dense(256, 1; initW = glorot_uniform(rng)),
1046+
critic = Chain(
1047+
Dense(ns, 256, relu; initW = glorot_uniform(rng)),
1048+
Dense(256, 1; initW = glorot_uniform(rng)),
10531049
),
1054-
optimizer = ADAM(1e-3),
1055-
),
1050+
optimizer = ADAM(1e-3),
10561051
) |> cpu,
10571052
γ = 0.99f0,
10581053
λ = 0.95f0,
@@ -1325,25 +1320,20 @@ function RLCore.Experiment(
13251320
agent = Agent(
13261321
policy = PPOPolicy(
13271322
approximator = ActorCritic(
1328-
actor = NeuralNetworkApproximator(
1329-
model = GaussianNetwork(
1330-
pre = Chain(
1331-
Dense(ns, 64, relu; initW = glorot_uniform(rng)),
1332-
Dense(64, 64, relu; initW = glorot_uniform(rng)),
1333-
),
1334-
μ = Chain(Dense(64, 1, tanh; initW = glorot_uniform(rng)), vec),
1335-
σ = Chain(Dense(64, 1; initW = glorot_uniform(rng)), vec),
1336-
),
1337-
optimizer = ADAM(3e-4),
1338-
),
1339-
critic = NeuralNetworkApproximator(
1340-
model = Chain(
1323+
actor = GaussianNetwork(
1324+
pre = Chain(
13411325
Dense(ns, 64, relu; initW = glorot_uniform(rng)),
13421326
Dense(64, 64, relu; initW = glorot_uniform(rng)),
1343-
Dense(64, 1; initW = glorot_uniform(rng)),
13441327
),
1345-
optimizer = ADAM(3e-4),
1346-
),
1328+
μ = Chain(Dense(64, 1, tanh; initW = glorot_uniform(rng)), vec),
1329+
σ = Chain(Dense(64, 1; initW = glorot_uniform(rng)), vec),
1330+
),
1331+
critic = Chain(
1332+
Dense(ns, 64, relu; initW = glorot_uniform(rng)),
1333+
Dense(64, 64, relu; initW = glorot_uniform(rng)),
1334+
Dense(64, 1; initW = glorot_uniform(rng)),
1335+
),
1336+
optimizer = ADAM(3e-4),
13471337
) |> cpu,
13481338
γ = 0.99f0,
13491339
λ = 0.95f0,

0 commit comments

Comments
 (0)