Skip to content

Commit 9b53db8

Browse files
committed
Improve coverage
1 parent b37aa31 commit 9b53db8

5 files changed

Lines changed: 50 additions & 38 deletions

File tree

.github/workflows/CI.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ on:
33
push:
44
branches:
55
- main
6-
tags: '*'
6+
tags: ['*']
77
pull_request:
88
concurrency:
99
# Skip intermediate builds: always.
@@ -36,6 +36,7 @@ jobs:
3636
- uses: julia-actions/julia-processcoverage@v1
3737
- uses: codecov/codecov-action@v2
3838
with:
39+
token: ${{ secrets.CODECOV_TOKEN }}
3940
files: lcov.info
4041
docs:
4142
name: Documentation

benchmark/linearmodel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ end
1212
data = rand(2)
1313

1414
function loss(model)
15-
derivative(model, data, [1., 0.], Val(2))
15+
derivative(model, data, [1.0, 0.0], Val(2))
1616
end
1717

1818
model = LinearModel(rand(1, 2))

benchmark/pinn.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@ using Plots
66
const input = 2
77
const hidden = 16
88

9-
model = Chain(
10-
Dense(input => hidden, sin),
11-
Dense(hidden => hidden, sin),
12-
Dense(hidden => 1),
13-
first
14-
)
9+
model = Chain(Dense(input => hidden, sin),
10+
Dense(hidden => hidden, sin),
11+
Dense(hidden => 1),
12+
first)
1513
trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x)
1614

1715
M = 100
@@ -27,7 +25,8 @@ function loss_by_finitediff(model, x)
2725
end
2826
function loss_by_taylordiff(model, x)
2927
f(x) = trial(model, x)
30-
error = derivative(f, x, Float32[1, 0], 2) + derivative(f, x, Float32[0, 1], 2) + sin* x[1]) * sin* x[2])
28+
error = derivative(f, x, Float32[1, 0], 2) + derivative(f, x, Float32[0, 1], 2) +
29+
sin* x[1]) * sin* x[2])
3130
abs2(error)
3231
end
3332

src/chainrules.jl

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,23 @@
11
import ChainRulesCore: rrule, RuleConfig, ProjectTo
22
using ZygoteRules: @adjoint
33

4-
contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N} = mapreduce(*, +, value(a), value(b))
5-
6-
NONLINEAR_UNARY_FUNCTIONS = Function[
7-
exp, exp2, exp10, expm1,
8-
log, log2, log10, log1p,
9-
inv, sqrt, cbrt,
10-
sin, cos, tan, cot, sec, csc,
11-
asin, acos, atan, acot, asec, acsc,
12-
sinh, cosh, tanh, coth, sech, csch,
13-
asinh, acosh, atanh, acoth, asech, acsch,
14-
]
4+
function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
5+
mapreduce(*, +, value(a), value(b))
6+
end
7+
8+
NONLINEAR_UNARY_FUNCTIONS = Function[exp, exp2, exp10, expm1,
9+
log, log2, log10, log1p,
10+
inv, sqrt, cbrt,
11+
sin, cos, tan, cot, sec, csc,
12+
asin, acos, atan, acot, asec, acsc,
13+
sinh, cosh, tanh, coth, sech, csch,
14+
asinh, acosh, atanh, acoth, asech, acsch]
1515

1616
for func in NONLINEAR_UNARY_FUNCTIONS
1717
@eval @opt_out rrule(::typeof($func), ::TaylorScalar)
1818
end
1919

20-
NONLINEAR_BINARY_FUNCTIONS = Function[
21-
*, /, ^
22-
]
20+
NONLINEAR_BINARY_FUNCTIONS = Function[*, /, ^]
2321

2422
for func in NONLINEAR_BINARY_FUNCTIONS
2523
@eval @opt_out rrule(::typeof($func), ::TaylorScalar, ::TaylorScalar)
@@ -49,31 +47,43 @@ end
4947
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
5048
i::Integer) where {N, T <: Number}
5149
function extract_derivative_pullback(d̄)
52-
NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ?: zero(T), Val(N))), NoTangent()
50+
NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ?: zero(T), Val(N))),
51+
NoTangent()
5352
end
5453
return extract_derivative(t, i), extract_derivative_pullback
5554
end
5655

57-
function rrule(::typeof(*), A::Matrix{S}, t::Vector{TaylorScalar{T, N}}) where {N, S <: Number, T}
56+
function rrule(::typeof(*), A::Matrix{S},
57+
t::Vector{TaylorScalar{T, N}}) where {N, S <: Number, T}
5858
project_A = ProjectTo(A)
59-
gemv_pullback(x̄) = NoTangent(), @thunk(project_A(contract.(x̄, transpose(t)))), @thunk(transpose(A) * x̄)
59+
function gemv_pullback(x̄)
60+
NoTangent(), @thunk(project_A(contract.(x̄, transpose(t)))), @thunk(transpose(A)*x̄)
61+
end
6062
return A * t, gemv_pullback
6163
end
6264

63-
function rrule(::typeof(+), v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
65+
function rrule(::typeof(+), v::Vector{T},
66+
t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
6467
vadd_pullback(x̄) = NoTangent(), ProjectTo(v)(x̄), x̄
6568
return v + t, vadd_pullback
6669
end
6770

68-
function rrule(::typeof(+), t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number}
71+
function rrule(::typeof(+), t::Vector{TaylorScalar{T, N}},
72+
v::Vector{T}) where {N, T <: Number}
6973
vadd_pullback(x̄) = NoTangent(), x̄, ProjectTo(v)(x̄)
7074
return t + v, vadd_pullback
7175
end
7276

73-
@adjoint +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number} = t + v, x̄ -> (x̄, map(primal, x̄))
77+
@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number}
78+
t + v, x̄ -> (x̄, map(primal, x̄))
79+
end
7480

75-
@adjoint +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number} = v + t, x̄ -> (map(primal, x̄), x̄)
81+
@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
82+
v + t, x̄ -> (map(primal, x̄), x̄)
83+
end
7684

7785
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx)
7886

79-
(project::ProjectTo{S})(dx::TaylorScalar{T, N}) where {N, T <: Number, S <: Real} = project(primal(dx))
87+
function (project::ProjectTo{S})(dx::TaylorScalar{T, N}) where {N, T <: Number, S <: Real}
88+
project(primal(dx))
89+
end

test/zygote.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,24 @@ using Zygote
33
@testset "Zygote for mixed derivative" begin
44
some_number = 0.7
55
for f in (exp, log, sqrt, sin, asin, sinh, asinh)
6-
@test gradient(x -> derivative(f, x, 2), some_number)[1] derivative(f, some_number, 3)
6+
@test gradient(x -> derivative(f, x, 2), some_number)[1]
7+
derivative(f, some_number, 3)
78
end
89
@test gradient(x -> derivative(x -> x * x, x, 1),
9-
5.0)[1] 2.0
10+
5.0)[1] 2.0
1011

1112
g(x) = x[1] * x[1] + x[2] * x[2]
12-
@test gradient(x -> derivative(g, x, [1., 0.], 1), [1., 2.])[1] [2., 0.]
13+
@test gradient(x -> derivative(g, x, [1.0, 0.0], 1), [1.0, 2.0])[1] [2.0, 0.0]
1314
end
1415

1516
@testset "Zygote for parameter optimization" begin
1617
linear_model(x, p) = exp.(p * x)[1]
17-
some_x, some_v, some_p = [.58, .36], [.23, .11], [.49 .96]
18+
some_x, some_v, some_p = [0.58, 0.36], [0.23, 0.11], [0.49 0.96]
1819
loss_taylor(p) = derivative(x -> linear_model(x, p), some_x, some_v, 1)
1920
ε = cbrt(eps(Float64))
20-
loss_finite(p) = let f = x -> linear_model(x, p)
21-
(f(some_x + ε * some_v) - f(some_x - ε * some_v)) / 2ε
22-
end
21+
loss_finite(p) =
22+
let f = x -> linear_model(x, p)
23+
(f(some_x + ε * some_v) - f(some_x - ε * some_v)) / 2ε
24+
end
2325
@test gradient(loss_taylor, some_p)[1] gradient(loss_finite, some_p)[1]
2426
end

0 commit comments

Comments
 (0)