Skip to content

Commit 2555949

Browse files
committed
Initialize array type
1 parent 1005d83 commit 2555949

10 files changed

Lines changed: 181 additions & 124 deletions

File tree

ext/TaylorDiffZygoteExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ using ChainRulesCore: @opt_out
77
# Zygote can't infer this constructor function
88
# defining rrule for this doesn't seem to work for Zygote
99
# so need to use @adjoint
10-
@adjoint TaylorScalar{T, N}(t::TaylorScalar{T, M}) where {T, N, M} = TaylorScalar{T, N}(t),
11-
-> (TaylorScalar{T, M}(x̄),)
10+
@adjoint TaylorScalar{P}(t::TaylorScalar{T, Q}) where {T, P, Q} = TaylorScalar{P}(t),
11+
-> (TaylorScalar{Q}(x̄),)
1212

1313
# Zygote will try to use ForwardDiff to compute broadcast functions
1414
# However, TaylorScalar is not dual safe, so we opt out of this

src/TaylorDiff.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
11
module TaylorDiff
22

3+
"""
4+
TaylorDiff.can_taylorize(V::Type)
5+
6+
Determines whether the type V is allowed as the scalar type in a
7+
Dual. By default, only `<:Real` types are allowed.
8+
"""
9+
can_taylorize(::Type{<:Real}) = true
10+
can_taylorize(::Type) = false
11+
12+
@noinline function throw_cannot_taylorize(V::Type)
13+
throw(ArgumentError("Cannot create a Taylor polynomial over scalar type $V." *
14+
" If the type behaves as a scalar, define TaylorDiff.can_taylorize(::Type{$V}) = true."))
15+
end
16+
317
include("scalar.jl")
18+
include("array.jl")
419
include("primitive.jl")
520
include("utils.jl")
621
include("codegen.jl")

src/array.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
export TaylorArray
2+
3+
"""
4+
TaylorArray{T, N, A, P}
5+
6+
Representation of Taylor polynomials in array mode.
7+
8+
# Fields
9+
10+
- `value::A`: zeroth order coefficient
11+
- `partials::NTuple{P, T}`: i-th element of this stores the i-th derivative
12+
"""
13+
struct TaylorArray{T, N, A <: AbstractArray{T, N}, P} <:
14+
AbstractArray{TaylorScalar{T, P}, N}
15+
value::A
16+
partials::NTuple{P, A}
17+
function TaylorArray(
18+
value::A, partials::NTuple{P, A}) where {P, A <: AbstractArray}
19+
T = eltype(value)
20+
N = ndims(value)
21+
can_taylorize(T) || throw_cannot_taylorize(T)
22+
new{T, N, A, P}(value, partials)
23+
end
24+
end
25+
26+
function TaylorArray{P}(value::A) where {A <: AbstractArray, P}
27+
TaylorArray(value, ntuple(i -> zeros(eltype(value), size(value)), Val(P)))
28+
end
29+
30+
function TaylorArray{P}(value::A, seed::A) where {A <: AbstractArray, P}
31+
TaylorArray(
32+
value, ntuple(i -> i == 1 ? seed : zeros(eltype(value), size(value)), Val(P)))
33+
end
34+
35+
# Indexing
36+
37+
Base.@propagate_inbounds function Base.getindex(a::TaylorArray, i::Int...)
38+
new_value = value(a)[i...]
39+
new_partials = map(p -> p[i...], partials(a))
40+
return TaylorScalar(new_value, new_partials)
41+
end
42+
43+
Base.@propagate_inbounds function Base.setindex!(
44+
a::TaylorArray, s::TaylorScalar, i::Int...)
45+
value(a)[i...] = value(s)
46+
for j in 1:length(partials(a))
47+
partials(a)[j][i...] = partials(s)[j]
48+
end
49+
return a
50+
end
51+
52+
Base.@propagate_inbounds function Base.setindex!(
53+
a::TaylorArray, s::Real, i::Int...)
54+
value(a)[i...] = s
55+
return a
56+
end
57+
58+
# Invariant
59+
for op in Symbol[:size, :eachindex, :IndexStyle]
60+
@eval Base.$(op)(x::TaylorArray) = Base.$(op)(value(x))
61+
end

src/chainrules.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out
22
using Base.Broadcast: broadcasted
33

4-
function rrule(::Type{TaylorScalar{T, N}}, v::T, p::NTuple{N, T}) where {N, T}
4+
function rrule(::Type{TaylorScalar}, v::T, p::NTuple{N, T}) where {N, T}
55
taylor_scalar_pullback(t̄) = NoTangent(), value(t̄), partials(t̄)
6-
return TaylorScalar{T, N}(v, p), taylor_scalar_pullback
6+
return TaylorScalar(v, p), taylor_scalar_pullback
77
end
88

99
function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
@@ -15,16 +15,18 @@ function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T}
1515
value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(0, v̄)
1616
# for structural tangent, convert to tuple
1717
function value_pullback(v̄::Tangent{P, NTuple{N, T}}) where {P}
18-
NoTangent(), TaylorScalar{T, N}(zero(T), backing(v̄))
18+
NoTangent(), TaylorScalar(zero(T), backing(v̄))
19+
end
20+
function value_pullback(v̄)
21+
NoTangent(), TaylorScalar(zero(T), map(x -> convert(T, x), Tuple(v̄)))
1922
end
20-
value_pullback(v̄) = NoTangent(), TaylorScalar{T, N}(zero(T), map(x -> convert(T, x), Tuple(v̄)))
2123
return partials(t), value_pullback
2224
end
2325

2426
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
2527
i::Integer) where {N, T}
2628
function extract_derivative_pullback(d̄)
27-
NoTangent(), TaylorScalar{T, N}(zero(T), ntuple(j -> j === i ?: zero(T), Val(N))),
29+
NoTangent(), TaylorScalar(zero(T), ntuple(j -> j === i ?: zero(T), Val(N))),
2830
NoTangent()
2931
end
3032
return extract_derivative(t, i), extract_derivative_pullback

src/derivative.jl

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,73 +2,73 @@
22
export derivative, derivative!, derivatives, make_seed
33

44
"""
5-
derivative(f, x, l, ::Val{N})
6-
derivative(f!, y, x, l, ::Val{N})
5+
derivative(f, x, l, ::Val{P})
6+
derivative(f!, y, x, l, ::Val{P})
77
8-
Computes `N`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
8+
Computes `P`-th directional derivative of `f` w.r.t. vector `x` in direction `l`.
99
"""
1010
function derivative end
1111

1212
"""
13-
derivative!(result, f, x, l, ::Val{N})
14-
derivative!(result, f!, y, x, l, ::Val{N})
13+
derivative!(result, f, x, l, ::Val{P})
14+
derivative!(result, f!, y, x, l, ::Val{P})
1515
1616
In-place derivative calculation APIs. `result` is expected to be pre-allocated and have the same shape as `y`.
1717
"""
1818
function derivative! end
1919

2020
"""
21-
derivatives(f, x, l, ::Val{N})
22-
derivatives(f!, y, x, l, ::Val{N})
21+
derivatives(f, x, l, ::Val{P})
22+
derivatives(f!, y, x, l, ::Val{P})
2323
24-
Computes all derivatives of `f` at `x` up to order `N`.
24+
Computes all derivatives of `f` at `x` up to order `P`.
2525
"""
2626
function derivatives end
2727

2828
# Convenience wrapper for adding unit seed to the input
2929

30-
@inline derivative(f, x, order::Int64) = derivative(f, x, one(eltype(x)), order)
30+
@inline derivative(f, x, p::Int64) = derivative(f, x, one(eltype(x)), p)
3131

32-
# Convenience wrappers for converting orders to value types
32+
# Convenience wrappers for converting ps to value types
3333
# and forward work to core APIs
3434

35-
@inline derivative(f, x, l, order::Int64) = derivative(f, x, l, Val{order}())
36-
@inline derivative(f!, y, x, l, order::Int64) = derivative(f!, y, x, l, Val{order}())
37-
@inline derivative!(result, f, x, l, order::Int64) = derivative!(
38-
result, f, x, l, Val{order}())
39-
@inline derivative!(result, f!, y, x, l, order::Int64) = derivative!(
40-
result, f!, y, x, l, Val{order}())
35+
@inline derivative(f, x, l, p::Int64) = derivative(f, x, l, Val{p}())
36+
@inline derivative(f!, y, x, l, p::Int64) = derivative(f!, y, x, l, Val{p}())
37+
@inline derivative!(result, f, x, l, p::Int64) = derivative!(
38+
result, f, x, l, Val{p}())
39+
@inline derivative!(result, f!, y, x, l, p::Int64) = derivative!(
40+
result, f!, y, x, l, Val{p}())
4141

4242
# Core APIs
4343

4444
# Added to help Zygote infer types
45-
@inline function make_seed(x::T, l::S, ::Val{N}) where {T <: Real, S <: Real, N}
46-
TaylorScalar{T, N}(x, convert(T, l))
45+
@inline function make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P}
46+
TaylorScalar{P}(x, convert(T, l))
4747
end
4848

49-
@inline function make_seed(x::AbstractArray{T}, l, vN::Val{N}) where {T <: Real, N}
50-
broadcast(make_seed, x, l, vN)
49+
@inline function make_seed(x::AbstractArray{T}, l, p::Val{P}) where {T <: Real, P}
50+
broadcast(make_seed, x, l, p)
5151
end
5252

53-
# `derivative` API: computes the `N - 1`-th derivative of `f` at `x`
54-
@inline derivative(f, x, l, vN::Val{N}) where {N} = extract_derivative(
55-
derivatives(f, x, l, vN), N)
56-
@inline derivative(f!, y, x, l, vN::Val{N}) where {N} = extract_derivative(
57-
derivatives(f!, y, x, l, vN), N)
58-
@inline derivative!(result, f, x, l, vN::Val{N}) where {N} = extract_derivative!(
59-
result, derivatives(f, x, l, vN), N)
60-
@inline derivative!(result, f!, y, x, l, vN::Val{N}) where {N} = extract_derivative!(
61-
result, derivatives(f!, y, x, l, vN), N)
53+
# `derivative` API: computes the `P - 1`-th derivative of `f` at `x`
54+
@inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative(
55+
derivatives(f, x, l, p), P)
56+
@inline derivative(f!, y, x, l, p::Val{P}) where {P} = extract_derivative(
57+
derivatives(f!, y, x, l, p), P)
58+
@inline derivative!(result, f, x, l, p::Val{P}) where {P} = extract_derivative!(
59+
result, derivatives(f, x, l, p), P)
60+
@inline derivative!(result, f!, y, x, l, p::Val{P}) where {P} = extract_derivative!(
61+
result, derivatives(f!, y, x, l, p), P)
6262

63-
# `derivatives` API: computes all derivatives of `f` at `x` up to order `N - 1`
63+
# `derivatives` API: computes all derivatives of `f` at `x` up to p `P - 1`
6464

6565
# Out-of-place function
66-
@inline derivatives(f, x, l, vN::Val{N}) where {N} = f(make_seed(x, l, vN))
66+
@inline derivatives(f, x, l, p::Val{P}) where {P} = f(make_seed(x, l, p))
6767

6868
# In-place function
69-
@inline function derivatives(f!, y::AbstractArray{T}, x, l, vN::Val{N}) where {T, N}
70-
buffer = similar(y, TaylorScalar{T, N})
71-
f!(buffer, make_seed(x, l, vN))
69+
@inline function derivatives(f!, y::AbstractArray{T}, x, l, p::Val{P}) where {T, P}
70+
buffer = similar(y, TaylorScalar{T, P})
71+
f!(buffer, make_seed(x, l, p))
7272
map!(value, y, buffer)
7373
return buffer
7474
end

src/primitive.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,31 @@ import Base: sinc, cosc
66
import Base: +, -, *, /, \, ^, >, <, >=, <=, ==
77
import Base: hypot, max, min
88
import Base: tail
9+
import Base: convert, promote_rule
10+
11+
Taylor = Union{TaylorScalar, TaylorArray}
12+
13+
@inline value(t::Taylor) = t.value
14+
@inline partials(t::Taylor) = t.partials
15+
@inline extract_derivative(t::Taylor, i::Integer) = t.partials[i]
16+
@inline extract_derivative(v::AbstractArray{<:TaylorScalar}, i::Integer) = map(
17+
t -> extract_derivative(t, i), v)
18+
@inline extract_derivative(r, i::Integer) = false
19+
@inline function extract_derivative!(result::AbstractArray, v::AbstractArray{T},
20+
i::Integer) where {T <: TaylorScalar}
21+
map!(t -> extract_derivative(t, i), result, v)
22+
end
23+
24+
@inline flatten(t::Taylor) = (value(t), partials(t)...)
25+
26+
function promote_rule(::Type{TaylorScalar{T, P}},
27+
::Type{S}) where {T, S, P}
28+
TaylorScalar{promote_type(T, S), P}
29+
end
30+
31+
function (::Type{F})(x::TaylorScalar{T, P}) where {T, P, F <: AbstractFloat}
32+
F(value(x))
33+
end
934

1035
# Unary
1136
@inline +(a::Number, b::TaylorScalar) = TaylorScalar(a + value(b), partials(b))

src/scalar.jl

Lines changed: 26 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,58 @@
1-
import Base: convert, promote_rule
2-
31
export TaylorScalar
42

53
"""
6-
TaylorDiff.can_taylor(V::Type)
7-
8-
Determines whether the type V is allowed as the scalar type in a
9-
Dual. By default, only `<:Real` types are allowed.
10-
"""
11-
can_taylorize(::Type{<:Real}) = true
12-
can_taylorize(::Type) = false
13-
14-
@noinline function throw_cannot_taylorize(V::Type)
15-
throw(ArgumentError("Cannot create a Taylor polynomial over scalar type $V." *
16-
" If the type behaves as a scalar, define TaylorDiff.can_taylorize(::Type{$V}) = true."))
17-
end
18-
19-
"""
20-
TaylorScalar{T, N}
4+
TaylorScalar{T, P}
215
226
Representation of Taylor polynomials.
237
248
# Fields
259
2610
- `value::T`: zeroth order coefficient
27-
- `partials::NTuple{N, T}`: i-th element of this stores the i-th derivative
11+
- `partials::NTuple{P, T}`: i-th element of this stores the i-th derivative
2812
"""
29-
struct TaylorScalar{T, N} <: Real
13+
struct TaylorScalar{T, P} <: Real
3014
value::T
31-
partials::NTuple{N, T}
32-
function TaylorScalar{T, N}(value::T, partials::NTuple{N, T}) where {T, N}
15+
partials::NTuple{P, T}
16+
function TaylorScalar(value::T, partials::NTuple{P, T}) where {T, P}
3317
can_taylorize(T) || throw_cannot_taylorize(T)
34-
new{T, N}(value, partials)
18+
new{T, P}(value, partials)
3519
end
3620
end
3721

38-
function TaylorScalar(value::T, partials::NTuple{N, T}) where {T, N}
39-
TaylorScalar{T, N}(value, partials)
40-
end
22+
# Allowing promotion of basic Number types
23+
TaylorScalar{T, P}(x) where {T, P} = TaylorScalar{P}(T(x))
4124

42-
function TaylorScalar(value_and_partials::NTuple{N, T}) where {T, N}
43-
TaylorScalar(value_and_partials[1], value_and_partials[2:end])
44-
end
25+
# Allowing construction with flattened value and partials in a tuple
26+
TaylorScalar(all::NTuple{P, T}) where {T, P} = TaylorScalar(all[1], all[2:end])
4527

4628
"""
47-
TaylorScalar{T, N}(x::T) where {T, N}
29+
TaylorScalar{P}(value::T) where {T, P}
4830
49-
Construct a Taylor polynomial with zeroth order coefficient.
31+
Convenience function: construct a Taylor polynomial with zeroth order coefficient.
5032
"""
51-
TaylorScalar{T, N}(x::S) where {T, S <: Real, N} = TaylorScalar(
52-
T(x), ntuple(i -> zero(T), Val(N)))
33+
TaylorScalar{P}(value::T) where {T, P} = TaylorScalar(value, ntuple(i -> zero(T), Val(P)))
5334

5435
"""
55-
TaylorScalar{T, N}(x::T, d::T) where {T, N}
36+
TaylorScalar{P}(value::T, seed::T)
5637
57-
Construct a Taylor polynomial with zeroth and first order coefficient, acting as a seed.
38+
Convenience function: construct a Taylor polynomial with zeroth and first order coefficient, acting as a seed.
5839
"""
59-
TaylorScalar{T, N}(x::S, d::S) where {T, S <: Real, N} = TaylorScalar(
60-
T(x), ntuple(i -> i == 1 ? T(d) : zero(T), Val(N)))
40+
TaylorScalar{P}(value::T, seed::T) where {T, P} = TaylorScalar(
41+
value, ntuple(i -> i == 1 ? seed : zero(T), Val(P)))
6142

62-
function TaylorScalar{T, N}(t::TaylorScalar{T, M}) where {T, N, M}
43+
function TaylorScalar{P}(t::TaylorScalar{T, Q}) where {T, P, Q}
6344
v = value(t)
6445
p = partials(t)
65-
N <= M ? TaylorScalar(v, p[1:N]) :
66-
TaylorScalar(v, ntuple(i -> i <= M ? p[i] : zero(T), Val(N)))
46+
P <= Q ? TaylorScalar(v, p[1:P]) :
47+
TaylorScalar(v, ntuple(i -> i <= Q ? p[i] : zero(T), Val(P)))
6748
end
6849

69-
@inline value(t::TaylorScalar) = t.value
70-
@inline partials(t::TaylorScalar) = t.partials
71-
@inline extract_derivative(t::TaylorScalar, i::Integer) = t.partials[i]
72-
@inline function extract_derivative(v::AbstractArray{T},
73-
i::Integer) where {T <: TaylorScalar}
74-
map(t -> extract_derivative(t, i), v)
75-
end
76-
@inline extract_derivative(r, i::Integer) = false
77-
@inline function extract_derivative!(result::AbstractArray, v::AbstractArray{T},
78-
i::Integer) where {T <: TaylorScalar}
79-
map!(t -> extract_derivative(t, i), result, v)
50+
# Covariant: operate on the value, and reconstruct with the partials
51+
for op in Symbol[:nextfloat, :prevfloat]
52+
@eval Base.$(op)(x::TaylorScalar) = TaylorScalar(Base.$(op)(value(x)), partials(x))
8053
end
8154

82-
@inline flatten(t::TaylorScalar) = (value(t), partials(t)...)
83-
84-
function promote_rule(::Type{TaylorScalar{T, N}},
85-
::Type{S}) where {T, S, N}
86-
TaylorScalar{promote_type(T, S), N}
87-
end
88-
89-
function (::Type{F})(x::TaylorScalar{T, N}) where {T, N, F <: AbstractFloat}
90-
F(value(x))
91-
end
92-
93-
const COVARIANT_OPS = Symbol[:nextfloat, :prevfloat]
94-
95-
for op in COVARIANT_OPS
96-
@eval Base.$(op)(x::TaylorScalar{T, N}) where {T, N} = TaylorScalar($(op)(value(x)), partials(x))
97-
end
98-
99-
const UNARY_PREDICATES = Symbol[
100-
:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]
101-
102-
for pred in UNARY_PREDICATES
103-
@eval Base.$(pred)(x::TaylorScalar) = $(pred)(value(x))
55+
# Invariant: operate on the value, and drop the partials
56+
for op in Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]
57+
@eval Base.$(op)(x::TaylorScalar) = Base.$(op)(value(x))
10458
end

0 commit comments

Comments
 (0)