Skip to content

Commit e6ab873

Browse files
committed
Migrate to value + partials
1 parent e06c8fd commit e6ab873

7 files changed

Lines changed: 115 additions & 104 deletions

File tree

src/chainrules.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,30 @@
11
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out
22
using Base.Broadcast: broadcasted
33

4-
function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
5-
mapreduce(*, +, value(a), value(b))
4+
function rrule(::Type{TaylorScalar{T, N}}, v::T, p::NTuple{N, T}) where {N, T}
5+
taylor_scalar_pullback(t̄) = NoTangent(), value(t̄), partials(t̄)
6+
return TaylorScalar{T, N}(v, p), taylor_scalar_pullback
67
end
78

8-
function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T}
9-
taylor_scalar_pullback(t̄) = NoTangent(), value(t̄)
10-
return TaylorScalar(v), taylor_scalar_pullback
9+
function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
10+
value_pullback(v̄::T) = NoTangent(), TaylorScalar{T, N}(v̄)
11+
return value(t), value_pullback
1112
end
1213

13-
function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
14-
value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(v̄)
14+
function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T}
15+
value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(0, v̄)
1516
# for structural tangent, convert to tuple
1617
function value_pullback(v̄::Tangent{P, NTuple{N, T}}) where {P}
17-
NoTangent(), TaylorScalar{T, N}(backing(v̄))
18+
NoTangent(), TaylorScalar{T, N}(zero(T), backing(v̄))
1819
end
19-
value_pullback(v̄) = NoTangent(), TaylorScalar{T, N}(map(x -> convert(T, x), Tuple(v̄)))
20-
return value(t), value_pullback
20+
value_pullback(v̄) = NoTangent(), TaylorScalar{T, N}(zero(T), map(x -> convert(T, x), Tuple(v̄)))
21+
return partials(t), value_pullback
2122
end
2223

2324
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
2425
i::Integer) where {N, T}
2526
function extract_derivative_pullback(d̄)
26-
NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ?: zero(T), Val(N))),
27+
NoTangent(), TaylorScalar{T, N}(zero(T), ntuple(j -> j === i ?: zero(T), Val(N))),
2728
NoTangent()
2829
end
2930
return extract_derivative(t, i), extract_derivative_pullback
@@ -53,7 +54,7 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
5354
return A * B, gemm_pullback
5455
end
5556

56-
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx)
57+
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = value(dx)
5758

5859
# opt-outs
5960

src/derivative.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ export derivative, derivative!, derivatives, make_seed
55
derivative(f, x, l, ::Val{N})
66
derivative(f!, y, x, l, ::Val{N})
77
8-
Computes `order`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
8+
Computes `N`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
99
"""
1010
function derivative end
1111

@@ -21,7 +21,7 @@ function derivative! end
2121
derivatives(f, x, l, ::Val{N})
2222
derivatives(f!, y, x, l, ::Val{N})
2323
24-
Computes all derivatives of `f` at `x` up to order `N - 1`.
24+
Computes all derivatives of `f` at `x` up to order `N`.
2525
"""
2626
function derivatives end
2727

@@ -32,12 +32,12 @@ function derivatives end
3232
# Convenience wrappers for converting orders to value types
3333
# and forward work to core APIs
3434

35-
@inline derivative(f, x, l, order::Int64) = derivative(f, x, l, Val{order + 1}())
36-
@inline derivative(f!, y, x, l, order::Int64) = derivative(f!, y, x, l, Val{order + 1}())
35+
@inline derivative(f, x, l, order::Int64) = derivative(f, x, l, Val{order}())
36+
@inline derivative(f!, y, x, l, order::Int64) = derivative(f!, y, x, l, Val{order}())
3737
@inline derivative!(result, f, x, l, order::Int64) = derivative!(
38-
result, f, x, l, Val{order + 1}())
38+
result, f, x, l, Val{order}())
3939
@inline derivative!(result, f!, y, x, l, order::Int64) = derivative!(
40-
result, f!, y, x, l, Val{order + 1}())
40+
result, f!, y, x, l, Val{order}())
4141

4242
# Core APIs
4343

@@ -69,6 +69,6 @@ end
6969
@inline function derivatives(f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N}
7070
buffer = similar(y, TaylorScalar{T, N})
7171
f!(buffer, make_seed(x, l, vN))
72-
map!(primal, y, buffer)
72+
map!(value, y, buffer)
7373
return buffer
7474
end

src/primitive.jl

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ import Base: hypot, max, min
88
import Base: tail
99

1010
# Unary
11-
@inline +(a::Number, b::TaylorScalar) = TaylorScalar((a + value(b)[1]), tail(value(b))...)
12-
@inline -(a::Number, b::TaylorScalar) = TaylorScalar((a - value(b)[1]), .-tail(value(b))...)
13-
@inline *(a::Number, b::TaylorScalar) = TaylorScalar((a .* value(b))...)
11+
@inline +(a::Number, b::TaylorScalar) = TaylorScalar(a + value(b), partials(b))
12+
@inline -(a::Number, b::TaylorScalar) = TaylorScalar(a - value(b), map(-, partials(b)))
13+
@inline *(a::Number, b::TaylorScalar) = TaylorScalar(a * value(b), a .* partials(b))
1414
@inline /(a::Number, b::TaylorScalar) = /(promote(a, b)...)
1515

16-
@inline +(a::TaylorScalar, b::Number) = TaylorScalar((value(a)[1] + b), tail(value(a))...)
17-
@inline -(a::TaylorScalar, b::Number) = TaylorScalar((value(a)[1] - b), tail(value(a))...)
18-
@inline *(a::TaylorScalar, b::Number) = TaylorScalar((value(a) .* b)...)
19-
@inline /(a::TaylorScalar, b::Number) = TaylorScalar((value(a) ./ b)...)
16+
@inline +(a::TaylorScalar, b::Number) = TaylorScalar(value(a) + b, partials(a))
17+
@inline -(a::TaylorScalar, b::Number) = TaylorScalar(value(a) - b, partials(a))
18+
@inline *(a::TaylorScalar, b::Number) = TaylorScalar(value(a) * b, partials(a) .* b)
19+
@inline /(a::TaylorScalar, b::Number) = TaylorScalar(value(a) / b, partials(a) ./ b)
2020

2121
## Delegated
2222

@@ -27,10 +27,10 @@ import Base: tail
2727
for func in (:exp, :expm1, :exp2, :exp10)
2828
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
2929
ex = quote
30-
v = value(t)
30+
v = flatten(t)
3131
v1 = $($(QuoteNode(func)) == :expm1 ? :(exp(v[1])) : :($$func(v[1])))
3232
end
33-
for i in 2:N
33+
for i in 2:(N + 1)
3434
ex = quote
3535
$ex
3636
$(Symbol('v', i)) = +($([:($(binomial(i - 2, j - 1)) * $(Symbol('v', j)) *
@@ -46,19 +46,19 @@ for func in (:exp, :expm1, :exp2, :exp10)
4646
if $(QuoteNode(func)) == :expm1
4747
ex = :($ex; v1 = expm1(v[1]))
4848
end
49-
ex = :($ex; TaylorScalar{T, N}(tuple($([Symbol('v', i) for i in 1:N]...))))
49+
ex = :($ex; TaylorScalar(tuple($([Symbol('v', i) for i in 1:(N + 1)]...))))
5050
return :(@inbounds $ex)
5151
end
5252
end
5353

5454
for func in (:sin, :cos)
5555
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
5656
ex = quote
57-
v = value(t)
57+
v = flatten(t)
5858
s1 = sin(v[1])
5959
c1 = cos(v[1])
6060
end
61-
for i in 2:N
61+
for i in 2:(N + 1)
6262
ex = :($ex;
6363
$(Symbol('s', i)) = +($([:($(binomial(i - 2, j - 1)) *
6464
$(Symbol('c', j)) *
@@ -69,9 +69,9 @@ for func in (:sin, :cos)
6969
v[$(i + 1 - j)]) for j in 1:(i - 1)]...)))
7070
end
7171
if $(QuoteNode(func)) == :sin
72-
ex = :($ex; TaylorScalar($([Symbol('s', i) for i in 1:N]...)))
72+
ex = :($ex; TaylorScalar(tuple($([Symbol('s', i) for i in 1:(N + 1)]...))))
7373
else
74-
ex = :($ex; TaylorScalar($([Symbol('c', i) for i in 1:N]...)))
74+
ex = :($ex; TaylorScalar(tuple($([Symbol('c', i) for i in 1:(N + 1)]...))))
7575
end
7676
return quote
7777
@inbounds $ex
@@ -94,24 +94,27 @@ for op in [:>, :<, :(==), :(>=), :(<=)]
9494
@eval @inline $op(a::TaylorScalar, b::TaylorScalar) = $op(value(a)[1], value(b)[1])
9595
end
9696

97-
@inline +(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(map(+, value(a), value(b)))
98-
@inline -(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(map(-, value(a), value(b)))
97+
@inline +(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(
98+
value(a) + value(b), map(+, partials(a), partials(b)))
99+
@inline -(a::TaylorScalar, b::TaylorScalar) = TaylorScalar(
100+
value(a) - value(b), map(-, partials(a), partials(b)))
99101

100102
@generated function *(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
101103
return quote
102-
va, vb = value(a), value(b)
103-
@inbounds TaylorScalar($([:(+($([:($(binomial(i - 1, j - 1)) * va[$j] *
104-
vb[$(i + 1 - j)]) for j in 1:i]...)))
105-
for i in 1:N]...))
104+
va, vb = flatten(a), flatten(b)
105+
r = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * va[$j] *
106+
vb[$(i + 1 - j)]) for j in 1:i]...)))
107+
for i in 1:(N + 1)]...))
108+
@inbounds TaylorScalar(r[1], r[2:end])
106109
end
107110
end
108111

109112
@generated function /(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
110113
ex = quote
111-
va, vb = value(a), value(b)
114+
va, vb = flatten(a), flatten(b)
112115
v1 = va[1] / vb[1]
113116
end
114-
for i in 2:N
117+
for i in 2:(N + 1)
115118
ex = quote
116119
$ex
117120
$(Symbol('v', i)) = (va[$i] -
@@ -120,24 +123,28 @@ end
120123
for j in 1:(i - 1)]...))) / vb[1]
121124
end
122125
end
123-
ex = :($ex; TaylorScalar($([Symbol('v', i) for i in 1:N]...)))
126+
ex = quote
127+
$ex
128+
v = tuple($([Symbol('v', i) for i in 1:(N + 1)]...))
129+
TaylorScalar(v)
130+
end
124131
return :(@inbounds $ex)
125132
end
126133

127134
for R in (Integer, Real)
128135
@eval @generated function ^(t::TaylorScalar{T, N}, n::S) where {S <: $R, T, N}
129136
ex = quote
130-
v = value(t)
137+
v = flatten(t)
131138
w11 = 1
132139
u1 = ^(v[1], n)
133140
end
134-
for k in 1:N
141+
for k in 1:(N + 1)
135142
ex = quote
136143
$ex
137144
$(Symbol('p', k)) = ^(v[1], n - $(k - 1))
138145
end
139146
end
140-
for i in 2:N
147+
for i in 2:(N + 1)
141148
subex = quote
142149
$(Symbol('w', i, 1)) = 0
143150
end
@@ -158,7 +165,11 @@ for R in (Integer, Real)
158165
for k in 2:i]...))
159166
end
160167
end
161-
ex = :($ex; TaylorScalar($([Symbol('u', i) for i in 1:N]...)))
168+
ex = quote
169+
$ex
170+
v = tuple($([Symbol('u', i) for i in 1:(N + 1)]...))
171+
TaylorScalar(v)
172+
end
162173
return :(@inbounds $ex)
163174
end
164175
@eval function ^(a::S, t::TaylorScalar{T, N}) where {S <: $R, T, N}
@@ -172,11 +183,11 @@ end
172183
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
173184
return quote
174185
$(Expr(:meta, :inline))
175-
vdf, vt = value(df), value(t)
176-
@inbounds TaylorScalar(f,
177-
$([:(+($([:($(binomial(i - 1, j - 1)) * vdf[$j] *
178-
vt[$(i + 2 - j)]) for j in 1:i]...)))
179-
for i in 1:M]...))
186+
vdf, vt = flatten(df), flatten(t)
187+
partials = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * vdf[$j] *
188+
vt[$(i + 2 - j)]) for j in 1:i]...)))
189+
for i in 1:(M + 1)]...))
190+
@inbounds TaylorScalar(f, partials)
180191
end
181192
end
182193

@@ -185,10 +196,10 @@ raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t
185196
@generated function raiseinv(f::T, df::TaylorScalar{T, M},
186197
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
187198
ex = quote
188-
vdf, vt = value(df), value(t)
199+
vdf, vt = flatten(df), flatten(t)
189200
v1 = vt[2] / vdf[1]
190201
end
191-
for i in 2:M
202+
for i in 2:(M + 1)
192203
ex = quote
193204
$ex
194205
$(Symbol('v', i)) = (vt[$(i + 1)] -
@@ -197,6 +208,10 @@ raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t
197208
for j in 1:(i - 1)]...))) / vdf[1]
198209
end
199210
end
200-
ex = :($ex; TaylorScalar(f, $([Symbol('v', i) for i in 1:M]...)))
211+
ex = quote
212+
$ex
213+
v = tuple($([Symbol('v', i) for i in 1:(M + 1)]...))
214+
TaylorScalar(f, v)
215+
end
201216
return :(@inbounds $ex)
202217
end

0 commit comments

Comments
 (0)