@@ -38,34 +38,35 @@ function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T <: Numbe
3838 return TaylorScalar (v), taylor_scalar_pullback
3939end
4040
41- function rrule (:: typeof (value), t:: TaylorScalar )
42- value_pullback (v̄:: NTuple ) = NoTangent (), TaylorScalar (v̄)
41+ function rrule (:: typeof (value), t:: TaylorScalar{T, N} ) where {N, T}
42+ value_pullback (v̄:: NTuple{N, T} ) = NoTangent (), TaylorScalar (v̄)
43+ value_pullback (v̄:: Tuple ) = NoTangent (), TaylorScalar (map (x -> convert (T, x), v̄))
4344 # for structural tangent, convert to tuple
44- value_pullback (v̄) = NoTangent (), TaylorScalar (Tuple (v̄))
45+ value_pullback (v̄) = NoTangent (), TaylorScalar (map (x -> convert (T, x), Tuple (v̄) ))
4546 return value (t), value_pullback
4647end
4748
4849function rrule (:: typeof (extract_derivative), t:: TaylorScalar{T, N} ,
4950 i:: Integer ) where {N, T <: Number }
5051 function extract_derivative_pullback (d̄)
51- NoTangent (), TaylorScalar (( zeros ( T, i - 1 ) ... , d̄, zeros (T, N - i) ... )), NoTangent ()
52+ NoTangent (), TaylorScalar { T, N} ( ntuple (j -> j === i ? d̄ : zero (T), Val (N) )), NoTangent ()
5253 end
5354 return extract_derivative (t, i), extract_derivative_pullback
5455end
5556
5657function rrule (:: typeof (* ), A:: Matrix{S} , t:: Vector{TaylorScalar{T, N}} ) where {N, S <: Number , T}
5758 project_A = ProjectTo (A)
58- gemv_pullback (x̄) = NoTangent (), project_A (contract .(x̄, transpose (t))), transpose (A) * x̄
59+ gemv_pullback (x̄) = NoTangent (), @thunk ( project_A (contract .(x̄, transpose (t)))), @thunk ( transpose (A) * x̄)
5960 return A * t, gemv_pullback
6061end
6162
6263function rrule (:: typeof (+ ), v:: Vector{T} , t:: Vector{TaylorScalar{T, N}} ) where {N, T <: Number }
63- vadd_pullback (x̄) = NoTangent (), map (primal, x̄), x̄
64+ vadd_pullback (x̄) = NoTangent (), ProjectTo (v)( x̄), x̄
6465 return v + t, vadd_pullback
6566end
6667
6768function rrule (:: typeof (+ ), t:: Vector{TaylorScalar{T, N}} , v:: Vector{T} ) where {N, T <: Number }
68- vadd_pullback (x̄) = NoTangent (), x̄, map (primal, x̄)
69+ vadd_pullback (x̄) = NoTangent (), x̄, ProjectTo (v)( x̄)
6970 return t + v, vadd_pullback
7071end
7172
0 commit comments