@@ -8,7 +8,7 @@ Representation of Taylor polynomials in array mode.
88# Fields
99
1010- `value::A`: zeroth order coefficient
11- - `partials::NTuple{P, T }`: i-th element of this stores the i-th derivative
11+ - `partials::NTuple{P, A }`: i-th element of this stores the i-th derivative
1212"""
1313struct TaylorArray{T, N, A <: AbstractArray{T, N} , P} < :
1414 AbstractArray{TaylorScalar{T, P}, N}
@@ -24,15 +24,28 @@ struct TaylorArray{T, N, A <: AbstractArray{T, N}, P} <:
2424end
2525
2626function TaylorArray {P} (value:: A ) where {A <: AbstractArray , P}
27- TaylorArray (value, ntuple (i -> zeros ( eltype (value), size ( value) ), Val (P)))
27+ TaylorArray (value, ntuple (i -> broadcast (zero, value), Val (P)))
2828end
2929
3030function TaylorArray {P} (value:: A , seed:: A ) where {A <: AbstractArray , P}
3131 TaylorArray (
32- value, ntuple (i -> i == 1 ? seed : zeros ( eltype (value), size (value) ), Val (P)))
32+ value, ntuple (i -> i == 1 ? seed : broadcast (zero, seed ), Val (P)))
3333end
3434
35- # Indexing
35+ # Necessary AbstractArray interface methods for TaylorArray to work
36+ # https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-array
37+
38+ # # 1. Invariant
39+ for op in Symbol[:size , :strides , :eachindex , :IndexStyle ]
40+ @eval Base.$ (op)(x:: TaylorArray ) = Base.$ (op)(value (x))
41+ end
42+
43+ # # 2. Indexing
44+ function Base. similar (a:: TaylorArray , :: Type{<:TaylorScalar{T}} , dims:: Dims ) where {T}
45+ new_value = similar (value (a), T, dims)
46+ new_partials = map (p -> similar (p, T, dims), partials (a))
47+ return TaylorArray (new_value, new_partials)
48+ end
3649
3750Base. @propagate_inbounds function Base. getindex (a:: TaylorArray , i:: Int... )
3851 new_value = value (a)[i... ]
@@ -55,7 +68,31 @@ Base.@propagate_inbounds function Base.setindex!(
5568 return a
5669end
5770
58- # Invariant
59- for op in Symbol[:size , :eachindex , :IndexStyle ]
60- @eval Base.$ (op)(x:: TaylorArray ) = Base.$ (op)(value (x))
71+ # # 3. Broadcasting
72+ struct TaylorArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
73+ TaylorArrayStyle (:: Val{N} ) where {N} = TaylorArrayStyle {N} ()
74+ TaylorArrayStyle {M} (:: Val{N} ) where {N, M} = TaylorArrayStyle {N} ()
75+
76+ Base. BroadcastStyle (:: Type{<:TaylorArray{T, N}} ) where {T, N} = TaylorArrayStyle {N} ()
77+ # This is added to make Zygote custom broadcasting work
78+ # However, we might implement custom broadcasting semantics for TaylorArray in the future
79+ # function Base.BroadcastStyle(::Type{<:Array{
80+ # <:Tuple{TaylorScalar{T, P}, Any}, N}}) where {T, N, P}
81+ # TaylorArrayStyle{N}()
82+ # end
83+
84+ function Base. similar (
85+ bc:: Broadcast.Broadcasted{<:TaylorArrayStyle} , :: Type{ElType} ) where {ElType}
86+ A = find_taylor (bc)
87+ similar (A, ElType, axes (bc))
88+ end
89+
90+ find_taylor (bc:: Broadcast.Broadcasted ) = find_taylor (bc. args)
91+ find_taylor (args:: Tuple ) = find_taylor (find_taylor (args[1 ]), Base. tail (args))
92+ find_taylor (x) = x
93+ find_taylor (:: Tuple{} ) = nothing
94+ find_taylor (a:: TaylorArray , rest) = a
95+ function find_taylor (a:: Array{<:Tuple{TaylorScalar{T, P}, Any}, N} , rest) where {T, P, N}
96+ TaylorArray {P} (zeros (T, size (a)))
6197end
98+ find_taylor (:: Any , rest) = find_taylor (rest)
0 commit comments