Skip to content

Commit b1aaf45

Browse files
committed
Improve coverage
1 parent 9b53db8 commit b1aaf45

3 files changed

Lines changed: 10 additions & 24 deletions

File tree

src/chainrules.jl

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ end
3838

3939
function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
4040
value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(v̄)
41-
value_pullback(v̄::Tuple) = NoTangent(), TaylorScalar(map(x -> convert(T, x), v̄))
4241
# for structural tangent, convert to tuple
4342
value_pullback(v̄) = NoTangent(), TaylorScalar(map(x -> convert(T, x), Tuple(v̄)))
4443
return value(t), value_pullback
@@ -62,28 +61,14 @@ function rrule(::typeof(*), A::Matrix{S},
6261
return A * t, gemv_pullback
6362
end
6463

65-
function rrule(::typeof(+), v::Vector{T},
66-
t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
67-
vadd_pullback(x̄) = NoTangent(), ProjectTo(v)(x̄), x̄
68-
return v + t, vadd_pullback
69-
end
70-
71-
function rrule(::typeof(+), t::Vector{TaylorScalar{T, N}},
72-
v::Vector{T}) where {N, T <: Number}
73-
vadd_pullback(x̄) = NoTangent(), x̄, ProjectTo(v)(x̄)
74-
return t + v, vadd_pullback
75-
end
76-
7764
@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number}
78-
t + v, x̄ -> (x̄, map(primal, x̄))
65+
project_v = ProjectTo(v)
66+
t + v, x̄ -> (x̄, project_v(x̄))
7967
end
8068

8169
@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
82-
v + t, x̄ -> (map(primal, x̄), x̄)
70+
project_v = ProjectTo(v)
71+
v + t, x̄ -> (project_v(x̄), x̄)
8372
end
8473

8574
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx)
86-
87-
function (project::ProjectTo{S})(dx::TaylorScalar{T, N}) where {N, T <: Number, S <: Real}
88-
project(primal(dx))
89-
end

src/primitive.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import Base: hypot, max, min
1414
@inline sqrt(t::TaylorScalar) = t^0.5
1515
@inline cbrt(t::TaylorScalar) = ^(t, 1 / 3)
1616
@inline inv(t::TaylorScalar) = one(t) / t
17-
@inline abs(t::TaylorScalar) = primal(t) >= 0 ? t : -t
1817

1918
for func in (:exp, :expm1, :exp2, :exp10)
2019
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}

test/zygote.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ using Zygote
1414
end
1515

1616
@testset "Zygote for parameter optimization" begin
17-
linear_model(x, p) = exp.(p * x)[1]
18-
some_x, some_v, some_p = [0.58, 0.36], [0.23, 0.11], [0.49 0.96]
19-
loss_taylor(p) = derivative(x -> linear_model(x, p), some_x, some_v, 1)
17+
gradient(p -> derivative(x -> sum(exp.(x + p)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7])
18+
gradient(p -> derivative(x -> sum(exp.(p + x)), [1.0, 1.0], [1.0, 0.0], 1), [0.5, 0.7])
19+
linear_model(x, p, b) = exp.(b + p * x + b)[1]
20+
some_x, some_v, some_p, some_b = [0.58, 0.36], [0.23, 0.11], [0.49 0.96], [0.88]
21+
loss_taylor(p) = derivative(x -> linear_model(x, p, some_b), some_x, some_v, 1)
2022
ε = cbrt(eps(Float64))
2123
loss_finite(p) =
22-
let f = x -> linear_model(x, p)
24+
let f = x -> linear_model(x, p, some_b)
2325
(f(some_x + ε * some_v) - f(some_x - ε * some_v)) / 2ε
2426
end
2527
@test gradient(loss_taylor, some_p)[1] gradient(loss_finite, some_p)[1]

0 commit comments

Comments
 (0)