|
| 1 | +export SACPolicy, SACPolicyNetwork |
| 2 | + |
| 3 | +using Random |
| 4 | +using Flux |
| 5 | +using Flux.Losses: mse |
| 6 | +using Distributions: Normal, logpdf |
| 7 | + |
| 8 | +# Define SAC Actor |
| 9 | +struct SACPolicyNetwork |
| 10 | + pre::Chain |
| 11 | + mean::Chain |
| 12 | + log_std::Chain |
| 13 | +end |
| 14 | +Flux.@functor SACPolicyNetwork |
| 15 | +(m::SACPolicyNetwork)(state) = (x = m.pre(state); (m.mean(x), m.log_std(x))) |
| 16 | + |
| 17 | +mutable struct SACPolicy{ |
| 18 | + BA<:NeuralNetworkApproximator, |
| 19 | + BC1<:NeuralNetworkApproximator, |
| 20 | + BC2<:NeuralNetworkApproximator, |
| 21 | + P, |
| 22 | + R<:AbstractRNG, |
| 23 | +} <: AbstractPolicy |
| 24 | + |
| 25 | + policy::BA |
| 26 | + qnetwork1::BC1 |
| 27 | + qnetwork2::BC2 |
| 28 | + target_qnetwork1::BC1 |
| 29 | + target_qnetwork2::BC2 |
| 30 | + γ::Float32 |
| 31 | + ρ::Float32 |
| 32 | + α::Float32 |
| 33 | + batch_size::Int |
| 34 | + start_steps::Int |
| 35 | + start_policy::P |
| 36 | + update_after::Int |
| 37 | + update_every::Int |
| 38 | + step::Int |
| 39 | + rng::R |
| 40 | +end |
| 41 | + |
| 42 | +""" |
| 43 | + SACPolicy(;kwargs...) |
| 44 | +
|
| 45 | +# Keyword arguments |
| 46 | +
|
| 47 | +- `policy`, |
| 48 | +- `qnetwork1`, |
| 49 | +- `qnetwork2`, |
| 50 | +- `target_qnetwork1`, |
| 51 | +- `target_qnetwork2`, |
| 52 | +- `start_policy`, |
| 53 | +- `γ = 0.99f0`, |
| 54 | +- `ρ = 0.995f0`, |
| 55 | +- `α = 0.2f0`, |
| 56 | +- `batch_size = 32`, |
| 57 | +- `start_steps = 10000`, |
| 58 | +- `update_after = 1000`, |
| 59 | +- `update_every = 50`, |
| 60 | +- `step = 0`, |
| 61 | +- `rng = Random.GLOBAL_RNG`, |
| 62 | +""" |
| 63 | +function SACPolicy(; |
| 64 | + policy, |
| 65 | + qnetwork1, |
| 66 | + qnetwork2, |
| 67 | + target_qnetwork1, |
| 68 | + target_qnetwork2, |
| 69 | + start_policy, |
| 70 | + γ = 0.99f0, |
| 71 | + ρ = 0.995f0, |
| 72 | + α = 0.2f0, |
| 73 | + batch_size = 32, |
| 74 | + start_steps = 10000, |
| 75 | + update_after = 1000, |
| 76 | + update_every = 50, |
| 77 | + step = 0, |
| 78 | + rng = Random.GLOBAL_RNG, |
| 79 | +) |
| 80 | + copyto!(qnetwork1, target_qnetwork1) # force sync |
| 81 | + copyto!(qnetwork2, target_qnetwork2) # force sync |
| 82 | + SACPolicy( |
| 83 | + policy, |
| 84 | + qnetwork1, |
| 85 | + qnetwork2, |
| 86 | + target_qnetwork1, |
| 87 | + target_qnetwork2, |
| 88 | + γ, |
| 89 | + ρ, |
| 90 | + α, |
| 91 | + batch_size, |
| 92 | + start_steps, |
| 93 | + start_policy, |
| 94 | + update_after, |
| 95 | + update_every, |
| 96 | + step, |
| 97 | + rng, |
| 98 | + ) |
| 99 | +end |
| 100 | + |
| 101 | +# TODO: handle Training/Testing mode |
| 102 | +function (p::SACPolicy)(env) |
| 103 | + p.step += 1 |
| 104 | + |
| 105 | + if p.step <= p.start_steps |
| 106 | + p.start_policy(env) |
| 107 | + else |
| 108 | + D = device(p.policy) |
| 109 | + s = get_state(env) |
| 110 | + s = Flux.unsqueeze(s, ndims(s) + 1) |
| 111 | + # trainmode: |
| 112 | + action = evaluate(p, s)[1][] # returns action as scalar |
| 113 | + |
| 114 | + # testmode: |
| 115 | + # if testing dont sample an action, but act deterministically by |
| 116 | + # taking the "mean" action |
| 117 | + # action = p.policy(s)[1][] # returns action as scalar |
| 118 | + end |
| 119 | +end |
| 120 | + |
| 121 | +""" |
| 122 | +This function is compatible with a multidimensional action space. |
| 123 | +""" |
| 124 | +function evaluate(p::SACPolicy, state) |
| 125 | + μ, log_σ = p.policy(state) |
| 126 | + π_dist = Normal.(μ, exp.(log_σ)) |
| 127 | + z = rand.(p.rng, π_dist) |
| 128 | + logp_π = sum(logpdf.(π_dist, z), dims = 1) |
| 129 | + logp_π -= sum((2f0 .* (log(2f0) .- z - softplus.(-2f0 * z))), dims = 1) |
| 130 | + return tanh.(z), logp_π |
| 131 | +end |
| 132 | + |
| 133 | +function RLBase.update!(p::SACPolicy, traj::CircularCompactSARTSATrajectory) |
| 134 | + length(traj[:terminal]) > p.update_after || return |
| 135 | + p.step % p.update_every == 0 || return |
| 136 | + |
| 137 | + inds = rand(p.rng, 1:(length(traj[:terminal])-1), p.batch_size) |
| 138 | + s = select_last_dim(traj[:state], inds) |
| 139 | + a = select_last_dim(traj[:action], inds) |
| 140 | + r = select_last_dim(traj[:reward], inds) |
| 141 | + t = select_last_dim(traj[:terminal], inds) |
| 142 | + s′ = select_last_dim(traj[:next_state], inds) |
| 143 | + |
| 144 | + γ, ρ, α = p.γ, p.ρ, p.α |
| 145 | + |
| 146 | + # !!! we have several assumptions here, need revisit when we have more complex environments |
| 147 | + # state is vector |
| 148 | + # action is scalar |
| 149 | + a′, log_π = evaluate(p, s′) |
| 150 | + q′_input = vcat(s′, a′) |
| 151 | + q′ = min.(p.target_qnetwork1(q′_input), p.target_qnetwork2(q′_input)) |
| 152 | + |
| 153 | + y = r .+ γ .* (1 .- t) .* vec((q′ .- α .* log_π)) |
| 154 | + |
| 155 | + # Train Q Networks |
| 156 | + a = Flux.unsqueeze(a, 1) |
| 157 | + q_input = vcat(s, a) |
| 158 | + |
| 159 | + q_grad_1 = gradient(Flux.params(p.qnetwork1)) do |
| 160 | + q1 = p.qnetwork1(q_input) |> vec |
| 161 | + mse(q1, y) |
| 162 | + end |
| 163 | + update!(p.qnetwork1, q_grad_1) |
| 164 | + q_grad_2 = gradient(Flux.params(p.qnetwork2)) do |
| 165 | + q2 = p.qnetwork1(q_input) |> vec |
| 166 | + mse(q2, y) |
| 167 | + end |
| 168 | + update!(p.qnetwork2, q_grad_2) |
| 169 | + |
| 170 | + # Train Policy |
| 171 | + p_grad = gradient(Flux.params(p.policy)) do |
| 172 | + a, log_π = evaluate(p, s) |
| 173 | + q_input = vcat(s, a) |
| 174 | + q = min.(p.qnetwork1(q_input), p.qnetwork2(q_input)) |
| 175 | + mean(α .* log_π .- q) |
| 176 | + end |
| 177 | + update!(p.policy, p_grad) |
| 178 | + |
| 179 | + # polyak averaging |
| 180 | + for (dest, src) in zip( |
| 181 | + Flux.params([p.target_qnetwork1, p.target_qnetwork2]), |
| 182 | + Flux.params([p.qnetwork1, p.qnetwork2]), |
| 183 | + ) |
| 184 | + dest .= ρ .* dest .+ (1 - ρ) .* src |
| 185 | + end |
| 186 | +end |
0 commit comments