Skip to content

Commit 1005d83

Browse files
committed
Improve log-like function
1 parent e6ab873 commit 1005d83

1 file changed

Lines changed: 25 additions & 11 deletions

File tree

src/utils.jl

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,46 @@
11
using ChainRules
22
using ChainRulesCore
3-
using Symbolics: @variables
4-
using SymbolicUtils, SymbolicUtils.Code
5-
using SymbolicUtils: Pow
3+
using Symbolics: @variables, @rule, unwrap, isdiv
4+
using SymbolicUtils.Code: toexpr
65

7-
dummy = (NoTangent(), 1)
8-
@variables z
6+
"""
7+
Pick a strategy for raising the derivative of a function. If the derivative is like 1 over something, raise with the division rule; otherwise, raise with the multiplication rule.
8+
"""
9+
function get_term_raiser(func)
10+
@variables z
11+
r1 = @rule -1 * (1 / ~x) => (-1) / ~x
12+
der = frule((NoTangent(), true), func, z)[2]
13+
term = unwrap(der)
14+
maybe_rewrite = r1(term)
15+
if maybe_rewrite !== nothing
16+
term = maybe_rewrite
17+
end
18+
if isdiv(term) && (term.num == 1 || term.num == -1)
19+
term.den * term.num, raiseinv
20+
else
21+
term, raise
22+
end
23+
end
924

1025
function define_unary_function(func, m)
1126
F = typeof(func)
12-
# base case
27+
# First order: call frule directly
1328
@eval m function (op::$F)(t::TaylorScalar{T, 1}) where {T}
1429
t0 = value(t)
1530
t1 = first(partials(t))
1631
f0, f1 = frule((NoTangent(), t1), op, t0)
1732
TaylorScalar{T, 1}(f0, zero_tangent(f0) + f1)
1833
end
19-
der = frule(dummy, func, z)[2]
20-
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
21-
# recursion by raising
34+
term, raiser = get_term_raiser(func)
35+
# Higher order: recursion by raising
2236
@eval m @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
23-
der_expr = $(QuoteNode(toexpr(term)))
37+
expr = $(QuoteNode(toexpr(term)))
2438
f = $func
2539
quote
2640
$(Expr(:meta, :inline))
2741
z = TaylorScalar{T, N - 1}(t)
2842
f0 = $f(value(t)[1])
29-
df = zero_tangent(z) + $der_expr
43+
df = zero_tangent(z) + $expr
3044
$$raiser(f0, df, t)
3145
end
3246
end

0 commit comments

Comments
 (0)