|
| 1 | +export TD3Policy, TD3Critic |
| 2 | + |
| 3 | +using Random |
| 4 | +using Flux |
| 5 | + |
| 6 | +struct TD3Critic |
| 7 | + critic_1::Flux.Chain |
| 8 | + critic_2::Flux.Chain |
| 9 | +end |
| 10 | +Flux.@functor TD3Critic |
| 11 | +(c::TD3Critic)(s, a) = (inp = vcat(s, a); (c.critic_1(inp), c.critic_2(inp))) |
| 12 | + |
| 13 | +mutable struct TD3Policy{ |
| 14 | + BA<:NeuralNetworkApproximator, |
| 15 | + BC<:NeuralNetworkApproximator, |
| 16 | + TA<:NeuralNetworkApproximator, |
| 17 | + TC<:NeuralNetworkApproximator, |
| 18 | + P, |
| 19 | + R<:AbstractRNG, |
| 20 | +} <: AbstractPolicy |
| 21 | + |
| 22 | + behavior_actor::BA |
| 23 | + behavior_critic::BC |
| 24 | + target_actor::TA |
| 25 | + target_critic::TC |
| 26 | + γ::Float32 |
| 27 | + ρ::Float32 |
| 28 | + batch_size::Int |
| 29 | + start_steps::Int |
| 30 | + start_policy::P |
| 31 | + update_after::Int |
| 32 | + update_every::Int |
| 33 | + policy_freq::Int |
| 34 | + target_act_limit::Float64 |
| 35 | + target_act_noise::Float64 |
| 36 | + act_limit::Float64 |
| 37 | + act_noise::Float64 |
| 38 | + step::Int |
| 39 | + rng::R |
| 40 | + replay_counter::Int |
| 41 | + # for logging |
| 42 | + actor_loss::Float32 |
| 43 | + critic_loss::Float32 |
| 44 | +end |
| 45 | + |
| 46 | +""" |
| 47 | + TD3Policy(;kwargs...) |
| 48 | +
|
| 49 | +# Keyword arguments |
| 50 | +
|
| 51 | +- `behavior_actor`, |
| 52 | +- `behavior_critic`, |
| 53 | +- `target_actor`, |
| 54 | +- `target_critic`, |
| 55 | +- `start_policy`, |
| 56 | +- `γ = 0.99f0`, |
| 57 | +- `ρ = 0.995f0`, |
| 58 | +- `batch_size = 32`, |
| 59 | +- `start_steps = 10000`, |
| 60 | +- `update_after = 1000`, |
| 61 | +- `update_every = 50`, |
| 62 | +- `policy_freq = 2` # frequency in which the actor performs a gradient step and critic target is updated |
| 63 | +- `target_act_limit = 1.0`, # noise added to actor target |
| 64 | +- `target_act_noise = 0.1`, # noise added to actor target |
| 65 | +- `act_limit = 1.0`, # noise added when outputing action |
| 66 | +- `act_noise = 0.1`, # noise added when outputing action |
| 67 | +- `step = 0`, |
| 68 | +- `rng = Random.GLOBAL_RNG`, |
| 69 | +""" |
| 70 | +function TD3Policy(; |
| 71 | + behavior_actor, |
| 72 | + behavior_critic, |
| 73 | + target_actor, |
| 74 | + target_critic, |
| 75 | + start_policy, |
| 76 | + γ = 0.99f0, |
| 77 | + ρ = 0.995f0, |
| 78 | + batch_size = 64, |
| 79 | + start_steps = 10000, |
| 80 | + update_after = 1000, |
| 81 | + update_every = 50, |
| 82 | + policy_freq = 2, |
| 83 | + target_act_limit = 1.0, |
| 84 | + target_act_noise = 0.1, |
| 85 | + act_limit = 1.0, |
| 86 | + act_noise = 0.1, |
| 87 | + step = 0, |
| 88 | + rng = Random.GLOBAL_RNG, |
| 89 | +) |
| 90 | + copyto!(behavior_actor, target_actor) # force sync |
| 91 | + copyto!(behavior_critic, target_critic) # force sync |
| 92 | + TD3Policy( |
| 93 | + behavior_actor, |
| 94 | + behavior_critic, |
| 95 | + target_actor, |
| 96 | + target_critic, |
| 97 | + γ, |
| 98 | + ρ, |
| 99 | + batch_size, |
| 100 | + start_steps, |
| 101 | + start_policy, |
| 102 | + update_after, |
| 103 | + update_every, |
| 104 | + policy_freq, |
| 105 | + target_act_limit, |
| 106 | + target_act_noise, |
| 107 | + act_limit, |
| 108 | + act_noise, |
| 109 | + step, |
| 110 | + rng, |
| 111 | + 1, # keep track of numbers of replay |
| 112 | + 0.f0, |
| 113 | + 0.f0, |
| 114 | + ) |
| 115 | +end |
| 116 | + |
| 117 | +# TODO: handle Training/Testing mode |
| 118 | +function (p::TD3Policy)(env) |
| 119 | + p.step += 1 |
| 120 | + |
| 121 | + if p.step <= p.start_steps |
| 122 | + p.start_policy(env) |
| 123 | + else |
| 124 | + D = device(p.behavior_actor) |
| 125 | + s = get_state(env) |
| 126 | + s = Flux.unsqueeze(s, ndims(s) + 1) |
| 127 | + action = p.behavior_actor(send_to_device(D, s)) |> vec |> send_to_host |
| 128 | + clamp(action[] + randn(p.rng) * p.act_noise, -p.act_limit, p.act_limit) |
| 129 | + end |
| 130 | +end |
| 131 | + |
| 132 | +function RLBase.update!(p::TD3Policy, traj::CircularCompactSARTSATrajectory) |
| 133 | + length(traj[:terminal]) > p.update_after || return |
| 134 | + p.step % p.update_every == 0 || return |
| 135 | + |
| 136 | + inds = rand(p.rng, 1:(length(traj[:terminal])-1), p.batch_size) |
| 137 | + s = select_last_dim(traj[:state], inds) |
| 138 | + a = select_last_dim(traj[:action], inds) |
| 139 | + r = select_last_dim(traj[:reward], inds) |
| 140 | + t = select_last_dim(traj[:terminal], inds) |
| 141 | + s′ = select_last_dim(traj[:next_state], inds) |
| 142 | + |
| 143 | + actor = p.behavior_actor |
| 144 | + critic = p.behavior_critic |
| 145 | + |
| 146 | + # !!! we have several assumptions here, need revisit when we have more complex environments |
| 147 | + # state is vector |
| 148 | + # action is scalar |
| 149 | + target_noise = clamp.( |
| 150 | + randn(p.rng, Float32, 1, p.batch_size) .* p.target_act_noise, |
| 151 | + -p.target_act_limit, |
| 152 | + p.target_act_limit, |
| 153 | + ) |
| 154 | + # add noise and clip to tanh bounds |
| 155 | + a′ = clamp.(p.target_actor(s′) + target_noise, -1f0, 1f0) |
| 156 | + |
| 157 | + q_1′, q_2′ = p.target_critic(s′, a′) |
| 158 | + y = r .+ p.γ .* (1 .- t) .* (min.(q_1′, q_2′) |> vec) |
| 159 | + a = Flux.unsqueeze(a, 1) |
| 160 | + |
| 161 | + gs1 = gradient(Flux.params(critic)) do |
| 162 | + q1, q2 = critic(s, a) |
| 163 | + loss = mse(q1 |> vec, y) + mse(q2 |> vec, y) |
| 164 | + ignore() do |
| 165 | + p.critic_loss = loss |
| 166 | + end |
| 167 | + loss |
| 168 | + end |
| 169 | + update!(critic, gs1) |
| 170 | + |
| 171 | + if p.replay_counter % p.policy_freq == 0 |
| 172 | + gs2 = gradient(Flux.params(actor)) do |
| 173 | + actions = actor(s) |
| 174 | + loss = -mean(critic.model.critic_1(vcat(s, actions))) |
| 175 | + ignore() do |
| 176 | + p.actor_loss = loss |
| 177 | + end |
| 178 | + loss |
| 179 | + end |
| 180 | + update!(actor, gs2) |
| 181 | + # polyak averaging |
| 182 | + for (dest, src) in zip(Flux.params([p.target_actor, p.target_critic]), Flux.params([actor, critic])) |
| 183 | + dest .= p.ρ .* dest .+ (1 - p.ρ) .* src |
| 184 | + end |
| 185 | + p.replay_counter = 1 |
| 186 | + end |
| 187 | + p.replay_counter += 1 |
| 188 | +end |
0 commit comments