1- import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing
1+ import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing, @opt_out
22using Base. Broadcast: broadcasted
3- import Zygote: @adjoint , accum_sum, unbroadcast, Numeric, ∇getindex, _project
43
54function contract (a:: TaylorScalar{T, N} , b:: TaylorScalar{S, N} ) where {T, S, N}
65 mapreduce (* , + , value (a), value (b))
@@ -31,7 +30,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
3130end
3231
3332function rrule (:: typeof (* ), A:: AbstractMatrix{S} ,
34- t:: AbstractVector{TaylorScalar{T, N}} ) where {N, S, T}
33+ t:: AbstractVector{TaylorScalar{T, N}} ) where {N, S <: Real , T <: Real }
3534 project_A = ProjectTo (A)
3635 function gemv_pullback (x̄)
3736 x̂ = reinterpret (reshape, T, x̄)
@@ -42,7 +41,7 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
4241end
4342
4443function rrule (:: typeof (* ), A:: AbstractMatrix{S} ,
45- B:: AbstractMatrix{TaylorScalar{T, N}} ) where {N, S, T}
44+ B:: AbstractMatrix{TaylorScalar{T, N}} ) where {N, S <: Real , T <: Real }
4645 project_A = ProjectTo (A)
4746 project_B = ProjectTo (B)
4847 function gemm_pullback (x̄)
@@ -54,88 +53,36 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
5453 return A * B, gemm_pullback
5554end
5655
57- @adjoint function + (t:: Vector{TaylorScalar{T, N}} , v:: Vector{T} ) where {N, T}
58- project_v = ProjectTo (v)
59- t + v, x̄ -> (x̄, project_v (x̄))
60- end
61-
62- @adjoint function + (v:: Vector{T} , t:: Vector{TaylorScalar{T, N}} ) where {N, T}
63- project_v = ProjectTo (v)
64- v + t, x̄ -> (project_v (x̄), x̄)
65- end
66-
67- (project:: ProjectTo{T} )(dx:: TaylorScalar{T, N} ) where {N, T} = primal (dx)
68-
69- # Not-a-number patches
70-
71- ProjectTo (:: T ) where {T <: TaylorScalar } = ProjectTo {T} ()
72- (p:: ProjectTo{T} )(x:: T ) where {T <: TaylorScalar } = x
73- function ProjectTo (x:: AbstractArray{T} ) where {T <: TaylorScalar }
74- ProjectTo {AbstractArray} (; element = ProjectTo (zero (T)), axes = axes (x))
75- end
76- (p:: ProjectTo{AbstractArray{T}} )(x:: AbstractArray{T} ) where {T <: TaylorScalar } = x
77- accum_sum (xs:: AbstractArray{T} ; dims = :) where {T <: TaylorScalar } = sum (xs, dims = dims)
78-
79- TaylorNumeric{T <: TaylorScalar } = Union{T, AbstractArray{<: T }}
80-
81- @adjoint function broadcasted (:: typeof (+ ), xs:: TaylorNumeric... )
82- broadcast (+ , xs... ), ȳ -> (nothing , map (x -> unbroadcast (x, ȳ), xs)... )
83- end
56+ (project:: ProjectTo{T} )(dx:: TaylorScalar{T, N} ) where {N, T <: Number } = primal (dx)
8457
85- struct TaylorOneElement{T, N, I, A} <: AbstractArray{T, N}
86- val:: T
87- ind:: I
88- axes:: A
89- function TaylorOneElement (val:: T , ind:: I ,
90- axes:: A ) where {T <: TaylorScalar , I <: NTuple{N, Int} ,
91- A <: NTuple{N, AbstractUnitRange} } where {N}
92- new {T, N, I, A} (val, ind, axes)
93- end
94- end
58+ # opt-outs
9559
96- Base. size (A:: TaylorOneElement ) = map (length, A. axes)
97- Base. axes (A:: TaylorOneElement ) = A. axes
98- function Base. getindex (A:: TaylorOneElement{T, N} , i:: Vararg{Int, N} ) where {T, N}
99- ifelse (i == A. ind, A. val, zero (T))
100- end
101-
102- function ∇getindex (x:: AbstractArray{T, N} , inds) where {T <: TaylorScalar , N}
103- dy -> begin
104- dx = TaylorOneElement (dy, inds, axes (x))
105- return (_project (x, dx), map (_ -> nothing , inds)... )
106- end
107- end
60+ # Unary functions
10861
109- @generated function mul_adjoint (Ω:: TaylorScalar{T, N} , x:: TaylorScalar{T, N} ) where {T, N}
110- return quote
111- vΩ, vx = value (Ω), value (x)
112- @inbounds TaylorScalar ($ ([:(+ ($ ([:($ (binomial (j - 1 , i - 1 )) * vΩ[$ j] *
113- vx[$ (j + 1 - i)]) for j in i: N]. .. )))
114- for i in 1 : N]. .. ))
115- end
62+ for f in (
63+ exp, exp10, exp2, expm1,
64+ sin, cos, tan, sec, csc, cot,
65+ sinh, cosh, tanh, sech, csch, coth,
66+ log, log10, log2, log1p,
67+ asin, acos, atan, asec, acsc, acot,
68+ asinh, acosh, atanh, asech, acsch, acoth,
69+ sqrt, cbrt, inv
70+ )
71+ @eval @opt_out frule (:: typeof ($ f), x:: TaylorScalar )
72+ @eval @opt_out rrule (:: typeof ($ f), x:: TaylorScalar )
11673end
11774
118- rrule (:: typeof (* ), x:: TaylorScalar ) = rrule (identity, x)
119-
120- function rrule (:: typeof (* ), x:: TaylorScalar , y:: TaylorScalar )
121- function times_pullback2 (Ω̇)
122- ΔΩ = unthunk (Ω̇)
123- return (NoTangent (), ProjectTo (x)(mul_adjoint (ΔΩ, y)),
124- ProjectTo (y)(mul_adjoint (ΔΩ, x)))
125- end
126- return x * y, times_pullback2
127- end
75+ # Binary functions
12876
129- function rrule ( :: typeof ( * ), x :: TaylorScalar , y :: TaylorScalar , z :: TaylorScalar ,
130- more :: TaylorScalar... )
131- Ω2, back2 = rrule ( * , x, y )
132- Ω3, back3 = rrule ( * , Ω2, z)
133- Ω4, back4 = rrule ( * , Ω3, more ... )
134- function times_pullback4 (Ω̇)
135- Δ4 = back4 ( unthunk (Ω̇)) # (0, ΔΩ3, Δmore... )
136- Δ3 = back3 (Δ4[ 2 ]) # (0, ΔΩ2, Δz )
137- Δ2 = back2 (Δ3[ 2 ]) # (0, Δx, Δy )
138- return (Δ2 ... , Δ3[ 3 ], Δ4[ 3 : end ] . .. )
77+ for f in (
78+ * , / , ^
79+ )
80+ for (tlhs, trhs) in (
81+ (TaylorScalar, TaylorScalar),
82+ (TaylorScalar, Number),
83+ (Number, TaylorScalar )
84+ )
85+ @eval @opt_out frule ( :: typeof ( $ f), x :: $tlhs , y :: $trhs )
86+ @eval @opt_out rrule ( :: typeof ( $ f), x :: $tlhs , y :: $trhs )
13987 end
140- return Ω4, times_pullback4
14188end
0 commit comments