Skip to content

Commit eff4ac7

Browse files
zhujch1tansongchen
authored andcommitted
SFext
1 parent 140fe5f commit eff4ac7

3 files changed

Lines changed: 40 additions & 5 deletions

File tree

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@ version = "0.2.1"
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
10-
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
11-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1210
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1311
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1412
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1513

14+
[weakdeps]
15+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
16+
17+
[extensions]
18+
TaylorDiffSFExt = ["SpecialFunctions"]
19+
1620
[compat]
1721
ChainRules = "1"
1822
ChainRulesCore = "1"

ext/TaylorDiffSFExt.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
module TaylorDiffSFExt
2+
using TaylorDiff, SpecialFunctions
3+
using Symbolics: @variables
4+
using SymbolicUtils, SymbolicUtils.Code
5+
using SymbolicUtils: Pow
6+
using TaylorDiff: value, raise
7+
using ChainRules, ChainRulesCore
8+
9+
dummy = (NoTangent(), 1)
10+
@variables z
11+
for func in (erf, )
12+
F = typeof(func)
13+
# base case
14+
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
15+
t0, t1 = value(t)
16+
TaylorScalar{T, 2}(frule((NoTangent(), t1), op, t0))
17+
end
18+
der = frule(dummy, func, z)[2]
19+
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
20+
# recursion by raising
21+
@eval @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
22+
der_expr = $(QuoteNode(toexpr(term)))
23+
f = $func
24+
quote
25+
$(Expr(:meta, :inline))
26+
z = TaylorScalar{T, N - 1}(t)
27+
df = $der_expr
28+
$$raiser($f(value(t)[1]), df, t)
29+
end
30+
end
31+
end
32+
33+
end

src/codegen.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
using ChainRules
22
using ChainRulesCore
3-
using SpecialFunctions
4-
using IrrationalConstants: sqrtπ
53
using Symbolics: @variables
64
using SymbolicUtils, SymbolicUtils.Code
75
using SymbolicUtils: Pow
@@ -13,7 +11,7 @@ for func in (+, -, deg2rad, rad2deg,
1311
asin, acos, atan, asec, acsc, acot,
1412
log, log10, log1p, log2,
1513
asinh, acosh, atanh, asech, acsch,
16-
acoth, erf)
14+
acoth)
1715
F = typeof(func)
1816
# base case
1917
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}

0 commit comments

Comments
 (0)