Skip to content

Commit 61326ef

Browse files
committed
Add support for nested derivatives
1 parent 87a3dba commit 61326ef

3 files changed

Lines changed: 11 additions & 11 deletions

File tree

src/chainrules.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
66
mapreduce(*, +, value(a), value(b))
77
end
88

9-
function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T <: Number}
9+
function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T}
1010
taylor_scalar_pullback(t̄) = NoTangent(), value(t̄)
1111
return TaylorScalar(v), taylor_scalar_pullback
1212
end
@@ -22,7 +22,7 @@ function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
2222
end
2323

2424
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
25-
i::Integer) where {N, T <: Number}
25+
i::Integer) where {N, T}
2626
function extract_derivative_pullback(d̄)
2727
NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ?: zero(T), Val(N))),
2828
NoTangent()
@@ -31,7 +31,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
3131
end
3232

3333
function rrule(::typeof(*), A::AbstractMatrix{S},
34-
t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Number, T}
34+
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
3535
project_A = ProjectTo(A)
3636
function gemv_pullback(x̄)
3737
= reinterpret(reshape, T, x̄)
@@ -41,17 +41,17 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
4141
return A * t, gemv_pullback
4242
end
4343

44-
@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number}
44+
@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T}
4545
project_v = ProjectTo(v)
4646
t + v, x̄ -> (x̄, project_v(x̄))
4747
end
4848

49-
@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
49+
@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T}
5050
project_v = ProjectTo(v)
5151
v + t, x̄ -> (project_v(x̄), x̄)
5252
end
5353

54-
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx)
54+
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T} = primal(dx)
5555

5656
# Not-a-number patches
5757

src/derivative.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,26 @@ end
4040
# Added to help Zygote infer types
4141
make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1))
4242

43-
@inline function derivative(f, x::T, ::Val{N}) where {T <: Number, N}
43+
@inline function derivative(f, x::T, ::Val{N}) where {T <: TN, N}
4444
t = TaylorScalar{T, N}(x, one(x))
4545
return extract_derivative(f(t), N)
4646
end
4747

4848
@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S},
49-
vN::Val{N}) where {T <: Number, S <: Number, N}
49+
vN::Val{N}) where {T <: TN, S <: TN, N}
5050
t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l)
5151
# equivalent to map(TaylorScalar{T, N}, x, l)
5252
return extract_derivative(f(t), N)
5353
end
5454

5555
# shorthand notations for matrices
5656

57-
@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: Number, N}
57+
@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: TN, N}
5858
size(x)[1] != 1 && @warn "x is not a row vector."
5959
mapcols(u -> derivative(f, u[1], vN), x)
6060
end
6161

6262
@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S},
63-
vN::Val{N}) where {T <: Number, S <: Number, N}
63+
vN::Val{N}) where {T <: TN, S <: TN, N}
6464
mapcols(u -> derivative(f, u, l, vN), x)
6565
end

src/scalar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct TaylorScalar{T, N}
1717
value::NTuple{N, T}
1818
end
1919

20-
TaylorOrNumber = Union{TaylorScalar, Number}
20+
TN = Union{TaylorScalar, Number}
2121

2222
@inline TaylorScalar(xs::Vararg{T, N}) where {T, N} = TaylorScalar(xs)
2323

0 commit comments

Comments
 (0)