Skip to content

Commit b37aa31

Browse files
committed
Add tests for some rrules
1 parent a3e588a commit b37aa31

4 files changed

Lines changed: 28 additions & 11 deletions

File tree

benchmark/pinn.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ model = Chain(
1515
trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x)
1616

1717
M = 100
18-
data = [rand(input) for _ in 1:M]
18+
data = [rand(Float32, input) for _ in 1:M]
1919
function loss_by_finitediff(model, x)
2020
ε = cbrt(eps(Float32))
2121
ε₁ = [ε, 0]
@@ -27,7 +27,7 @@ function loss_by_finitediff(model, x)
2727
end
2828
function loss_by_taylordiff(model, x)
2929
f(x) = trial(model, x)
30-
error = derivative(f, x, [1., 0.], 2) + derivative(f, x, [0., 1.], 2) + sin* x[1]) * sin* x[2])
30+
error = derivative(f, x, Float32[1, 0], 2) + derivative(f, x, Float32[0, 1], 2) + sin* x[1]) * sin* x[2])
3131
abs2(error)
3232
end
3333

src/chainrules.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,34 +38,35 @@ function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T <: Numbe
3838
return TaylorScalar(v), taylor_scalar_pullback
3939
end
4040

41-
function rrule(::typeof(value), t::TaylorScalar)
42-
value_pullback(v̄::NTuple) = NoTangent(), TaylorScalar(v̄)
41+
function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
42+
value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(v̄)
43+
value_pullback(v̄::Tuple) = NoTangent(), TaylorScalar(map(x -> convert(T, x), v̄))
4344
# for structural tangent, convert to tuple
44-
value_pullback(v̄) = NoTangent(), TaylorScalar(Tuple(v̄))
45+
value_pullback(v̄) = NoTangent(), TaylorScalar(map(x -> convert(T, x), Tuple(v̄)))
4546
return value(t), value_pullback
4647
end
4748

4849
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
4950
i::Integer) where {N, T <: Number}
5051
function extract_derivative_pullback(d̄)
51-
NoTangent(), TaylorScalar((zeros(T, i - 1)..., d̄, zeros(T, N - i)...)), NoTangent()
52+
NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ?: zero(T), Val(N))), NoTangent()
5253
end
5354
return extract_derivative(t, i), extract_derivative_pullback
5455
end
5556

5657
function rrule(::typeof(*), A::Matrix{S}, t::Vector{TaylorScalar{T, N}}) where {N, S <: Number, T}
5758
project_A = ProjectTo(A)
58-
gemv_pullback(x̄) = NoTangent(), project_A(contract.(x̄, transpose(t))), transpose(A) *
59+
gemv_pullback(x̄) = NoTangent(), @thunk(project_A(contract.(x̄, transpose(t)))), @thunk(transpose(A) *)
5960
return A * t, gemv_pullback
6061
end
6162

6263
function rrule(::typeof(+), v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
63-
vadd_pullback(x̄) = NoTangent(), map(primal, x̄), x̄
64+
vadd_pullback(x̄) = NoTangent(), ProjectTo(v)(x̄), x̄
6465
return v + t, vadd_pullback
6566
end
6667

6768
function rrule(::typeof(+), t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number}
68-
vadd_pullback(x̄) = NoTangent(), x̄, map(primal, x̄)
69+
vadd_pullback(x̄) = NoTangent(), x̄, ProjectTo(v)(x̄)
6970
return t + v, vadd_pullback
7071
end
7172

src/primitive.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import Base: hypot, max, min
1313

1414
@inline sqrt(t::TaylorScalar) = t^0.5
1515
@inline cbrt(t::TaylorScalar) = ^(t, 1 / 3)
16-
@inline inv(t::TaylorScalar) = 1 / t
16+
@inline inv(t::TaylorScalar) = one(t) / t
1717
@inline abs(t::TaylorScalar) = primal(t) >= 0 ? t : -t
1818

1919
for func in (:exp, :expm1, :exp2, :exp10)

test/zygote.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
11
using Zygote
22

3-
@testset "Zygote compatibility" begin @test gradient(x -> derivative(x -> x * x, x, 1),
3+
@testset "Zygote for mixed derivative" begin
4+
some_number = 0.7
5+
for f in (exp, log, sqrt, sin, asin, sinh, asinh)
6+
@test gradient(x -> derivative(f, x, 2), some_number)[1] derivative(f, some_number, 3)
7+
end
8+
@test gradient(x -> derivative(x -> x * x, x, 1),
49
5.0)[1] 2.0
510

611
g(x) = x[1] * x[1] + x[2] * x[2]
712
@test gradient(x -> derivative(g, x, [1., 0.], 1), [1., 2.])[1] [2., 0.]
813
end
14+
15+
@testset "Zygote for parameter optimization" begin
16+
linear_model(x, p) = exp.(p * x)[1]
17+
some_x, some_v, some_p = [.58, .36], [.23, .11], [.49 .96]
18+
loss_taylor(p) = derivative(x -> linear_model(x, p), some_x, some_v, 1)
19+
ε = cbrt(eps(Float64))
20+
loss_finite(p) = let f = x -> linear_model(x, p)
21+
(f(some_x + ε * some_v) - f(some_x - ε * some_v)) / 2ε
22+
end
23+
@test gradient(loss_taylor, some_p)[1] gradient(loss_finite, some_p)[1]
24+
end

0 commit comments

Comments
 (0)