Skip to content

Commit c11dd72

Browse files
committed
Fix zygote compat
1 parent f1f850b commit c11dd72

5 files changed

Lines changed: 78 additions & 112 deletions

File tree

Project.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,22 @@ version = "0.2.4"
66
[deps]
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9-
ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
109
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1110
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
12-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1311

1412
[weakdeps]
1513
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1614
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
15+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1716

1817
[extensions]
1918
TaylorDiffNNlibExt = ["NNlib"]
2019
TaylorDiffSFExt = ["SpecialFunctions"]
20+
TaylorDiffZygoteExt = ["Zygote"]
2121

2222
[compat]
2323
ChainRules = "1"
2424
ChainRulesCore = "1"
25-
ChainRulesOverloadGeneration = "0.1"
2625
NNlib = "0.9"
2726
SpecialFunctions = "2"
2827
SymbolicUtils = "2, 3"

ext/TaylorDiffZygoteExt.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
module TaylorDiffZygoteExt
2+
3+
using TaylorDiff
4+
import Zygote: @adjoint, Numeric, _dual_safearg, ZygoteRuleConfig
5+
using ChainRulesCore: @opt_out
6+
7+
# Zygote can't infer this constructor function
8+
# defining rrule for this doesn't seem to work for Zygote
9+
# so need to use @adjoint
10+
@adjoint TaylorScalar{T, N}(t::TaylorScalar{T, M}) where {T, N, M} = TaylorScalar{T, N}(t),
11+
-> (TaylorScalar{T, M}(x̄),)
12+
13+
# Zygote will try to use ForwardDiff to compute broadcast functions
14+
# However, TaylorScalar is not dual safe, so we opt out of this
15+
_dual_safearg(::Numeric{<:TaylorScalar}) = false
16+
17+
# Zygote has a rule for literal power, need to opt out of this
18+
@opt_out rrule(
19+
::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::TaylorScalar, ::Val{p}
20+
) where {p}
21+
22+
end

src/chainrules.jl

Lines changed: 28 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing
1+
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out
22
using Base.Broadcast: broadcasted
3-
import Zygote: @adjoint, accum_sum, unbroadcast, Numeric, ∇getindex, _project
43

54
function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
65
mapreduce(*, +, value(a), value(b))
@@ -31,7 +30,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
3130
end
3231

3332
function rrule(::typeof(*), A::AbstractMatrix{S},
34-
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
33+
t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Real, T <: Real}
3534
project_A = ProjectTo(A)
3635
function gemv_pullback(x̄)
3736
= reinterpret(reshape, T, x̄)
@@ -42,7 +41,7 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
4241
end
4342

4443
function rrule(::typeof(*), A::AbstractMatrix{S},
45-
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T}
44+
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S <: Real, T <: Real}
4645
project_A = ProjectTo(A)
4746
project_B = ProjectTo(B)
4847
function gemm_pullback(x̄)
@@ -54,88 +53,36 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
5453
return A * B, gemm_pullback
5554
end
5655

57-
@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T}
58-
project_v = ProjectTo(v)
59-
t + v, x̄ -> (x̄, project_v(x̄))
60-
end
61-
62-
@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T}
63-
project_v = ProjectTo(v)
64-
v + t, x̄ -> (project_v(x̄), x̄)
65-
end
66-
67-
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T} = primal(dx)
68-
69-
# Not-a-number patches
70-
71-
ProjectTo(::T) where {T <: TaylorScalar} = ProjectTo{T}()
72-
(p::ProjectTo{T})(x::T) where {T <: TaylorScalar} = x
73-
function ProjectTo(x::AbstractArray{T}) where {T <: TaylorScalar}
74-
ProjectTo{AbstractArray}(; element = ProjectTo(zero(T)), axes = axes(x))
75-
end
76-
(p::ProjectTo{AbstractArray{T}})(x::AbstractArray{T}) where {T <: TaylorScalar} = x
77-
accum_sum(xs::AbstractArray{T}; dims = :) where {T <: TaylorScalar} = sum(xs, dims = dims)
78-
79-
TaylorNumeric{T <: TaylorScalar} = Union{T, AbstractArray{<:T}}
80-
81-
@adjoint function broadcasted(::typeof(+), xs::TaylorNumeric...)
82-
broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...)
83-
end
56+
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx)
8457

85-
struct TaylorOneElement{T, N, I, A} <: AbstractArray{T, N}
86-
val::T
87-
ind::I
88-
axes::A
89-
function TaylorOneElement(val::T, ind::I,
90-
axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int},
91-
A <: NTuple{N, AbstractUnitRange}} where {N}
92-
new{T, N, I, A}(val, ind, axes)
93-
end
94-
end
58+
# opt-outs
9559

96-
Base.size(A::TaylorOneElement) = map(length, A.axes)
97-
Base.axes(A::TaylorOneElement) = A.axes
98-
function Base.getindex(A::TaylorOneElement{T, N}, i::Vararg{Int, N}) where {T, N}
99-
ifelse(i == A.ind, A.val, zero(T))
100-
end
101-
102-
function ∇getindex(x::AbstractArray{T, N}, inds) where {T <: TaylorScalar, N}
103-
dy -> begin
104-
dx = TaylorOneElement(dy, inds, axes(x))
105-
return (_project(x, dx), map(_ -> nothing, inds)...)
106-
end
107-
end
60+
# Unary functions
10861

109-
@generated function mul_adjoint::TaylorScalar{T, N}, x::TaylorScalar{T, N}) where {T, N}
110-
return quote
111-
vΩ, vx = value(Ω), value(x)
112-
@inbounds TaylorScalar($([:(+($([:($(binomial(j - 1, i - 1)) * vΩ[$j] *
113-
vx[$(j + 1 - i)]) for j in i:N]...)))
114-
for i in 1:N]...))
115-
end
62+
for f in (
63+
exp, exp10, exp2, expm1,
64+
sin, cos, tan, sec, csc, cot,
65+
sinh, cosh, tanh, sech, csch, coth,
66+
log, log10, log2, log1p,
67+
asin, acos, atan, asec, acsc, acot,
68+
asinh, acosh, atanh, asech, acsch, acoth,
69+
sqrt, cbrt, inv
70+
)
71+
@eval @opt_out frule(::typeof($f), x::TaylorScalar)
72+
@eval @opt_out rrule(::typeof($f), x::TaylorScalar)
11673
end
11774

118-
rrule(::typeof(*), x::TaylorScalar) = rrule(identity, x)
119-
120-
function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar)
121-
function times_pullback2(Ω̇)
122-
ΔΩ = unthunk(Ω̇)
123-
return (NoTangent(), ProjectTo(x)(mul_adjoint(ΔΩ, y)),
124-
ProjectTo(y)(mul_adjoint(ΔΩ, x)))
125-
end
126-
return x * y, times_pullback2
127-
end
75+
# Binary functions
12876

129-
function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar,
130-
more::TaylorScalar...)
131-
Ω2, back2 = rrule(*, x, y)
132-
Ω3, back3 = rrule(*, Ω2, z)
133-
Ω4, back4 = rrule(*, Ω3, more...)
134-
function times_pullback4(Ω̇)
135-
Δ4 = back4(unthunk(Ω̇)) # (0, ΔΩ3, Δmore...)
136-
Δ3 = back3(Δ4[2]) # (0, ΔΩ2, Δz)
137-
Δ2 = back2(Δ3[2]) # (0, Δx, Δy)
138-
return (Δ2..., Δ3[3], Δ4[3:end]...)
77+
for f in (
78+
*, /, ^
79+
)
80+
for (tlhs, trhs) in (
81+
(TaylorScalar, TaylorScalar),
82+
(TaylorScalar, Number),
83+
(Number, TaylorScalar)
84+
)
85+
@eval @opt_out frule(::typeof($f), x::$tlhs, y::$trhs)
86+
@eval @opt_out rrule(::typeof($f), x::$tlhs, y::$trhs)
13987
end
140-
return Ω4, times_pullback4
14188
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ using Test
33

44
include("primitive.jl")
55
include("derivative.jl")
6-
# include("zygote.jl")
6+
include("zygote.jl")
77
# include("lux.jl")

test/zygote.jl

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,36 @@
1-
using Zygote, LinearAlgebra
1+
using LinearAlgebra
2+
import Zygote # use qualified import to avoid conflict with TaylorDiff
23

3-
@testset "Zygote for mixed derivative" begin
4+
@testset "Zygote-over-TaylorDiff on same variable" begin
5+
# Scalar functions
46
some_number = 0.7
57
some_numbers = [0.3 0.4 0.1;]
6-
for f in (exp, log, sqrt, sin, asin, sinh, asinh)
7-
@test gradient(x -> derivative(f, x, 2), some_number)[1]
8+
for f in (exp, log, sqrt, sin, asin, sinh, asinh, x -> x^3)
9+
@test Zygote.gradient(derivative, f, some_number, 2)[2]
810
derivative(f, some_number, 3)
9-
derivative_result = vec(derivative.(f, some_numbers, 3))
10-
@test Zygote.jacobian(x -> derivative.(f, x, 2), some_numbers)[1]
11-
diagm(derivative_result)
11+
@test Zygote.jacobian(broadcast, derivative, f, some_numbers, 2)[3]
12+
diagm(vec(derivative.(f, some_numbers, 3)))
1213
end
1314

14-
some_matrix = [0.7 0.1; 0.4 0.2]
15-
f = x -> sum(tanh.(x), dims = 1)
16-
dfdx1(m, x) = derivative(m, x, [1.0, 0.0], 1)
17-
dfdx2(m, x) = derivative(m, x, [0.0, 1.0], 1)
18-
res(m, x) = dfdx1(m, x) .+ 2 * dfdx2(m, x)
19-
grads = Zygote.gradient(some_matrix) do x
20-
sum(res(f, x))
21-
end
22-
expected_grads = x -> -2 * sinh(x) / cosh(x)^3
23-
@test grads[1] [1 0; 0 2] * expected_grads.(some_matrix)
24-
25-
@test gradient(x -> derivative(x -> x * x, x, 1),
26-
5.0)[1] 2.0
27-
15+
# Vector functions
2816
g(x) = x[1] * x[1] + x[2] * x[2]
29-
@test gradient(x -> derivative(g, x, [1.0, 0.0], 1),
30-
[1.0, 2.0])[1] [2.0, 0.0]
17+
@test Zygote.gradient(derivative, g, [1.0, 2.0], [1.0, 0.0], 1)[2] [2.0, 0.0]
18+
19+
# Matrix functions
20+
some_matrix = [0.7 0.1; 0.4 0.2]
21+
f(x) = sum(exp.(x), dims = 1)
22+
dfdx1(x) = derivative(f, x, [1.0, 0.0], 1)
23+
dfdx2(x) = derivative(f, x, [0.0, 1.0], 1)
24+
res(x) = sum(dfdx1(x) .+ 2 * dfdx2(x))
25+
grads = Zygote.gradient(res, some_matrix)
26+
@test grads[1] [1 0; 0 2] * exp.(some_matrix)
3127
end
3228

33-
@testset "Zygote for parameter optimization" begin
34-
gradient(p -> derivative(x -> sum(exp.(x + p)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7])
35-
gradient(p -> derivative(x -> sum(exp.(p + x)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7])
29+
@testset "Zygote-over-TaylorDiff on different variable" begin
30+
Zygote.gradient(
31+
p -> derivative(x -> sum(exp.(x + p)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7])
32+
Zygote.gradient(
33+
p -> derivative(x -> sum(exp.(p + x)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7])
3634
linear_model(x, p, b) = exp.(b + p * x + b)[1]
3735
some_x, some_v, some_p, some_b = [0.58, 0.36], [0.23, 0.11], [0.49 0.96], [0.88]
3836
loss_taylor(p) = derivative(x -> linear_model(x, p, some_b), some_x, some_v, 1)
@@ -41,5 +39,5 @@ end
4139
let f = x -> linear_model(x, p, some_b)
4240
(f(some_x + ε * some_v) - f(some_x - ε * some_v)) / 2ε
4341
end
44-
@test gradient(loss_taylor, some_p)[1] gradient(loss_finite, some_p)[1]
42+
@test Zygote.gradient(loss_taylor, some_p)[1] Zygote.gradient(loss_finite, some_p)[1]
4543
end

0 commit comments

Comments
 (0)