Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit d4ead5e

Browse files
authored
Prepare for the next release of RLCore (#129)
* sync * fix experiments in rl_env * fix experiments of Atari * bugfix with ppo * fix tests * add compat * create Project.toml in test * only run tests on ubuntu
1 parent 2f6c1e2 commit d4ead5e

50 files changed

Lines changed: 3691 additions & 3360 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/ci.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@ jobs:
1616
matrix:
1717
version:
1818
- '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia.
19-
- 'nightly'
2019
os:
2120
- ubuntu-latest
22-
- macOS-latest
23-
- windows-latest
2421
arch:
2522
- x64
2623
steps:

Project.toml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.2.2"
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
99
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
10+
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
1011
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1112
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1213
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -34,7 +35,7 @@ Distributions = "0.24"
3435
Flux = "0.11"
3536
MacroTools = "0.5"
3637
ReinforcementLearningBase = "0.8.4"
37-
ReinforcementLearningCore = "0.5"
38+
ReinforcementLearningCore = "0.6"
3839
Requires = "1"
3940
Setfield = "0.6, 0.7"
4041
StableRNGs = "1.0"
@@ -43,11 +44,3 @@ StructArrays = "0.4"
4344
TensorBoardLogger = "0.1"
4445
Zygote = "0.5"
4546
julia = "1.4"
46-
47-
[extras]
48-
OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
49-
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
50-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
51-
52-
[targets]
53-
test = ["Test", "ReinforcementLearningEnvironments", "OpenSpiel"]

src/ReinforcementLearningZoo.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,21 @@ using ReinforcementLearningBase
77
using ReinforcementLearningCore
88
using Setfield: @set
99
using StableRNGs
10+
using Logging
11+
using Flux.Losses
12+
using Dates
1013

1114
include("patch.jl")
1215
include("algorithms/algorithms.jl")
13-
include("utils.jl")
1416

1517
using Requires
1618

1719
# dynamic loading environments
1820
function __init__()
1921
@require ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" begin
20-
include("experiments/rl_envs.jl")
21-
@require ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" include("experiments/atari.jl")
22-
@require SnakeGames = "34dccd9f-48d6-4445-aa0f-8c2e373b5429" include("experiments/snake.jl")
23-
@require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include("experiments/open_spiel.jl")
22+
include("experiments/rl_envs/rl_envs.jl")
23+
@require ArcadeLearningEnvironment = "b7f77d8d-088d-5e02-8ac0-89aab2acc977" include("experiments/atari/atari.jl")
24+
# @require OpenSpiel = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2" include("experiments/open_spiel/open_spiel.jl")
2425
end
2526
end
2627

src/algorithms/dqns/basic_dqn.jl

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ This is the very basic implementation of DQN. Compared to the traditional Q lear
1212
in the updating step it uses a batch of transitions sampled from an experience buffer instead of current transition.
1313
And the `approximator` is usually a [`NeuralNetworkApproximator`](@ref).
1414
You can start from this implementation to understand how everything is organized and how to write your own customized algorithm.
15+
1516
# Keywords
17+
1618
- `approximator`::[`AbstractApproximator`](@ref): used to get Q-values of a state.
17-
- `loss_func`: the loss function to use. TODO: provide a default [`huber_loss`](@ref)?
19+
- `loss_func`: the loss function to use.
1820
- `γ::Float32=0.99f0`: discount rate.
1921
- `batch_size::Int=32`
2022
- `min_replay_history::Int=32`: number of transitions that should be experienced before updating the `approximator`.
@@ -24,9 +26,10 @@ mutable struct BasicDQNLearner{Q,F,R} <: AbstractLearner
2426
approximator::Q
2527
loss_func::F
2628
γ::Float32
27-
batch_size::Int
29+
sampler::BatchSampler
2830
min_replay_history::Int
2931
rng::R
32+
# for debugging
3033
loss::Float32
3134
end
3235

@@ -38,14 +41,13 @@ end
3841
(learner::BasicDQNLearner)(env) =
3942
env |>
4043
get_state |>
41-
x ->
42-
send_to_device(device(learner.approximator), x) |>
43-
learner.approximator |>
44-
send_to_host
44+
x -> send_to_device(device(learner), x) |>
45+
learner.approximator |>
46+
send_to_host
4547

4648
function BasicDQNLearner(;
4749
approximator::Q,
48-
loss_func::F,
50+
loss_func::F = huber_loss,
4951
γ = 0.99f0,
5052
batch_size = 32,
5153
min_replay_history = 32,
@@ -55,45 +57,28 @@ function BasicDQNLearner(;
5557
approximator,
5658
loss_func,
5759
γ,
58-
batch_size,
60+
BatchSampler{SARTS}(batch_size),
5961
min_replay_history,
6062
rng,
6163
0.0,
6264
)
6365
end
6466

65-
function RLBase.update!(learner::BasicDQNLearner, T::AbstractTrajectory)
66-
length(T[:terminal]) < learner.min_replay_history && return
67-
68-
inds = rand(learner.rng, 1:length(T[:terminal]), learner.batch_size)
69-
70-
batch = (
71-
state = consecutive_view(T[:state], inds),
72-
action = consecutive_view(T[:action], inds),
73-
reward = consecutive_view(T[:reward], inds),
74-
terminal = consecutive_view(T[:terminal], inds),
75-
next_state = consecutive_view(T[:next_state], inds),
76-
)
77-
78-
update!(learner, batch)
67+
function RLBase.update!(learner::BasicDQNLearner, traj::AbstractTrajectory)
68+
if length(traj) >= learner.min_replay_history
69+
inds, batch = sample(learner.rng, traj, learner.sampler)
70+
update!(learner, batch)
71+
end
7972
end
8073

81-
function RLBase.update!(learner::BasicDQNLearner, batch::NamedTuple)
74+
function RLBase.update!(learner::BasicDQNLearner, batch::NamedTuple{SARTS})
8275

8376
Q = learner.approximator
84-
D = device(Q)
8577
γ = learner.γ
8678
loss_func = learner.loss_func
8779

88-
batch_size = nframes(batch.terminal)
89-
90-
s = send_to_device(D, batch.state)
91-
a = batch.action
92-
r = send_to_device(D, batch.reward)
93-
t = send_to_device(D, batch.terminal)
94-
s′ = send_to_device(D, batch.next_state)
95-
96-
a = CartesianIndex.(a, 1:batch_size)
80+
s, a, r, t, s′ = send_to_device(device(Q), batch)
81+
a = CartesianIndex.(a, 1:length(a))
9782

9883
gs = gradient(params(Q)) do
9984
q = Q(s)[a]

src/algorithms/dqns/common.jl

Lines changed: 9 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,73 +4,7 @@
44

55
const PERLearners = Union{PrioritizedDQNLearner,RainbowLearner,IQNLearner}
66

7-
function extract_experience(t::AbstractTrajectory, learner::PERLearners)
8-
s = learner.stack_size
9-
h = learner.update_horizon
10-
n = learner.batch_size
11-
γ = learner.γ
12-
13-
# 1. sample indices based on priority
14-
valid_ind_range =
15-
isnothing(s) ? (1:(length(t[:terminal])-h)) : (s:(length(t[:terminal])-h))
16-
if haskey(t, :priority)
17-
inds = Vector{Int}(undef, n)
18-
priorities = Vector{Float32}(undef, n)
19-
for i in 1:n
20-
ind, p = sample(learner.rng, t[:priority])
21-
while ind valid_ind_range
22-
ind, p = sample(learner.rng, t[:priority])
23-
end
24-
inds[i] = ind
25-
priorities[i] = p
26-
end
27-
else
28-
inds = rand(learner.rng, valid_ind_range, n)
29-
priorities = nothing
30-
end
31-
32-
next_inds = inds .+ h
33-
34-
# 2. extract SARTS
35-
states = consecutive_view(t[:state], inds; n_stack = s)
36-
actions = consecutive_view(t[:action], inds)
37-
next_states = consecutive_view(t[:state], next_inds; n_stack = s)
38-
39-
if haskey(t, :legal_actions_mask)
40-
legal_actions_mask = consecutive_view(t[:legal_actions_mask], inds)
41-
next_legal_actions_mask = consecutive_view(t[:next_legal_actions_mask], inds)
42-
else
43-
legal_actions_mask = nothing
44-
next_legal_actions_mask = nothing
45-
end
46-
47-
consecutive_rewards = consecutive_view(t[:reward], inds; n_horizon = h)
48-
consecutive_terminals = consecutive_view(t[:terminal], inds; n_horizon = h)
49-
rewards, terminals = zeros(Float32, n), fill(false, n)
50-
51-
rewards = discount_rewards_reduced(
52-
consecutive_rewards,
53-
γ;
54-
terminal = consecutive_terminals,
55-
dims = 1,
56-
)
57-
terminals = mapslices(any, consecutive_terminals; dims = 1) |> vec
58-
59-
inds,
60-
(
61-
states = states,
62-
legal_actions_mask = legal_actions_mask,
63-
actions = actions,
64-
rewards = rewards,
65-
terminals = terminals,
66-
next_states = next_states,
67-
next_legal_actions_mask = next_legal_actions_mask,
68-
priorities = priorities,
69-
)
70-
end
71-
72-
function RLBase.update!(p::QBasedPolicy{<:PERLearners}, t::AbstractTrajectory)
73-
learner = p.learner
7+
function RLBase.update!(learner::Union{DQNLearner, PERLearners}, t::AbstractTrajectory)
748
length(t[:terminal]) < learner.min_replay_history && return
759

7610
learner.update_step += 1
@@ -81,20 +15,18 @@ function RLBase.update!(p::QBasedPolicy{<:PERLearners}, t::AbstractTrajectory)
8115

8216
learner.update_step % learner.update_freq == 0 || return
8317

84-
inds, experience = extract_experience(t, p.learner)
18+
inds, batch = sample(learner.rng, t, learner.sampler)
8519

86-
if haskey(t, :priority)
87-
priorities = update!(p.learner, experience)
20+
if t isa PrioritizedTrajectory
21+
priorities = update!(learner, batch)
8822
t[:priority][inds] .= priorities
8923
else
90-
update!(p.learner, experience)
24+
update!(learner,batch)
9125
end
9226
end
9327

94-
function (agent::Agent{<:QBasedPolicy{<:PERLearners}})(::RLCore.Training{PostActStage}, env)
95-
push!(agent.trajectory; reward = get_reward(env), terminal = get_terminal(env))
96-
if haskey(agent.trajectory, :priority)
97-
push!(agent.trajectory; priority = agent.policy.learner.default_priority)
98-
end
99-
nothing
28+
function RLBase.update!(trajectory::PrioritizedTrajectory, p::QBasedPolicy{<:PERLearners}, env::AbstractEnv, ::PostActStage)
29+
push!(trajectory[:reward], get_reward(env))
30+
push!(trajectory[:terminal], get_terminal(env))
31+
push!(trajectory[:priority], p.learner.default_priority)
10032
end

0 commit comments

Comments
 (0)