Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 1780ac6

Browse files
authored
Adding Mean Actor Critic (#108)
* Update policy_gradient.jl * CartPole MAC experiment * MAC.jl * Adding Test for MAC * Update MAC.jl
1 parent 2e5a43f commit 1780ac6

4 files changed

Lines changed: 236 additions & 1 deletion

File tree

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
export MACLearner
2+
3+
using Flux
4+
5+
"""
6+
MACLearner(;kwargs...)
7+
Keyword arguments
8+
- `approximator`::[`ActorCritic`](@ref)
9+
- `γ::Float32`, reward discount rate
10+
- `bootstrap::bool`, if false then Q function is approximated using monte carlo returns.
11+
"""
12+
13+
Base.@kwdef mutable struct MACLearner{A<:ActorCritic} <:AbstractLearner
14+
approximator::A
15+
γ::Float32
16+
max_grad_norm::Union{Nothing,Float32} = nothing
17+
norm::Float32 = 0.0f0
18+
actor_loss::Float32 = 0.0f0
19+
critic_loss::Float32 = 0.0f0
20+
loss::Float32 = 0.0f0
21+
bootstrap::Bool = true
22+
end
23+
24+
function (learner::MACLearner)(env::MultiThreadEnv)
25+
learner.approximator.actor(send_to_device(
26+
device(learner.approximator),
27+
get_state(env),
28+
)) |> send_to_host
29+
end
30+
31+
function (learner::MACLearner)(env)
32+
s = get_state(env)
33+
s = Flux.unsqueeze(s, ndims(s) + 1)
34+
s = send_to_device(device(learner.approximator), s)
35+
learner.approximator.actor(s) |> vec |> send_to_host
36+
end
37+
38+
function RLBase.update!(learner::MACLearner, t::AbstractTrajectory)
39+
isfull(t) || return
40+
41+
states = t[:state]
42+
actions = t[:action]
43+
rewards = t[:reward]
44+
terminals = t[:terminal]
45+
46+
AC = learner.approximator
47+
γ = learner.γ
48+
D = device(AC)
49+
50+
states = send_to_device(D, states)
51+
states_flattened = flatten_batch(states) # (state_size..., n_thread * update_step)
52+
53+
54+
actions = flatten_batch(actions)
55+
actions = CartesianIndex.(actions, 1:length(actions))
56+
57+
if learner.bootstrap
58+
next_state = select_last_frame(t[:next_state])
59+
next_state = send_to_device(D, next_state)
60+
next_state_values = AC.critic(next_state)
61+
62+
gains = discount_rewards(
63+
rewards,
64+
γ;
65+
dims = 2,
66+
init = send_to_host(next_state_values),
67+
terminal = terminals,
68+
)
69+
gains = send_to_device(D, gains)
70+
else
71+
next_state_flattened = flatten_batch(t[:next_state])
72+
next_state_flattened = send_to_device(D, next_state_flattened)
73+
rewards_flattened = flatten_batch(rewards)
74+
rewards_flattened = send_to_device(D, rewards_flattened)
75+
end
76+
77+
action_values = AC.critic(states_flattened)
78+
79+
ps1 = Flux.params(AC.actor)
80+
gs1 = gradient(ps1) do
81+
logits = AC.actor(states_flattened)
82+
probs = softmax(logits)
83+
actor_loss = -mean(sum((probs .* Zygote.dropgrad(action_values)),dims=1))
84+
loss = actor_loss
85+
ignore() do
86+
learner.actor_loss = actor_loss
87+
end
88+
loss
89+
end
90+
if !isnothing(learner.max_grad_norm)
91+
learner.norm = clip_by_global_norm!(gs1, ps1, learner.max_grad_norm)
92+
end
93+
update!(AC.actor, gs1)
94+
95+
ps2 = Flux.params(AC.critic)
96+
gs2 = gradient(ps2) do
97+
if learner.bootstrap
98+
critic_loss = mean((vec(gains) .- vec(action_values[actions])).^ 2)
99+
else
100+
next_state_values = AC.critic(next_state_flattened)
101+
target_action_values = vec(rewards_flattened) .+ γ*vec(Zygote.dropgrad(sum(next_state_values.*softmax(AC.actor(next_state_flattened)),dims=1)))
102+
critic_loss = mean((vec(target_action_values) .- vec(action_values[actions])) .^ 2)
103+
end
104+
105+
loss = critic_loss
106+
ignore() do
107+
learner.critic_loss = critic_loss
108+
end
109+
loss
110+
end
111+
if !isnothing(learner.max_grad_norm)
112+
learner.norm = clip_by_global_norm!(gs2, ps2, learner.max_grad_norm)
113+
end
114+
update!(AC.critic, gs2)
115+
end
116+
117+
function (agent::Agent{<:QBasedPolicy{<:MACLearner},<:CircularCompactSARTSATrajectory})(
118+
::Training{PreActStage},
119+
env,
120+
)
121+
action = agent.policy(env)
122+
state = get_state(env)
123+
push!(agent.trajectory; state = state, action = action)
124+
update!(agent.policy, agent.trajectory)
125+
126+
# the main difference is we'd like to flush the buffer after each update!
127+
if isfull(agent.trajectory)
128+
empty!(agent.trajectory)
129+
push!(agent.trajectory; state = state, action = action)
130+
end
131+
132+
action
133+
end

src/algorithms/policy_gradient/policy_gradient.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ include("vpg.jl")
22
include("A2C.jl")
33
include("ppo.jl")
44
include("A2CGAE.jl")
5+
include("MAC.jl")
56
include("ddpg.jl")
67
include("td3.jl")
78
include("sac.jl")

src/experiments/rl_envs.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,107 @@ function RLCore.Experiment(
801801
)
802802
end
803803

804+
805+
function RLCore.Experiment(
806+
::Val{:JuliaRL},
807+
::Val{:MAC},
808+
::Val{:CartPole},
809+
::Nothing;
810+
save_dir = nothing,
811+
seed = 123,
812+
)
813+
if isnothing(save_dir)
814+
t = Dates.format(now(), "yyyy_mm_dd_HH_MM_SS")
815+
save_dir = joinpath(pwd(), "checkpoints", "JuliaRL_MAC_CartPole_$(t)")
816+
end
817+
818+
lg = TBLogger(joinpath(save_dir, "tb_log"), min_level = Logging.Info)
819+
rng = MersenneTwister(seed)
820+
N_ENV = 16
821+
UPDATE_FREQ = 20
822+
env = MultiThreadEnv([
823+
CartPoleEnv(; T = Float32, rng = MersenneTwister(hash(seed + i))) for i in 1:N_ENV
824+
])
825+
ns, na = length(get_state(env[1])), length(get_actions(env[1]))
826+
RLBase.reset!(env, is_force = true)
827+
828+
agent = Agent(
829+
policy = QBasedPolicy(
830+
learner = MACLearner(
831+
approximator = ActorCritic(
832+
actor = NeuralNetworkApproximator(
833+
model = Chain(
834+
Dense(ns, 30, relu; initW = glorot_uniform(rng)),
835+
Dense(30, 30, relu; initW = glorot_uniform(rng)),
836+
Dense(30, na; initW = glorot_uniform(rng)),
837+
838+
),
839+
optimizer = ADAM(1e-2),
840+
),
841+
critic = NeuralNetworkApproximator(
842+
model = Chain(
843+
Dense(ns, 30, relu; initW = glorot_uniform(rng)),
844+
Dense(30, 30, relu; initW = glorot_uniform(rng)),
845+
Dense(30, na; initW = glorot_uniform(rng)),
846+
),
847+
optimizer = ADAM(3e-3),
848+
),
849+
) |> cpu,
850+
γ = 0.99f0,
851+
852+
bootstrap = true,
853+
),
854+
explorer = BatchExplorer(GumbelSoftmaxExplorer()),#= seed = nothing =#
855+
),
856+
trajectory = CircularCompactSARTSATrajectory(;
857+
capacity = UPDATE_FREQ,
858+
state_type = Float32,
859+
state_size = (ns, N_ENV),
860+
action_type = Int,
861+
action_size = (N_ENV,),
862+
reward_type = Float32,
863+
reward_size = (N_ENV,),
864+
terminal_type = Bool,
865+
terminal_size = (N_ENV,),
866+
),
867+
)
868+
869+
stop_condition = StopAfterStep(haskey(ENV, "CI") ? 10_000 : 100_000)
870+
total_reward_per_episode = TotalBatchRewardPerEpisode(N_ENV)
871+
time_per_step = TimePerStep()
872+
hook = ComposedHook(
873+
total_reward_per_episode,
874+
time_per_step,
875+
DoEveryNStep() do t, agent, env
876+
with_logger(lg) do
877+
@info(
878+
"training",
879+
actor_loss = agent.policy.learner.actor_loss,
880+
critic_loss = agent.policy.learner.critic_loss,
881+
)
882+
for i in 1:length(env)
883+
if get_terminal(env[i])
884+
@info "training" reward = total_reward_per_episode.rewards[i][end] log_step_increment =
885+
0
886+
break
887+
end
888+
end
889+
end
890+
end,
891+
DoEveryNStep(10000) do t, agent, env
892+
RLCore.save(save_dir, agent)
893+
BSON.@save joinpath(save_dir, "stats.bson") total_reward_per_episode time_per_step
894+
end,
895+
)
896+
Experiment(
897+
agent,
898+
env,
899+
stop_condition,
900+
hook,
901+
Description("# MAC with CartPole", save_dir),
902+
)
903+
end
904+
804905
function RLCore.Experiment(
805906
::Val{:JuliaRL},
806907
::Val{:TD3},

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ end
5555
1 / mean(res.hook[2].times)
5656
end
5757

58-
for method in (:A2C, :A2CGAE, :PPO)
58+
for method in (:A2C, :A2CGAE, :PPO, :MAC)
5959
res = run(Experiment(
6060
Val(:JuliaRL),
6161
Val(method),

0 commit comments

Comments
 (0)