11import ChainRulesCore: rrule, RuleConfig
2+ using ZygoteRules: @adjoint
3+
4+ contract (a:: TaylorScalar{T, N} , b:: TaylorScalar{S, N} ) where {T, S, N} = mapreduce (* , + , value (a), value (b))
5+
6+ NONLINEAR_UNARY_FUNCTIONS = Function[
7+ exp, exp2, exp10, expm1,
8+ log, log2, log10, log1p,
9+ sin, cos, tan, cot, sec, csc,
10+ asin, acos, atan, acot, asec, acsc,
11+ sinh, cosh, tanh, coth, sech, csch,
12+ asinh, acosh, atanh, acoth, asech, acsch,
13+ ]
14+
15+ for func in NONLINEAR_UNARY_FUNCTIONS
16+ @eval @opt_out rrule (:: typeof ($ func), :: TaylorScalar )
17+ end
18+
19+ NONLINEAR_BINARY_FUNCTIONS = Function[
20+ * , / , ^
21+ ]
22+
23+ for func in NONLINEAR_BINARY_FUNCTIONS
24+ @eval @opt_out rrule (:: typeof ($ func), :: TaylorScalar , :: TaylorScalar )
25+ @eval @opt_out rrule (:: typeof ($ func), :: TaylorScalar , :: Number )
26+ @eval @opt_out rrule (:: typeof ($ func), :: Number , :: TaylorScalar )
27+ end
28+
29+ # Other special cases
230
3- @opt_out rrule (:: Any , :: TaylorScalar , :: TaylorScalar )
4- @opt_out rrule (:: Any , :: TaylorScalar , :: Any )
5- @opt_out rrule (:: typeof (* ), :: TaylorScalar , :: TaylorScalar )
6- @opt_out rrule (:: typeof (^ ), :: TaylorScalar , :: Any )
731@opt_out rrule (:: typeof (Base. literal_pow), :: typeof (^ ), x:: TaylorScalar , :: Val{p} ) where {p}
832@opt_out rrule (:: RuleConfig , :: typeof (Base. literal_pow), :: typeof (^ ), x:: TaylorScalar ,
933 :: Val{p} ) where {p}
@@ -29,7 +53,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
2953end
3054
3155function rrule (:: typeof (* ), A:: Matrix{T} , t:: Vector{TaylorScalar{T, N}} ) where {N, T <: Number }
32- gemv_pullback (x̄) = NoTangent (), map (primal, x̄) * transpose (map (primal, t)), transpose (A) * x̄
56+ gemv_pullback (x̄) = NoTangent (), contract .(x̄, transpose (t)), transpose (A) * x̄
3357 return A * t, gemv_pullback
3458end
3559
@@ -42,3 +66,7 @@ function rrule(::typeof(+), t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {
4266 vadd_pullback (x̄) = NoTangent (), x̄, map (primal, x̄)
4367 return t + v, vadd_pullback
4468end
69+
70+ @adjoint + (t:: Vector{TaylorScalar{T, N}} , v:: Vector{T} ) where {N, T <: Number } = t + v, x̄ -> (x̄, map (primal, x̄))
71+
72+ @adjoint + (v:: Vector{T} , t:: Vector{TaylorScalar{T, N}} ) where {N, T <: Number } = v + t, x̄ -> (map (primal, x̄), x̄)
0 commit comments