Skip to content

Commit 770c190

Browse files
zhujch1tansongchen
authored andcommitted
Format
1 parent 9a037b3 commit 770c190

7 files changed

Lines changed: 21 additions & 19 deletions

File tree

benchmark/mlp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
function create_benchmark_mlp(mlp_conf::Tuple{Int, Int}, x::Vector{T},
2-
l::Vector{T}) where {T <: Number}
2+
l::Vector{T}) where {T <: Number}
33
input, hidden = mlp_conf
44
W₁, W₂, b₁, b₂ = rand(hidden, input), rand(1, hidden), rand(hidden), rand(1)
55
σ = exp

ext/TaylorDiffSFExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using ChainRules, ChainRulesCore
88

99
dummy = (NoTangent(), 1)
1010
@variables z
11-
for func in (erf, )
11+
for func in (erf,)
1212
F = typeof(func)
1313
# base case
1414
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
@@ -30,4 +30,4 @@ for func in (erf, )
3030
end
3131
end
3232

33-
end
33+
end

src/chainrules.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
2222
end
2323

2424
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
25-
i::Integer) where {N, T}
25+
i::Integer) where {N, T}
2626
function extract_derivative_pullback(d̄)
2727
NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ?: zero(T), Val(N))),
2828
NoTangent()
@@ -31,7 +31,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
3131
end
3232

3333
function rrule(::typeof(*), A::AbstractMatrix{S},
34-
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
34+
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
3535
project_A = ProjectTo(A)
3636
function gemv_pullback(x̄)
3737
= reinterpret(reshape, T, x̄)
@@ -42,12 +42,14 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
4242
end
4343

4444
function rrule(::typeof(*), A::AbstractMatrix{S},
45-
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T}
45+
B::AbstractMatrix{TaylorScalar{T, N}}) where {N, S, T}
4646
project_A = ProjectTo(A)
4747
project_B = ProjectTo(B)
4848
function gemm_pullback(x̄)
4949
= unthunk(x̄)
50-
NoTangent(), @thunk(project_A(X̄ * transpose(B))), @thunk(project_B(transpose(A) * X̄))
50+
NoTangent(),
51+
@thunk(project_A(X̄ * transpose(B))),
52+
@thunk(project_B(transpose(A) * X̄))
5153
end
5254
return A * B, gemm_pullback
5355
end
@@ -85,8 +87,8 @@ struct TaylorOneElement{T, N, I, A} <: AbstractArray{T, N}
8587
ind::I
8688
axes::A
8789
function TaylorOneElement(val::T, ind::I,
88-
axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int},
89-
A <: NTuple{N, AbstractUnitRange}} where {N}
90+
axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int},
91+
A <: NTuple{N, AbstractUnitRange}} where {N}
9092
new{T, N, I, A}(val, ind, axes)
9193
end
9294
end
@@ -125,7 +127,7 @@ function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar)
125127
end
126128

127129
function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar,
128-
more::TaylorScalar...)
130+
more::TaylorScalar...)
129131
Ω2, back2 = rrule(*, x, y)
130132
Ω3, back3 = rrule(*, Ω2, z)
131133
Ω4, back4 = rrule(*, Ω3, more...)

src/derivative.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t
4646
end
4747

4848
@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S},
49-
vN::Val{N}) where {T <: TN, S <: TN, N}
49+
vN::Val{N}) where {T <: TN, S <: TN, N}
5050
t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l)
5151
# equivalent to map(TaylorScalar{T, N}, x, l)
5252
return extract_derivative(f(t), N)
@@ -61,7 +61,7 @@ end
6161
end
6262

6363
@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S},
64-
vN::Val{N}) where {T <: TN, S <: TN, N}
64+
vN::Val{N}) where {T <: TN, S <: TN, N}
6565
t = make_taylor.(x, l, vN)
6666
return extract_derivative.(f(t), N)
6767
end

src/primitive.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ end
148148
end
149149

150150
@generated function raise(f::T, df::TaylorScalar{T, M},
151-
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
151+
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
152152
return quote
153153
$(Expr(:meta, :inline))
154154
vdf, vt = value(df), value(t)
@@ -162,7 +162,7 @@ end
162162
raise(::T, df::S, t::TaylorScalar{T, N}) where {S <: Number, T, N} = df * t
163163

164164
@generated function raiseinv(f::T, df::TaylorScalar{T, M},
165-
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
165+
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
166166
ex = quote
167167
vdf, vt = value(df), value(t)
168168
v1 = vt[2] / vdf[1]

src/scalar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ end
5858
@inline value(t::TaylorScalar) = t.value
5959
@inline extract_derivative(t::TaylorScalar, i::Integer) = t.value[i]
6060
@inline function extract_derivative(v::AbstractArray{T},
61-
i::Integer) where {T <: TaylorScalar}
61+
i::Integer) where {T <: TaylorScalar}
6262
map(t -> extract_derivative(t, i), v)
6363
end
6464
@inline extract_derivative(r, i::Integer) = false
@@ -73,7 +73,7 @@ adjoint(t::TaylorScalar) = t
7373
conj(t::TaylorScalar) = t
7474

7575
function promote_rule(::Type{TaylorScalar{T, N}},
76-
::Type{S}) where {T, S, N}
76+
::Type{S}) where {T, S, N}
7777
TaylorScalar{promote_type(T, S), N}
7878
end
7979

test/derivative.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
g(x) = x^3
44
@test derivative(g, 1.0, 1) 3
55

6-
h(x) = x.^3
6+
h(x) = x .^ 3
77
@test derivative(h, [2.0 3.0], 1) [12.0 27.0]
88
end
99

1010
@testset "Directional derivative" begin
1111
g(x) = x[1] * x[1] + x[2] * x[2]
1212
@test derivative(g, [1.0, 2.0], [1.0, 0.0], 1) 2.0
1313

14-
h(x) = sum(x, dims=1)
15-
@test derivative(h, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) [2. 2.]
14+
h(x) = sum(x, dims = 1)
15+
@test derivative(h, [1.0 2.0; 2.0 3.0], [1.0, 1.0], 1) [2.0 2.0]
1616
end

0 commit comments

Comments
 (0)