Skip to content

Commit 2b0b1ce

Browse files
committed
Add broadcasting
1 parent 2555949 commit 2b0b1ce

6 files changed

Lines changed: 76 additions & 24 deletions

File tree

src/array.jl

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Representation of Taylor polynomials in array mode.
88
# Fields
99
1010
- `value::A`: zeroth order coefficient
11-
- `partials::NTuple{P, T}`: i-th element of this stores the i-th derivative
11+
- `partials::NTuple{P, A}`: i-th element of this stores the i-th derivative
1212
"""
1313
struct TaylorArray{T, N, A <: AbstractArray{T, N}, P} <:
1414
AbstractArray{TaylorScalar{T, P}, N}
@@ -24,15 +24,28 @@ struct TaylorArray{T, N, A <: AbstractArray{T, N}, P} <:
2424
end
2525

2626
function TaylorArray{P}(value::A) where {A <: AbstractArray, P}
27-
TaylorArray(value, ntuple(i -> zeros(eltype(value), size(value)), Val(P)))
27+
TaylorArray(value, ntuple(i -> broadcast(zero, value), Val(P)))
2828
end
2929

3030
function TaylorArray{P}(value::A, seed::A) where {A <: AbstractArray, P}
3131
TaylorArray(
32-
value, ntuple(i -> i == 1 ? seed : zeros(eltype(value), size(value)), Val(P)))
32+
value, ntuple(i -> i == 1 ? seed : broadcast(zero, seed), Val(P)))
3333
end
3434

35-
# Indexing
35+
# Necessary AbstractArray interface methods for TaylorArray to work
36+
# https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-array
37+
38+
## 1. Invariant
39+
for op in Symbol[:size, :strides, :eachindex, :IndexStyle]
40+
@eval Base.$(op)(x::TaylorArray) = Base.$(op)(value(x))
41+
end
42+
43+
## 2. Indexing
44+
function Base.similar(a::TaylorArray, ::Type{<:TaylorScalar{T}}, dims::Dims) where {T}
45+
new_value = similar(value(a), T, dims)
46+
new_partials = map(p -> similar(p, T, dims), partials(a))
47+
return TaylorArray(new_value, new_partials)
48+
end
3649

3750
Base.@propagate_inbounds function Base.getindex(a::TaylorArray, i::Int...)
3851
new_value = value(a)[i...]
@@ -55,7 +68,31 @@ Base.@propagate_inbounds function Base.setindex!(
5568
return a
5669
end
5770

58-
# Invariant
59-
for op in Symbol[:size, :eachindex, :IndexStyle]
60-
@eval Base.$(op)(x::TaylorArray) = Base.$(op)(value(x))
71+
## 3. Broadcasting
72+
struct TaylorArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
73+
TaylorArrayStyle(::Val{N}) where {N} = TaylorArrayStyle{N}()
74+
TaylorArrayStyle{M}(::Val{N}) where {N, M} = TaylorArrayStyle{N}()
75+
76+
Base.BroadcastStyle(::Type{<:TaylorArray{T, N}}) where {T, N} = TaylorArrayStyle{N}()
77+
# This is added to make Zygote custom broadcasting work
78+
# However, we might implement custom broadcasting semantics for TaylorArray in the future
79+
# function Base.BroadcastStyle(::Type{<:Array{
80+
# <:Tuple{TaylorScalar{T, P}, Any}, N}}) where {T, N, P}
81+
# TaylorArrayStyle{N}()
82+
# end
83+
84+
function Base.similar(
85+
bc::Broadcast.Broadcasted{<:TaylorArrayStyle}, ::Type{ElType}) where {ElType}
86+
A = find_taylor(bc)
87+
similar(A, ElType, axes(bc))
88+
end
89+
90+
find_taylor(bc::Broadcast.Broadcasted) = find_taylor(bc.args)
91+
find_taylor(args::Tuple) = find_taylor(find_taylor(args[1]), Base.tail(args))
92+
find_taylor(x) = x
93+
find_taylor(::Tuple{}) = nothing
94+
find_taylor(a::TaylorArray, rest) = a
95+
function find_taylor(a::Array{<:Tuple{TaylorScalar{T, P}, Any}, N}, rest) where {T, P, N}
96+
TaylorArray{P}(zeros(T, size(a)))
6197
end
98+
find_taylor(::Any, rest) = find_taylor(rest)

src/chainrules.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@ import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out
22
using Base.Broadcast: broadcasted
33

44
function rrule(::Type{TaylorScalar}, v::T, p::NTuple{N, T}) where {N, T}
5-
taylor_scalar_pullback(t̄) = NoTangent(), value(t̄), partials(t̄)
6-
return TaylorScalar(v, p), taylor_scalar_pullback
5+
constructor_pullback(t̄) = NoTangent(), value(t̄), partials(t̄)
6+
return TaylorScalar(v, p), constructor_pullback
7+
end
8+
9+
function rrule(::Type{TaylorArray}, v::T, p::NTuple{N, T}) where {N, T}
10+
constructor_pullback(t̄) = NoTangent(), value(t̄), partials(t̄)
11+
return TaylorArray(v, p), constructor_pullback
712
end
813

914
function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
@@ -23,6 +28,11 @@ function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T}
2328
return partials(t), value_pullback
2429
end
2530

31+
function rrule(::typeof(partials), t::TaylorArray{T, N, A, P}) where {N, T, A, P}
32+
value_pullback(v̄::NTuple{P, A}) = NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄)
33+
return partials(t), value_pullback
34+
end
35+
2636
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
2737
i::Integer) where {N, T}
2838
function extract_derivative_pullback(d̄)
@@ -32,6 +42,16 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
3242
return extract_derivative(t, i), extract_derivative_pullback
3343
end
3444

45+
function rrule(::typeof(Base.getindex), a::TaylorArray, i::Int...)
46+
function getindex_pullback(t̄)
47+
= similar(a)
48+
ā .= zero(eltype(a))
49+
ā[i...] =
50+
NoTangent(), ā, map(Returns(NoTangent()), i)
51+
end
52+
return getindex(a, i...), getindex_pullback
53+
end
54+
3555
function rrule(::typeof(*), A::AbstractMatrix{S},
3656
t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Real, T <: Real}
3757
project_A = ProjectTo(A)

src/derivative.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function derivatives end
2727

2828
# Convenience wrapper for adding unit seed to the input
2929

30-
@inline derivative(f, x, p::Int64) = derivative(f, x, one(eltype(x)), p)
30+
@inline derivative(f, x, p::Int64) = derivative(f, x, broadcast(one, x), p)
3131

3232
# Convenience wrappers for converting ps to value types
3333
# and forward work to core APIs
@@ -42,13 +42,8 @@ function derivatives end
4242
# Core APIs
4343

4444
# Added to help Zygote infer types
45-
@inline function make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P}
46-
TaylorScalar{P}(x, convert(T, l))
47-
end
48-
49-
@inline function make_seed(x::AbstractArray{T}, l, p::Val{P}) where {T <: Real, P}
50-
broadcast(make_seed, x, l, p)
51-
end
45+
@inline make_seed(x::T, l::T, ::Val{P}) where {T <: Real, P} = TaylorScalar{P}(x, l)
46+
@inline make_seed(x::A, l::A, ::Val{P}) where {A <: AbstractArray, P} = broadcast(make_seed, x, l, Val{P}())
5247

5348
# `derivative` API: computes the `P - 1`-th derivative of `f` at `x`
5449
@inline derivative(f, x, l, p::Val{P}) where {P} = extract_derivative(
@@ -66,8 +61,8 @@ end
6661
@inline derivatives(f, x, l, p::Val{P}) where {P} = f(make_seed(x, l, p))
6762

6863
# In-place function
69-
@inline function derivatives(f!, y::AbstractArray{T}, x, l, p::Val{P}) where {T, P}
70-
buffer = similar(y, TaylorScalar{T, P})
64+
@inline function derivatives(f!, y, x, l, p::Val{P}) where {P}
65+
buffer = similar(y, TaylorScalar{eltype(y), P})
7166
f!(buffer, make_seed(x, l, p))
7267
map!(value, y, buffer)
7368
return buffer

test/derivative.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
@test derivative(g1, [1.0, 2.0], [1.0, 0.0], 1) 2.0
1111

1212
h1(x) = sum(x, dims = 1)
13-
@test derivative(h1, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) [2.0 2.0]
13+
@test derivative(h1, [1.0 2.0; 2.0 3.0], [1.0 1.0; 1.0 1.0], 1) [2.0 2.0]
1414
end
1515

1616
@testset "I-function, O-derivative" begin

test/downstream.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ backend = AutoZygote()
2727
# Matrix functions
2828
some_matrix = [0.7 0.1; 0.4 0.2]
2929
f(x) = sum(exp.(x), dims = 1)
30-
dfdx1(x) = derivative(f, x, [1.0, 0.0], 1)
31-
dfdx2(x) = derivative(f, x, [0.0, 1.0], 1)
30+
dfdx1(x) = derivative(f, x, [1.0 1.0; 0.0 0.0], 1)
31+
dfdx2(x) = derivative(f, x, [0.0 0.0; 1.0 1.0], 1)
3232
res(x) = sum(dfdx1(x) .+ 2 * dfdx2(x))
3333
grad = DI.gradient(res, backend, some_matrix)
3434
@test grad [1 0; 0 2] * exp.(some_matrix)

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using TaylorDiff
22
using Test
33

4-
# include("primitive.jl")
5-
# include("derivative.jl")
4+
include("primitive.jl")
5+
include("derivative.jl")
66
include("downstream.jl")
77
# include("lux.jl")

0 commit comments

Comments
 (0)