Skip to content

Commit 44c3ab0

Browse files
committed
coeffs instead of derivatives
1 parent 2b0b1ce commit 44c3ab0

5 files changed

Lines changed: 142 additions & 172 deletions

File tree

src/chainrules.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@ function rrule(::typeof(partials), t::TaylorArray{T, N, A, P}) where {N, T, A, P
3333
return partials(t), value_pullback
3434
end
3535

36-
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
37-
i::Integer) where {N, T}
36+
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, P},
37+
q::Val{Q}) where {T, P, Q}
3838
function extract_derivative_pullback(d̄)
39-
NoTangent(), TaylorScalar(zero(T), ntuple(j -> j === i ?: zero(T), Val(N))),
39+
NoTangent(),
40+
TaylorScalar(zero(T), ntuple(j -> j === Q ?* factorial(Q) : zero(T), Val(P))),
4041
NoTangent()
4142
end
42-
return extract_derivative(t, i), extract_derivative_pullback
43+
return extract_derivative(t, q), extract_derivative_pullback
4344
end
4445

4546
function rrule(::typeof(Base.getindex), a::TaylorArray, i::Int...)

src/codegen.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
for unary_func in (
2-
+, -, deg2rad, rad2deg,
2+
deg2rad, rad2deg,
33
sinh, cosh, tanh,
44
asin, acos, atan, asec, acsc, acot,
55
log, log10, log1p, log2,

src/derivative.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,18 @@ function derivatives end
4343

4444
# Added to help Zygote infer types
4545
@inline make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} = TaylorScalar{P}(x, l)
46-
@inline make_seed(x::A, l::A, ::Val{P}) where {A <: AbstractArray, P} = broadcast(make_seed, x, l, Val{P}())
46+
@inline make_seed(x::A, l::A, ::Val{P}) where {A <: AbstractArray, P} = broadcast(
47+
make_seed, x, l, Val{P}())
4748

4849
# `derivative` API: computes the `P - 1`-th derivative of `f` at `x`
4950
@inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative(
50-
derivatives(f, x, l, p), P)
51+
derivatives(f, x, l, p), p)
5152
@inline derivative(f!, y, x, l, p::Val{P}) where {P} = extract_derivative(
52-
derivatives(f!, y, x, l, p), P)
53+
derivatives(f!, y, x, l, p), p)
5354
@inline derivative!(result, f, x, l, p::Val{P}) where {P} = extract_derivative!(
54-
result, derivatives(f, x, l, p), P)
55+
result, derivatives(f, x, l, p), p)
5556
@inline derivative!(result, f!, y, x, l, p::Val{P}) where {P} = extract_derivative!(
56-
result, derivatives(f!, y, x, l, p), P)
57+
result, derivatives(f!, y, x, l, p), p)
5758

5859
# `derivatives` API: computes all derivatives of `f` at `x` up to p `P - 1`
5960

src/primitive.jl

Lines changed: 105 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@ Taylor = Union{TaylorScalar, TaylorArray}
1212

1313
@inline value(t::Taylor) = t.value
1414
@inline partials(t::Taylor) = t.partials
15-
@inline extract_derivative(t::Taylor, i::Integer) = t.partials[i]
16-
@inline extract_derivative(v::AbstractArray{<:TaylorScalar}, i::Integer) = map(
17-
t -> extract_derivative(t, i), v)
18-
@inline extract_derivative(r, i::Integer) = false
19-
@inline function extract_derivative!(result::AbstractArray, v::AbstractArray{T},
20-
i::Integer) where {T <: TaylorScalar}
21-
map!(t -> extract_derivative(t, i), result, v)
22-
end
15+
@inline @generated extract_derivative(t::Taylor, ::Val{P}) where {P} = :(t.partials[P] *
16+
$(factorial(P)))
17+
@inline extract_derivative(a::AbstractArray{<:TaylorScalar}, p) = map(
18+
t -> extract_derivative(t, p), a)
19+
@inline extract_derivative(_, p) = false
20+
@inline extract_derivative!(result, a::AbstractArray{<:TaylorScalar}, p) = map!(
21+
t -> extract_derivative(t, p), result, a)
2322

2423
@inline flatten(t::Taylor) = (value(t), partials(t)...)
2524

@@ -33,74 +32,75 @@ function (::Type{F})(x::TaylorScalar{T, P}) where {T, P, F <: AbstractFloat}
3332
end
3433

3534
# Unary
36-
@inline +(a::Number, b::TaylorScalar) = TaylorScalar(a + value(b), partials(b))
37-
@inline -(a::Number, b::TaylorScalar) = TaylorScalar(a - value(b), map(-, partials(b)))
38-
@inline *(a::Number, b::TaylorScalar) = TaylorScalar(a * value(b), a .* partials(b))
39-
@inline /(a::Number, b::TaylorScalar) = /(promote(a, b)...)
40-
41-
@inline +(a::TaylorScalar, b::Number) = TaylorScalar(value(a) + b, partials(a))
42-
@inline -(a::TaylorScalar, b::Number) = TaylorScalar(value(a) - b, partials(a))
43-
@inline *(a::TaylorScalar, b::Number) = TaylorScalar(value(a) * b, partials(a) .* b)
44-
@inline /(a::TaylorScalar, b::Number) = TaylorScalar(value(a) / b, partials(a) ./ b)
4535

4636
## Delegated
4737

38+
@inline +(t::TaylorScalar) = t
39+
@inline -(t::TaylorScalar) = TaylorScalar(-value(t), .-partials(t))
4840
@inline sqrt(t::TaylorScalar) = t^0.5
4941
@inline cbrt(t::TaylorScalar) = ^(t, 1 / 3)
5042
@inline inv(t::TaylorScalar) = one(t) / t
5143

5244
for func in (:exp, :expm1, :exp2, :exp10)
53-
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
45+
@eval @generated function $func(t::TaylorScalar{T, P}) where {T, P}
46+
v = [Symbol("v$i") for i in 0:P]
5447
ex = quote
55-
v = flatten(t)
56-
v1 = $($(QuoteNode(func)) == :expm1 ? :(exp(v[1])) : :($$func(v[1])))
48+
$(Expr(:meta, :inline))
49+
p = value(t)
50+
f = flatten(t)
51+
v0 = $($(QuoteNode(func)) == :expm1 ? :(exp(p)) : :($$func(p)))
5752
end
58-
for i in 2:(N + 1)
59-
ex = quote
60-
$ex
61-
$(Symbol('v', i)) = +($([:($(binomial(i - 2, j - 1)) * $(Symbol('v', j)) *
62-
v[$(i + 1 - j)])
63-
for j in 1:(i - 1)]...))
64-
end
53+
for i in 1:P
54+
push!(ex.args,
55+
:(
56+
$(v[begin + i]) = +($([:($(i - j) * $(v[begin + j]) *
57+
f[begin + $(i - j)])
58+
for j in 0:(i - 1)]...)) / $i
59+
))
6560
if $(QuoteNode(func)) == :exp2
66-
ex = :($ex; $(Symbol('v', i)) *= $(log(2)))
61+
push!(ex.args, :($(v[begin + i]) *= log(2)))
6762
elseif $(QuoteNode(func)) == :exp10
68-
ex = :($ex; $(Symbol('v', i)) *= $(log(10)))
63+
push!(ex.args, :($(v[begin + i]) *= log(10)))
6964
end
7065
end
7166
if $(QuoteNode(func)) == :expm1
72-
ex = :($ex; v1 = expm1(v[1]))
67+
push!(ex.args, :(v0 = expm1(f[1])))
7368
end
74-
ex = :($ex; TaylorScalar(tuple($([Symbol('v', i) for i in 1:(N + 1)]...))))
69+
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
7570
return :(@inbounds $ex)
7671
end
7772
end
7873

7974
for func in (:sin, :cos)
80-
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
75+
@eval @generated function $func(t::TaylorScalar{T, P}) where {T, P}
76+
s = [Symbol("s$i") for i in 0:P]
77+
c = [Symbol("c$i") for i in 0:P]
8178
ex = quote
82-
v = flatten(t)
83-
s1 = sin(v[1])
84-
c1 = cos(v[1])
79+
$(Expr(:meta, :inline))
80+
f = flatten(t)
81+
s0 = sin(f[1])
82+
c0 = cos(f[1])
8583
end
86-
for i in 2:(N + 1)
87-
ex = :($ex;
88-
$(Symbol('s', i)) = +($([:($(binomial(i - 2, j - 1)) *
89-
$(Symbol('c', j)) *
90-
v[$(i + 1 - j)]) for j in 1:(i - 1)]...)))
91-
ex = :($ex;
92-
$(Symbol('c', i)) = +($([:($(-binomial(i - 2, j - 1)) *
93-
$(Symbol('s', j)) *
94-
v[$(i + 1 - j)]) for j in 1:(i - 1)]...)))
84+
for i in 1:P
85+
push!(ex.args,
86+
:($(s[begin + i]) = +($([:(
87+
$(i - j) * $(c[begin + j]) *
88+
f[begin + $(i - j)]) for j in 0:(i - 1)]...)) /
89+
$i)
90+
)
91+
push!(ex.args,
92+
:($(c[begin + i]) = +($([:(
93+
$(i - j) * $(s[begin + j]) *
94+
f[begin + $(i - j)]) for j in 0:(i - 1)]...)) /
95+
-$i)
96+
)
9597
end
9698
if $(QuoteNode(func)) == :sin
97-
ex = :($ex; TaylorScalar(tuple($([Symbol('s', i) for i in 1:(N + 1)]...))))
99+
push!(ex.args, :(TaylorScalar(tuple($(s...)))))
98100
else
99-
ex = :($ex; TaylorScalar(tuple($([Symbol('c', i) for i in 1:(N + 1)]...))))
100-
end
101-
return quote
102-
@inbounds $ex
101+
push!(ex.args, :(TaylorScalar(tuple($(c...)))))
103102
end
103+
return :(@inbounds $ex)
104104
end
105105
end
106106

@@ -109,6 +109,18 @@ end
109109

110110
# Binary
111111

112+
## Easy case
113+
114+
@inline +(a::Number, b::TaylorScalar) = TaylorScalar(a + value(b), partials(b))
115+
@inline -(a::Number, b::TaylorScalar) = TaylorScalar(a - value(b), .-partials(b))
116+
@inline *(a::Number, b::TaylorScalar) = TaylorScalar(a * value(b), a .* partials(b))
117+
@inline /(a::Number, b::TaylorScalar) = /(promote(a, b)...)
118+
119+
@inline +(a::TaylorScalar, b::Number) = TaylorScalar(value(a) + b, partials(a))
120+
@inline -(a::TaylorScalar, b::Number) = TaylorScalar(value(a) - b, partials(a))
121+
@inline *(a::TaylorScalar, b::Number) = TaylorScalar(value(a) * b, partials(a) .* b)
122+
@inline /(a::TaylorScalar, b::Number) = TaylorScalar(value(a) / b, partials(a) ./ b)
123+
112124
const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode)
113125

114126
for op in [:>, :<, :(==), :(>=), :(<=)]
@@ -126,75 +138,56 @@ end
126138

127139
@generated function *(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
128140
return quote
141+
$(Expr(:meta, :inline))
129142
va, vb = flatten(a), flatten(b)
130-
r = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * va[$j] *
131-
vb[$(i + 1 - j)]) for j in 1:i]...)))
132-
for i in 1:(N + 1)]...))
133-
@inbounds TaylorScalar(r[1], r[2:end])
143+
v = tuple($([:(
144+
+($([:(va[begin + $j] * vb[begin + $(i - j)]) for j in 0:i]...))
145+
) for i in 0:N]...))
146+
@inbounds TaylorScalar(v)
134147
end
135148
end
136149

137-
@generated function /(a::TaylorScalar{T, N}, b::TaylorScalar{T, N}) where {T, N}
150+
@generated function /(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
151+
v = [Symbol("v$i") for i in 0:P]
138152
ex = quote
153+
$(Expr(:meta, :inline))
139154
va, vb = flatten(a), flatten(b)
140-
v1 = va[1] / vb[1]
155+
v0 = va[1] / vb[1]
156+
b0 = vb[1]
141157
end
142-
for i in 2:(N + 1)
143-
ex = quote
144-
$ex
145-
$(Symbol('v', i)) = (va[$i] -
146-
+($([:($(binomial(i - 1, j - 1)) * $(Symbol('v', j)) *
147-
vb[$(i + 1 - j)])
148-
for j in 1:(i - 1)]...))) / vb[1]
149-
end
150-
end
151-
ex = quote
152-
$ex
153-
v = tuple($([Symbol('v', i) for i in 1:(N + 1)]...))
154-
TaylorScalar(v)
158+
for i in 1:P
159+
push!(ex.args,
160+
:(
161+
$(v[begin + i]) = (va[begin + $i] -
162+
+($([:($(v[begin + j]) *
163+
vb[begin + $(i - j)])
164+
for j in 0:(i - 1)]...))) / b0
165+
)
166+
)
155167
end
168+
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
156169
return :(@inbounds $ex)
157170
end
158171

159172
for R in (Integer, Real)
160-
@eval @generated function ^(t::TaylorScalar{T, N}, n::S) where {S <: $R, T, N}
173+
@eval @generated function ^(t::TaylorScalar{T, P}, n::S) where {S <: $R, T, P}
174+
v = [Symbol("v$i") for i in 0:P]
161175
ex = quote
162-
v = flatten(t)
163-
w11 = 1
164-
u1 = ^(v[1], n)
165-
end
166-
for k in 1:(N + 1)
167-
ex = quote
168-
$ex
169-
$(Symbol('p', k)) = ^(v[1], n - $(k - 1))
170-
end
176+
$(Expr(:meta, :inline))
177+
f = flatten(t)
178+
f0 = f[1]
179+
v0 = ^(f0, n)
171180
end
172-
for i in 2:(N + 1)
173-
subex = quote
174-
$(Symbol('w', i, 1)) = 0
175-
end
176-
for k in 2:i
177-
subex = quote
178-
$subex
179-
$(Symbol('w', i, k)) = +($([:((n * $(binomial(i - 2, j - 1)) -
180-
$(binomial(i - 2, j - 2))) *
181-
$(Symbol('w', j, k - 1)) *
182-
v[$(i + 1 - j)])
183-
for j in (k - 1):(i - 1)]...))
184-
end
185-
end
186-
ex = quote
187-
$ex
188-
$subex
189-
$(Symbol('u', i)) = +($([:($(Symbol('w', i, k)) * $(Symbol('p', k)))
190-
for k in 2:i]...))
191-
end
192-
end
193-
ex = quote
194-
$ex
195-
v = tuple($([Symbol('u', i) for i in 1:(N + 1)]...))
196-
TaylorScalar(v)
181+
for i in 1:P
182+
push!(ex.args,
183+
:(
184+
$(v[begin + i]) = +($([:(
185+
(n * $(i - j) - $j) * $(v[begin + j]) *
186+
f[begin + $(i - j)]
187+
) for j in 0:(i - 1)]...)) / ($i * f0)
188+
))
197189
end
190+
push!(ex.args, :(TaylorScalar(tuple($(v...)))))
198191
return :(@inbounds $ex)
199192
end
200193
@eval function ^(a::S, t::TaylorScalar{T, N}) where {S <: $R, T, N}
@@ -204,39 +197,14 @@ end
204197

205198
^(t::TaylorScalar, s::TaylorScalar) = exp(s * log(t))
206199

207-
@generated function raise(f::T, df::TaylorScalar{T, M},
208-
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
209-
return quote
210-
$(Expr(:meta, :inline))
211-
vdf, vt = flatten(df), flatten(t)
212-
partials = tuple($([:(+($([:($(binomial(i - 1, j - 1)) * vdf[$j] *
213-
vt[$(i + 2 - j)]) for j in 1:i]...)))
214-
for i in 1:(M + 1)]...))
215-
@inbounds TaylorScalar(f, partials)
216-
end
200+
@inline function lower(t::TaylorScalar{T, P}) where {T, P}
201+
s = partials(t)
202+
TaylorScalar(ntuple(i -> s[i] * i, Val(P)))
217203
end
218-
219-
raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t
220-
221-
@generated function raiseinv(f::T, df::TaylorScalar{T, M},
222-
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
223-
ex = quote
224-
vdf, vt = flatten(df), flatten(t)
225-
v1 = vt[2] / vdf[1]
226-
end
227-
for i in 2:(M + 1)
228-
ex = quote
229-
$ex
230-
$(Symbol('v', i)) = (vt[$(i + 1)] -
231-
+($([:($(binomial(i - 1, j - 1)) * $(Symbol('v', j)) *
232-
vdf[$(i + 1 - j)])
233-
for j in 1:(i - 1)]...))) / vdf[1]
234-
end
235-
end
236-
ex = quote
237-
$ex
238-
v = tuple($([Symbol('v', i) for i in 1:(M + 1)]...))
239-
TaylorScalar(f, v)
240-
end
241-
return :(@inbounds $ex)
204+
@inline function higher(t::TaylorScalar{T, P}) where {T, P}
205+
s = flatten(t)
206+
ntuple(i -> s[i] / i, Val(P + 1))
242207
end
208+
@inline raise(f, df::TaylorScalar, t) = TaylorScalar(f, higher(lower(t) * df))
209+
@inline raise(f, df::Number, t) = df * t
210+
@inline raiseinv(f, df, t) = TaylorScalar(f, higher(lower(t) / df))

0 commit comments

Comments
 (0)