Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 96de067

Browse files
authored
Add behavior cloning (#146)
* add behavior cloning * add TODO * add experiment for bc * fix test error * update readme
1 parent f3da80b commit 96de067

7 files changed

Lines changed: 124 additions & 2 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Flux = "0.11"
4040
IntervalSets = "0.5"
4141
MacroTools = "0.5"
4242
ReinforcementLearningBase = "0.9"
43-
ReinforcementLearningCore = "0.7"
43+
ReinforcementLearningCore = "0.7.2"
4444
Requires = "1"
4545
Setfield = "0.6, 0.7"
4646
StableRNGs = "1.0"

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ This project aims to provide some implementations of the most typical reinforcem
2727
- SAC
2828
- CFR/OS-MCCFR/ES-MCCFR/DeepCFR
2929
- Minimax
30+
- Behavior Cloning
3031

3132
If you are looking for tabular reinforcement learning algorithms, you may refer [ReinforcementLearningAnIntroduction.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearningAnIntroduction.jl).
3233

@@ -58,6 +59,7 @@ Some built-in experiments are exported to help new users to easily run benchmark
5859
- ``E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)` ``
5960
- ``E`JuliaRL_DeepCFR_OpenSpiel(leduc_poker)` ``
6061
- ``E`JuliaRL_DQN_SnakeGame` ``
62+
- ``E`JuliaRL_BC_CartPole` ``
6163
- ``E`Dopamine_DQN_Atari(pong)` ``
6264
- ``E`Dopamine_Rainbow_Atari(pong)` ``
6365
- ``E`Dopamine_IQN_Atari(pong)` ``

src/algorithms/algorithms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ include("dqns/dqns.jl")
33
include("policy_gradient/policy_gradient.jl")
44
include("searching/searching.jl")
55
include("cfr/cfr.jl")
6+
include("offline_rl/offline_rl.jl")
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
export BehaviorCloningPolicy
2+
3+
"""
4+
BehaviorCloningPolicy(;kw...)
5+
6+
# Keyword Arguments
7+
8+
- `approximator`: calculate the logits of possible actions directly
9+
- `explorer=GreedyExplorer()`
10+
11+
"""
12+
Base.@kwdef struct BehaviorCloningPolicy{A} <: AbstractPolicy
13+
approximator::A
14+
explorer::Any = GreedyExplorer()
15+
end
16+
17+
function (p::BehaviorCloningPolicy)(env::AbstractEnv)
18+
s = state(env)
19+
s_batch = Flux.unsqueeze(s, ndims(s) + 1)
20+
logits = p.approximator(s_batch) |> vec # drop dimension
21+
p.explorer(logits)
22+
end
23+
24+
function RLBase.update!(p::BehaviorCloningPolicy, batch::NamedTuple{(:state, :action)})
25+
s, a = batch.state, batch.action
26+
m = p.approximator
27+
gs = gradient(params(m)) do
28+
= m(s)
29+
y = Flux.onehotbatch(a, axes(ŷ, 1))
30+
logitcrossentropy(ŷ, y)
31+
end
32+
update!(m, gs)
33+
end
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include("behavior_cloning.jl")
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
Base.@kwdef struct RecordStateAction <: AbstractHook
2+
records::Any = VectorSATrajectory(;state=Vector{Float32})
3+
end
4+
5+
function (h::RecordStateAction)(::PreActStage, policy, env, action)
6+
push!(h.records;state=copy(state(env)), action=action)
7+
end
8+
9+
function RLCore.Experiment(
10+
::Val{:JuliaRL},
11+
::Val{:BC},
12+
::Val{:CartPole},
13+
::Nothing;
14+
seed = 123,
15+
save_dir = nothing,
16+
)
17+
rng = StableRNG(seed)
18+
19+
env = CartPoleEnv(; T = Float32, rng = rng)
20+
ns, na = length(state(env)), length(action_space(env))
21+
agent = Agent(
22+
policy = QBasedPolicy(
23+
learner = BasicDQNLearner(
24+
approximator = NeuralNetworkApproximator(
25+
model = Chain(
26+
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
27+
Dense(128, 128, relu; initW = glorot_uniform(rng)),
28+
Dense(128, na; initW = glorot_uniform(rng)),
29+
) |> cpu,
30+
optimizer = ADAM(),
31+
),
32+
batch_size = 32,
33+
min_replay_history = 100,
34+
loss_func = huber_loss,
35+
rng = rng,
36+
),
37+
explorer = EpsilonGreedyExplorer(
38+
kind = :exp,
39+
ϵ_stable = 0.01,
40+
decay_steps = 500,
41+
rng = rng,
42+
),
43+
),
44+
trajectory = CircularArraySARTTrajectory(
45+
capacity = 1000,
46+
state = Vector{Float32} => (ns,),
47+
),
48+
)
49+
50+
stop_condition = StopAfterStep(10_000)
51+
hook = RecordStateAction()
52+
run(agent, env, stop_condition, hook)
53+
54+
bc = BehaviorCloningPolicy(
55+
approximator = NeuralNetworkApproximator(
56+
model = Chain(
57+
Dense(ns, 128, relu; initW = glorot_uniform(rng)),
58+
Dense(128, 128, relu; initW = glorot_uniform(rng)),
59+
Dense(128, na; initW = glorot_uniform(rng)),
60+
) |> cpu,
61+
optimizer = ADAM(),
62+
)
63+
)
64+
65+
s = BatchSampler{(:state, :action)}(32;)
66+
67+
for i in 1:300
68+
_, batch = s(hook.records)
69+
RLBase.update!(bc, batch)
70+
end
71+
72+
description = """
73+
# Behavior Cloning with CartPole
74+
75+
This experiment uses transitions during the experiment
76+
`JuliaRL_BasicDQN_CartPole` to train a behavior policy.
77+
"""
78+
79+
hook = ComposedHook(
80+
TotalRewardPerEpisode(),
81+
TimePerStep(),
82+
)
83+
84+
Experiment(bc, env, StopAfterEpisode(100), hook, description)
85+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232

3333
@testset "training" begin
3434
mktempdir() do dir
35-
for method in (:BasicDQN, :DQN, :PrioritizedDQN, :Rainbow, :IQN, :VPG)
35+
for method in (:BasicDQN, :BC, :DQN, :PrioritizedDQN, :Rainbow, :IQN, :VPG)
3636
res = run(Experiment(
3737
Val(:JuliaRL),
3838
Val(method),

0 commit comments

Comments
 (0)