Skip to content

Commit 6f0fcd5

Browse files
committed
Remove redundant frule definitions
1 parent 61326ef commit 6f0fcd5

2 files changed

Lines changed: 8 additions & 30 deletions

File tree

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1111
SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8"
1212
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1313
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
14+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1415
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1516

1617
[compat]
1718
ChainRules = "1"
1819
ChainRulesCore = "1"
1920
ChainRulesOverloadGeneration = "0.1"
21+
IrrationalConstants = "0.2"
2022
SliceMap = "0.2"
2123
SpecialFunctions = "2"
22-
IrrationalConstants = "0.2"
2324
SymbolicUtils = "1"
2425
Zygote = "0.6.55"
2526
julia = "1.6"

src/codegen.jl

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,13 @@
1+
using ChainRules
12
using ChainRulesCore
23
using SpecialFunctions
34
using IrrationalConstants: sqrtπ
5+
using Symbolics: @variables
46
using SymbolicUtils, SymbolicUtils.Code
5-
using SymbolicUtils: BasicSymbolic, Pow
6-
7-
@scalar_rule +(x::BasicSymbolic) true
8-
@scalar_rule -(x::BasicSymbolic) -1
9-
@scalar_rule deg2rad(x::BasicSymbolic) deg2rad(one(x))
10-
@scalar_rule rad2deg(x::BasicSymbolic) rad2deg(one(x))
11-
@scalar_rule asin(x::BasicSymbolic) inv(sqrt(1 - x^2))
12-
@scalar_rule acos(x::BasicSymbolic) inv(-sqrt(1 - x^2))
13-
@scalar_rule atan(x::BasicSymbolic) inv(-(1 + x^2))
14-
@scalar_rule acot(x::BasicSymbolic) inv(-(1 + x^2))
15-
@scalar_rule acsc(x::BasicSymbolic) inv(x^2 * -sqrt(1 - x^-2))
16-
@scalar_rule asec(x::BasicSymbolic) inv(x^2 * sqrt(1 - x^-2))
17-
@scalar_rule log(x::BasicSymbolic) inv(x)
18-
@scalar_rule log10(x::BasicSymbolic) inv(log(10.0) * x)
19-
@scalar_rule log1p(x::BasicSymbolic) inv(x + 1)
20-
@scalar_rule log2(x::BasicSymbolic) inv(log(2.0) * x)
21-
@scalar_rule sinh(x::BasicSymbolic) cosh(x)
22-
@scalar_rule cosh(x::BasicSymbolic) sinh(x)
23-
@scalar_rule tanh(x::BasicSymbolic) 1-Ω^2
24-
@scalar_rule acosh(x::BasicSymbolic) inv(sqrt(x - 1) * sqrt(x + 1))
25-
@scalar_rule acoth(x::BasicSymbolic) inv(1 - x^2)
26-
@scalar_rule acsch(x::BasicSymbolic) inv(x^2 * -sqrt(1 + x^-2))
27-
@scalar_rule asech(x::BasicSymbolic) inv(x * -sqrt(1 - x^2))
28-
@scalar_rule asinh(x::BasicSymbolic) inv(sqrt(x^2 + 1))
29-
@scalar_rule atanh(x::BasicSymbolic) inv(1 - x^2)
30-
@scalar_rule erf(x::BasicSymbolic) exp(-x^2) * 2/sqrtπ
7+
using SymbolicUtils: Pow
318

329
dummy = (NoTangent(), 1)
33-
@syms t₁
10+
@variables z
3411
for func in (+, -, deg2rad, rad2deg,
3512
sinh, cosh, tanh,
3613
asin, acos, atan, asec, acsc, acot,
@@ -43,15 +20,15 @@ for func in (+, -, deg2rad, rad2deg,
4320
t0, t1 = value(t)
4421
TaylorScalar{T, 2}(frule((NoTangent(), t1), op, t0))
4522
end
46-
der = frule(dummy, func, t₁)[2]
23+
der = frule(dummy, func, z)[2]
4724
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
4825
# recursion by raising
4926
@eval @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
5027
der_expr = $(QuoteNode(toexpr(term)))
5128
f = $func
5229
quote
5330
$(Expr(:meta, :inline))
54-
t₁ = TaylorScalar{T, N - 1}(t)
31+
z = TaylorScalar{T, N - 1}(t)
5532
df = $der_expr
5633
$$raiser($f(value(t)[1]), df, t)
5734
end

0 commit comments

Comments
 (0)