Skip to content

Commit 6e12560

Browse files
committed
"NNlib Extension"
1 parent 770c190 commit 6e12560

5 files changed

Lines changed: 26 additions & 2 deletions

File tree

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,19 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1212
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1313

1414
[weakdeps]
15+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1516
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1617

1718
[extensions]
19+
TaylorDiffNNlibExt = ["NNlib"]
1820
TaylorDiffSFExt = ["SpecialFunctions"]
1921

2022
[compat]
2123
ChainRules = "1"
2224
ChainRulesCore = "1"
2325
ChainRulesOverloadGeneration = "0.1"
2426
SpecialFunctions = "2"
27+
NNlib = "0.9"
2528
SymbolicUtils = "1"
2629
Symbolics = "5"
2730
Zygote = "0.6.55"

ext/TaylorDiffNNlibExt.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
module TaylorDiffNNlibExt
2+
3+
using TaylorDiff
4+
import NNlib: oftf
5+
import NNlib: sigmoid_fast, tanh_fast, rrelu, leakyrelu
6+
7+
println("revise!")
8+
9+
@inline sigmoid_fast(t::TaylorScalar) = one(t) / (one(t) + exp(-t))
10+
11+
@inline tanh_fast(t::TaylorScalar) = tanh(t)
12+
13+
@inline function rrelu(t::TaylorScalar{T, N}, l=oftf(t, 1/8), u=oftf(t, 1/3)) where {T, N}
14+
a = (u - l) * rand(float(T)) + l
15+
return leakyrelu(t, a)
16+
end
17+
18+
end

src/codegen.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ for func in (+, -, deg2rad, rad2deg,
1111
asin, acos, atan, asec, acsc, acot,
1212
log, log10, log1p, log2,
1313
asinh, acosh, atanh, asech, acsch,
14-
acoth)
14+
acoth,
15+
abs, sign)
1516
F = typeof(func)
1617
# base case
1718
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}

src/derivative.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ end
3838
# Core APIs
3939

4040
# Added to help Zygote infer types
41-
make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1))
41+
make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, convert(T, t1))
4242

4343
@inline function derivative(f, x::T, ::Val{N}) where {T <: TN, N}
4444
t = TaylorScalar{T, N}(x, one(x))

src/scalar.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,5 @@ for op in (:+, :-, :*, :/)
8888
@eval @inline $op(a::Number, b::TaylorScalar) = $op(promote(a, b)...)
8989
end
9090
transpose(t::TaylorScalar) = t
91+
92+
Base.AbstractFloat(x::TaylorScalar{T, N}) where {T, N} = TaylorScalar{Float64, N}(convert(NTuple{N, Float64}, x.value))

0 commit comments

Comments
 (0)