|
1 | 1 | using ChainRules |
2 | 2 | using ChainRulesCore |
3 | | -using Symbolics: @variables |
4 | | -using SymbolicUtils, SymbolicUtils.Code |
5 | | -using SymbolicUtils: Pow |
| 3 | +using Symbolics: @variables, @rule, unwrap, isdiv |
| 4 | +using SymbolicUtils.Code: toexpr |
6 | 5 |
|
7 | | -dummy = (NoTangent(), 1) |
8 | | -@variables z |
| 6 | +""" |
| 7 | +Pick a strategy for raising the derivative of a function. If the derivative is like 1 over something, raise with the division rule; otherwise, raise with the multiplication rule. |
| 8 | +""" |
| 9 | +function get_term_raiser(func) |
| 10 | + @variables z |
| 11 | + r1 = @rule -1 * (1 / ~x) => (-1) / ~x |
| 12 | + der = frule((NoTangent(), true), func, z)[2] |
| 13 | + term = unwrap(der) |
| 14 | + maybe_rewrite = r1(term) |
| 15 | + if maybe_rewrite !== nothing |
| 16 | + term = maybe_rewrite |
| 17 | + end |
| 18 | + if isdiv(term) && (term.num == 1 || term.num == -1) |
| 19 | + term.den * term.num, raiseinv |
| 20 | + else |
| 21 | + term, raise |
| 22 | + end |
| 23 | +end |
9 | 24 |
|
10 | 25 | function define_unary_function(func, m) |
11 | 26 | F = typeof(func) |
12 | | - # base case |
| 27 | + # First order: call frule directly |
13 | 28 | @eval m function (op::$F)(t::TaylorScalar{T, 1}) where {T} |
14 | 29 | t0 = value(t) |
15 | 30 | t1 = first(partials(t)) |
16 | 31 | f0, f1 = frule((NoTangent(), t1), op, t0) |
17 | 32 | TaylorScalar{T, 1}(f0, zero_tangent(f0) + f1) |
18 | 33 | end |
19 | | - der = frule(dummy, func, z)[2] |
20 | | - term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise) |
21 | | - # recursion by raising |
| 34 | + term, raiser = get_term_raiser(func) |
| 35 | + # Higher order: recursion by raising |
22 | 36 | @eval m @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N} |
23 | | - der_expr = $(QuoteNode(toexpr(term))) |
| 37 | + expr = $(QuoteNode(toexpr(term))) |
24 | 38 | f = $func |
25 | 39 | quote |
26 | 40 | $(Expr(:meta, :inline)) |
27 | 41 | z = TaylorScalar{T, N - 1}(t) |
28 | 42 | f0 = $f(value(t)[1]) |
29 | | - df = zero_tangent(z) + $der_expr |
| 43 | + df = zero_tangent(z) + $expr |
30 | 44 | $$raiser(f0, df, t) |
31 | 45 | end |
32 | 46 | end |
|
0 commit comments