Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit c5439e6

Browse files
authored
add GridWorlds environments (#152)
* ignore vim generated temp files * add JuliaRL_BasicDQN_EmptyRoom experiment * add per-step penalty and max-timeout per episode * add test for JuliaRL_BasicDQN_EmptyRoom * add JuliaRL_BasicDQN_EmptyRoom to README * add note on importing GridWorlds
1 parent 7913db6 commit c5439e6

7 files changed

Lines changed: 118 additions & 3 deletions

File tree

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
.DS_Store
22
/Manifest.toml
33
/dev/
4-
**/checkpoints/
4+
**/checkpoints/
5+
6+
# add vim generated temp files
7+
*~
8+
*.swp

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ Zygote = "0.5, 0.6"
5151
julia = "1.4"
5252

5353
[extras]
54+
GridWorlds = "e15a9946-cd7f-4d03-83e2-6c30bacb0043"
5455
OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
5556
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
5657
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5758

5859
[targets]
59-
test = ["Test", "ReinforcementLearningEnvironments", "OpenSpiel"]
60+
test = ["Test", "ReinforcementLearningEnvironments", "OpenSpiel", "GridWorlds"]

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Some built-in experiments are exported to help new users to easily run benchmark
6060
- ``E`JuliaRL_DeepCFR_OpenSpiel(leduc_poker)` ``
6161
- ``E`JuliaRL_DQN_SnakeGame` ``
6262
- ``E`JuliaRL_BC_CartPole` ``
63+
- ``E`JuliaRL_BasicDQN_EmptyRoom` ``
6364
- ``E`Dopamine_DQN_Atari(pong)` ``
6465
- ``E`Dopamine_Rainbow_Atari(pong)` ``
6566
- ``E`Dopamine_IQN_Atari(pong)` ``
@@ -87,7 +88,7 @@ julia> run(E`rlpyt_PPO_Atari(pong)`) # the Atari environment is provided in Arc
8788
- Experiments on `CartPole` usually run faster with CPU only due to the overhead of sending data between CPU and GPU.
8889
- It shouldn't surprise you that our experiments on `CartPole` are much faster than those written in Python. The secret is that our environment is written in Julia!
8990
- Remember to set `JULIA_NUM_THREADS` to enable multi-threading when using algorithms like `A2C` and `PPO`.
90-
- Experiments on `Atari` (`OpenSpiel`, `SnakeGame`) are only available after you have `ArcadeLearningEnvironment.jl` (`OpenSpiel.jl`, `SnakeGame.jl`) installed and `using ArcadeLearningEnvironment` (`using OpenSpiel`, `using SnakeGame`).
91+
- Experiments on `Atari` (`OpenSpiel`, `SnakeGame`, `GridWorlds`) are only available after you have `ArcadeLearningEnvironment.jl` (`OpenSpiel.jl`, `SnakeGame.jl`, `GridWorlds.jl`) installed and `using ArcadeLearningEnvironment` (`using OpenSpiel`, `using SnakeGame`, `import GridWorlds`).
9192

9293
### Speed
9394

src/ReinforcementLearningZoo.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ function __init__()
4242
@require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include(
4343
"experiments/open_spiel/open_spiel.jl",
4444
)
45+
@require GridWorlds = "e15a9946-cd7f-4d03-83e2-6c30bacb0043" include(
46+
"experiments/gridworlds/gridworlds.jl",
47+
)
4548
end
4649
end
4750

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
function RLCore.Experiment(
2+
::Val{:JuliaRL},
3+
::Val{:BasicDQN},
4+
::Val{:EmptyRoom},
5+
::Nothing;
6+
seed = 123,
7+
save_dir = nothing,
8+
)
9+
if isnothing(save_dir)
10+
t = Dates.format(now(), "yyyy_mm_dd_HH_MM_SS")
11+
save_dir = joinpath(pwd(), "checkpoints", "JuliaRL_BasicDQN_EmptyRoom$(t)")
12+
end
13+
log_dir = joinpath(save_dir, "tb_log")
14+
lg = TBLogger(log_dir, min_level = Logging.Info)
15+
rng = StableRNG(seed)
16+
17+
inner_env = GridWorlds.EmptyRoom(rng = rng)
18+
action_space_mapping = x -> Base.OneTo(length(RLBase.action_space(inner_env)))
19+
action_mapping = i -> RLBase.action_space(inner_env)[i]
20+
env = RLEnvs.ActionTransformedEnv(inner_env, action_space_mapping = action_space_mapping, action_mapping = action_mapping)
21+
env = RLEnvs.StateOverriddenEnv(env, x -> vec(Float32.(x)))
22+
env = RewardOverriddenEnv(env, x -> x - convert(typeof(x), 0.01))
23+
env = MaxTimeoutEnv(env, 240)
24+
25+
ns, na = length(state(env)), length(action_space(env))
26+
agent = Agent(
27+
policy = QBasedPolicy(
28+
learner = BasicDQNLearner(
29+
approximator = NeuralNetworkApproximator(
30+
model = Chain(
31+
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
32+
Dense(128, 128, relu; initW = glorot_uniform(rng)),
33+
Dense(128, na; initW = glorot_uniform(rng)),
34+
) |> cpu,
35+
optimizer = ADAM(),
36+
),
37+
batch_size = 32,
38+
min_replay_history = 100,
39+
loss_func = huber_loss,
40+
rng = rng,
41+
),
42+
explorer = EpsilonGreedyExplorer(
43+
kind = :exp,
44+
ϵ_stable = 0.01,
45+
decay_steps = 500,
46+
rng = rng,
47+
),
48+
),
49+
trajectory = CircularArraySARTTrajectory(
50+
capacity = 1000,
51+
state = Vector{Float32} => (ns,),
52+
),
53+
)
54+
55+
stop_condition = StopAfterStep(10_000)
56+
57+
total_reward_per_episode = TotalRewardPerEpisode()
58+
time_per_step = TimePerStep()
59+
hook = ComposedHook(
60+
total_reward_per_episode,
61+
time_per_step,
62+
DoEveryNStep() do t, agent, env
63+
with_logger(lg) do
64+
@info "training" loss = agent.policy.learner.loss
65+
end
66+
end,
67+
DoEveryNEpisode() do t, agent, env
68+
with_logger(lg) do
69+
@info "training" reward = total_reward_per_episode.rewards[end] log_step_increment =
70+
0
71+
end
72+
end,
73+
)
74+
75+
description = """
76+
This experiment uses three dense layers to approximate the Q value.
77+
The testing environment is EmptyRoom.
78+
79+
You can view the runtime logs with `tensorboard --logdir $log_dir`.
80+
Some useful statistics are stored in the `hook` field of this experiment.
81+
"""
82+
83+
Experiment(agent, env, stop_condition, hook, description)
84+
end
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import .GridWorlds
2+
3+
include("JuliaRL_BasicDQN_EmptyRoom.jl")

test/runtests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Statistics
88
using Random
99
using OpenSpiel
1010
using StableRNGs
11+
import GridWorlds
1112

1213
function get_optimal_kuhn_policy= 0.2)
1314
TabularRandomPolicy(
@@ -96,6 +97,24 @@ end
9697
@test e.hook[1][] == e.hook[0][] == [0.0]
9798
end
9899

100+
@testset "GridWorlds" begin
101+
mktempdir() do dir
102+
for method in (:BasicDQN,)
103+
res = run(
104+
Experiment(
105+
Val(:JuliaRL),
106+
Val(method),
107+
Val(:EmptyRoom),
108+
nothing;
109+
save_dir = joinpath(dir, "EmptyRoom", string(method)),
110+
),
111+
)
112+
@info "stats for $method" avg_reward = mean(res.hook[1].rewards) avg_fps =
113+
1 / mean(res.hook[2].times)
114+
end
115+
end
116+
end
117+
99118
@testset "TabularCFR" begin
100119
e = E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)`
101120
run(e)

0 commit comments

Comments
 (0)