|
1 | 1 | import ChainRulesCore: rrule, RuleConfig, ProjectTo |
2 | 2 | using ZygoteRules: @adjoint |
3 | 3 |
|
4 | | -contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N} = mapreduce(*, +, value(a), value(b)) |
5 | | - |
6 | | -NONLINEAR_UNARY_FUNCTIONS = Function[ |
7 | | - exp, exp2, exp10, expm1, |
8 | | - log, log2, log10, log1p, |
9 | | - inv, sqrt, cbrt, |
10 | | - sin, cos, tan, cot, sec, csc, |
11 | | - asin, acos, atan, acot, asec, acsc, |
12 | | - sinh, cosh, tanh, coth, sech, csch, |
13 | | - asinh, acosh, atanh, acoth, asech, acsch, |
14 | | -] |
| 4 | +function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N} |
| 5 | + mapreduce(*, +, value(a), value(b)) |
| 6 | +end |
| 7 | + |
| 8 | +NONLINEAR_UNARY_FUNCTIONS = Function[exp, exp2, exp10, expm1, |
| 9 | + log, log2, log10, log1p, |
| 10 | + inv, sqrt, cbrt, |
| 11 | + sin, cos, tan, cot, sec, csc, |
| 12 | + asin, acos, atan, acot, asec, acsc, |
| 13 | + sinh, cosh, tanh, coth, sech, csch, |
| 14 | + asinh, acosh, atanh, acoth, asech, acsch] |
15 | 15 |
|
16 | 16 | for func in NONLINEAR_UNARY_FUNCTIONS |
17 | 17 | @eval @opt_out rrule(::typeof($func), ::TaylorScalar) |
18 | 18 | end |
19 | 19 |
|
20 | | -NONLINEAR_BINARY_FUNCTIONS = Function[ |
21 | | - *, /, ^ |
22 | | -] |
| 20 | +NONLINEAR_BINARY_FUNCTIONS = Function[*, /, ^] |
23 | 21 |
|
24 | 22 | for func in NONLINEAR_BINARY_FUNCTIONS |
25 | 23 | @eval @opt_out rrule(::typeof($func), ::TaylorScalar, ::TaylorScalar) |
|
49 | 47 | function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N}, |
50 | 48 | i::Integer) where {N, T <: Number} |
51 | 49 | function extract_derivative_pullback(d̄) |
52 | | - NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ? d̄ : zero(T), Val(N))), NoTangent() |
| 50 | + NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ? d̄ : zero(T), Val(N))), |
| 51 | + NoTangent() |
53 | 52 | end |
54 | 53 | return extract_derivative(t, i), extract_derivative_pullback |
55 | 54 | end |
56 | 55 |
|
57 | | -function rrule(::typeof(*), A::Matrix{S}, t::Vector{TaylorScalar{T, N}}) where {N, S <: Number, T} |
| 56 | +function rrule(::typeof(*), A::Matrix{S}, |
| 57 | + t::Vector{TaylorScalar{T, N}}) where {N, S <: Number, T} |
58 | 58 | project_A = ProjectTo(A) |
59 | | - gemv_pullback(x̄) = NoTangent(), @thunk(project_A(contract.(x̄, transpose(t)))), @thunk(transpose(A) * x̄) |
| 59 | + function gemv_pullback(x̄) |
| 60 | + NoTangent(), @thunk(project_A(contract.(x̄, transpose(t)))), @thunk(transpose(A)*x̄) |
| 61 | + end |
60 | 62 | return A * t, gemv_pullback |
61 | 63 | end |
62 | 64 |
|
63 | | -function rrule(::typeof(+), v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number} |
| 65 | +function rrule(::typeof(+), v::Vector{T}, |
| 66 | + t::Vector{TaylorScalar{T, N}}) where {N, T <: Number} |
64 | 67 | vadd_pullback(x̄) = NoTangent(), ProjectTo(v)(x̄), x̄ |
65 | 68 | return v + t, vadd_pullback |
66 | 69 | end |
67 | 70 |
|
68 | | -function rrule(::typeof(+), t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number} |
| 71 | +function rrule(::typeof(+), t::Vector{TaylorScalar{T, N}}, |
| 72 | + v::Vector{T}) where {N, T <: Number} |
69 | 73 | vadd_pullback(x̄) = NoTangent(), x̄, ProjectTo(v)(x̄) |
70 | 74 | return t + v, vadd_pullback |
71 | 75 | end |
72 | 76 |
|
73 | | -@adjoint +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number} = t + v, x̄ -> (x̄, map(primal, x̄)) |
| 77 | +@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number} |
| 78 | + t + v, x̄ -> (x̄, map(primal, x̄)) |
| 79 | +end |
74 | 80 |
|
75 | | -@adjoint +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number} = v + t, x̄ -> (map(primal, x̄), x̄) |
| 81 | +@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number} |
| 82 | + v + t, x̄ -> (map(primal, x̄), x̄) |
| 83 | +end |
76 | 84 |
|
77 | 85 | (project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx) |
78 | 86 |
|
79 | | -(project::ProjectTo{S})(dx::TaylorScalar{T, N}) where {N, T <: Number, S <: Real} = project(primal(dx)) |
| 87 | +function (project::ProjectTo{S})(dx::TaylorScalar{T, N}) where {N, T <: Number, S <: Real} |
| 88 | + project(primal(dx)) |
| 89 | +end |
0 commit comments