Skip to content

Commit 37c8657

Browse files
committed
Add more primitives
1 parent fb023e6 commit 37c8657

4 files changed

Lines changed: 30 additions & 10 deletions

File tree

src/codegen.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,31 @@ using SymbolicUtils: Pow
88
@scalar_rule rad2deg(x::Any) rad2deg(one(x))
99
@scalar_rule asin(x::Any) inv(sqrt(1 - x^2))
1010
@scalar_rule acos(x::Any) inv(-sqrt(1 - x^2))
11-
@scalar_rule atan(x::Any) inv(1 + x^2)
11+
@scalar_rule atan(x::Any) inv(-(1 + x^2))
12+
@scalar_rule acot(x::Any) inv(-(1 + x^2))
13+
@scalar_rule acsc(x::Any) inv(x^2 * -sqrt(1 - x^-2))
14+
@scalar_rule asec(x::Any) inv(x^2 * sqrt(1 - x^-2))
1215
@scalar_rule log(x::Any) inv(x)
1316
@scalar_rule log10(x::Any) inv(log(10.0) * x)
1417
@scalar_rule log1p(x::Any) inv(x + 1)
1518
@scalar_rule log2(x::Any) inv(log(2.0) * x)
19+
@scalar_rule sinh(x::Any) cosh(x)
20+
@scalar_rule cosh(x::Any) sinh(x)
21+
@scalar_rule tanh(x::Any) 1-Ω^2
22+
@scalar_rule acosh(x::Any) inv(sqrt(x - 1) * sqrt(x + 1))
23+
@scalar_rule acoth(x::Any) inv(1 - x^2)
24+
@scalar_rule acsch(x::Any) inv(x^2 * -sqrt(1 + x^-2))
25+
@scalar_rule asech(x::Any) inv(x * -sqrt(1 - x^2))
26+
@scalar_rule asinh(x::Any) inv(sqrt(x^2 + 1))
27+
@scalar_rule atanh(x::Any) inv(1 - x^2)
1628

1729
dummy = (NoTangent(), 1)
1830
@syms t₁
19-
for func in (+, -, deg2rad, rad2deg, asin, acos, atan, log, log10, log1p, log2)
31+
for func in (+, -, deg2rad, rad2deg,
32+
sinh, cosh, tanh,
33+
asin, acos, atan, asec, acsc, acot,
34+
log, log10, log1p, log2,
35+
asinh, acosh, atanh, asech, acsch, acoth)
2036
F = typeof(func)
2137
# base case
2238
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}

src/primitive.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ end
7575

7676
for op in [:>, :<, :(==), :(>=), :(<=)]
7777
@eval @inline $op(a::Number, b::TaylorScalar) = $op(a, value(b)[1])
78-
@eval @inline $op(a::TaylorScalar, b::Number) = $op(a, value(b)[1])
78+
@eval @inline $op(a::TaylorScalar, b::Number) = $op(value(a)[1], b)
7979
@eval @inline $op(a::TaylorScalar, b::TaylorScalar) = $op(value(a)[1], value(b)[1])
8080
end
8181

@@ -158,7 +158,7 @@ end
158158
end
159159
end
160160

161-
raise(::T, df::S, t::TaylorScalar{T}) where {S <: Number, T <: Number} = df * t
161+
raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Real, T <: Number, N} = df * t
162162

163163
@generated function raiseinv(f::T, df::TaylorScalar{T, M},
164164
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N

src/scalar.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ end
5454

5555
@inline value(t::TaylorScalar) = t.value
5656
@inline extract_derivative(t::TaylorScalar, i::Integer) = t.value[i]
57+
@inline extract_derivative(r, i::Integer) = false
5758

5859
@inline zero(::Type{TaylorScalar{T, N}}) where {T, N} = TaylorScalar{T, N}(zero(T))
5960
@inline one(::Type{TaylorScalar{T, N}}) where {T, N} = TaylorScalar{T, N}(one(T))

test/primitive.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,34 @@
11
using FiniteDifferences
22

33
@testset "No derivative or linear" begin
4-
some_number = 3.7
5-
for f in (+, -, zero, one, adjoint, conj, deg2rad, rad2deg), order in (2, 4)
4+
some_number, another_number = 1.9, 2.6
5+
for f in (+, -, zero, one, adjoint, conj, deg2rad, rad2deg), order in (2,)
66
@test derivative(f, some_number, order) 0.0
77
end
8+
for f in (+, -, <, <=, >, >=, ==), order in (2,)
9+
@test derivative(x -> f(x, another_number), some_number, order) 0.0
10+
end
811
end
912

1013
@testset "Unary functions" begin
1114
some_number = 3.7
12-
for f in (exp, expm1, exp2, exp10, sin, cos, sqrt, cbrt, inv), order in (2, 4)
15+
for f in (exp, expm1, exp2, exp10, sin, cos, sqrt, cbrt, inv), order in (1, 4)
1316
fdm = central_fdm(12, order)
1417
@test derivative(f, some_number, order)fdm(f, some_number) rtol=1e-6
1518
end
1619
end
1720

1821
@testset "Codegen" begin
1922
some_number = 0.6
20-
for f in (log, ), order in (2, 4)
21-
fdm = central_fdm(12, order, max_range=0.5)
23+
for f in (log, sinh), order in (1, 4)
24+
fdm = central_fdm(12, order, max_range = 0.5)
2225
@test derivative(f, some_number, order)fdm(f, some_number) rtol=1e-6
2326
end
2427
end
2528

2629
@testset "Binary functions" begin
2730
some_number, another_number = 1.9, 2.6
28-
for f in (*, /), order in (2, 4)
31+
for f in (*, /), order in (1, 4)
2932
fdm = central_fdm(12, order)
3033
closure = x -> exp(f(x, another_number))
3134
@test derivative(closure, some_number, order)fdm(closure, some_number) rtol=1e-6

0 commit comments

Comments
 (0)