Skip to content

Commit 6b34351

Browse files
committed
Add rrule for gemv
1 parent a6b0320 commit 6b34351

4 files changed

Lines changed: 51 additions & 2 deletions

File tree

benchmark/pinn.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using TaylorDiff
2+
3+
const input = 2
4+
const hidden = 16
5+
6+
struct PINN
7+
W₁
8+
b₁
9+
W₂
10+
b₂
11+
end
12+
13+
(pinn::PINN)(x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * first(pinn.W₂ * (pinn.W₁ * x + pinn.b₁) + pinn.b₂)
14+
15+
dataset = [rand(input) for i in 1:10]
16+
function loss(pinn)
17+
out = 0.0
18+
for x in dataset
19+
out += derivative(pinn, x, [1., 0.], Val(2))
20+
end
21+
out
22+
end
23+
24+
# function simple(w)
25+
# derivative(x -> sum(w * x), [0.5, 0.7], [1., 0.], Val(2))
26+
# end
27+
# w = rand(hidden, input)
28+
# gradient(simple, w)
29+
30+
myPINN = PINN(rand(hidden, input), rand(hidden), rand(1, hidden), rand(1))
31+
32+
gradient(loss, myPINN)

src/chainrules.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import ChainRulesCore: rrule, RuleConfig
22

3-
@opt_out rrule(::Any, ::TaylorScalar)
43
@opt_out rrule(::Any, ::TaylorScalar, ::TaylorScalar)
54
@opt_out rrule(::Any, ::TaylorScalar, ::Any)
65
@opt_out rrule(::typeof(*), ::TaylorScalar, ::TaylorScalar)
@@ -28,3 +27,18 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
2827
end
2928
return extract_derivative(t, i), extract_derivative_pullback
3029
end
30+
31+
function rrule(::typeof(*), A::Matrix{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
32+
gemv_pullback(x̄) = NoTangent(), map(primal, x̄) * transpose(map(primal, t)), transpose(A) *
33+
return A * t, gemv_pullback
34+
end
35+
36+
function rrule(::typeof(+), v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
37+
vadd_pullback(x̄) = NoTangent(), map(primal, x̄), x̄
38+
return v + t, vadd_pullback
39+
end
40+
41+
function rrule(::typeof(+), t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number}
42+
vadd_pullback(x̄) = NoTangent(), x̄, map(primal, x̄)
43+
return t + v, vadd_pullback
44+
end

src/scalar.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ end
5555
@inline value(t::TaylorScalar) = t.value
5656
@inline extract_derivative(t::TaylorScalar, i::Integer) = t.value[i]
5757
@inline extract_derivative(r, i::Integer) = false
58+
@inline primal(t::TaylorScalar) = extract_derivative(t, 1)
5859

5960
@inline zero(::Type{TaylorScalar{T, N}}) where {T, N} = TaylorScalar{T, N}(zero(T))
6061
@inline one(::Type{TaylorScalar{T, N}}) where {T, N} = TaylorScalar{T, N}(one(T))

test/zygote.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,7 @@ using Zygote
22

33
@testset "Zygote compatibility" begin @test gradient(x -> derivative(x -> x * x, x, 1),
44
5.0)[1] 2.0
5-
# @test gradient(x -> derivative(g, x, 1, 1), [1., 2.])[1] ≈ [2., 0.]
5+
6+
g(x) = x[1] * x[1] + x[2] * x[2]
7+
@test gradient(x -> derivative(g, x, [1., 0.], 1), [1., 2.])[1] [2., 0.]
68
end

0 commit comments

Comments
 (0)