Skip to content

Commit 08c8279

Browse files
committed
Fix gradient for gemv and vadd
1 parent 6b34351 commit 08c8279

4 files changed

Lines changed: 56 additions & 13 deletions

File tree

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
1010
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
11+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1112

1213
[compat]
1314
ChainRules = "1"
1415
ChainRulesCore = "1"
1516
ChainRulesOverloadGeneration = "0.1"
1617
SymbolicUtils = "0.19, 0.20"
18+
ZygoteRules = "0.2"
1719
julia = "1.6"

benchmark/linearmodel.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using TaylorDiff, Zygote
2+
using LinearAlgebra: dot
3+
4+
sigmoid(x) = (one(x) + tanh(x)) / 2
5+
6+
struct LinearModel
7+
c::Matrix
8+
end
9+
10+
(m::LinearModel)(x) = sigmoid((m.c * x)[1])
11+
12+
data = rand(2)
13+
14+
function loss(model)
15+
derivative(model, data, [1., 0.], Val(2))
16+
end
17+
18+
model = LinearModel(rand(1, 2))
19+
gradient(loss, model)

benchmark/pinn.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using TaylorDiff
1+
using TaylorDiff, Zygote
22

33
const input = 2
44
const hidden = 16
@@ -10,7 +10,7 @@ struct PINN
1010
b₂
1111
end
1212

13-
(pinn::PINN)(x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * first(pinn.W₂ * (pinn.W₁ * x + pinn.b₁) + pinn.b₂)
13+
(pinn::PINN)(x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * first(pinn.W₂ * exp.(pinn.W₁ * x + pinn.b₁) + pinn.b₂)
1414

1515
dataset = [rand(input) for i in 1:10]
1616
function loss(pinn)
@@ -21,12 +21,6 @@ function loss(pinn)
2121
out
2222
end
2323

24-
# function simple(w)
25-
# derivative(x -> sum(w * x), [0.5, 0.7], [1., 0.], Val(2))
26-
# end
27-
# w = rand(hidden, input)
28-
# gradient(simple, w)
29-
3024
myPINN = PINN(rand(hidden, input), rand(hidden), rand(1, hidden), rand(1))
3125

3226
gradient(loss, myPINN)

src/chainrules.jl

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,33 @@
11
import ChainRulesCore: rrule, RuleConfig
2+
using ZygoteRules: @adjoint
3+
4+
contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N} = mapreduce(*, +, value(a), value(b))
5+
6+
NONLINEAR_UNARY_FUNCTIONS = Function[
7+
exp, exp2, exp10, expm1,
8+
log, log2, log10, log1p,
9+
sin, cos, tan, cot, sec, csc,
10+
asin, acos, atan, acot, asec, acsc,
11+
sinh, cosh, tanh, coth, sech, csch,
12+
asinh, acosh, atanh, acoth, asech, acsch,
13+
]
14+
15+
for func in NONLINEAR_UNARY_FUNCTIONS
16+
@eval @opt_out rrule(::typeof($func), ::TaylorScalar)
17+
end
18+
19+
NONLINEAR_BINARY_FUNCTIONS = Function[
20+
*, /, ^
21+
]
22+
23+
for func in NONLINEAR_BINARY_FUNCTIONS
24+
@eval @opt_out rrule(::typeof($func), ::TaylorScalar, ::TaylorScalar)
25+
@eval @opt_out rrule(::typeof($func), ::TaylorScalar, ::Number)
26+
@eval @opt_out rrule(::typeof($func), ::Number, ::TaylorScalar)
27+
end
28+
29+
# Other special cases
230

3-
@opt_out rrule(::Any, ::TaylorScalar, ::TaylorScalar)
4-
@opt_out rrule(::Any, ::TaylorScalar, ::Any)
5-
@opt_out rrule(::typeof(*), ::TaylorScalar, ::TaylorScalar)
6-
@opt_out rrule(::typeof(^), ::TaylorScalar, ::Any)
731
@opt_out rrule(::typeof(Base.literal_pow), ::typeof(^), x::TaylorScalar, ::Val{p}) where {p}
832
@opt_out rrule(::RuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::TaylorScalar,
933
::Val{p}) where {p}
@@ -29,7 +53,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
2953
end
3054

3155
function rrule(::typeof(*), A::Matrix{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
32-
gemv_pullback(x̄) = NoTangent(), map(primal, x̄) * transpose(map(primal, t)), transpose(A) *
56+
gemv_pullback(x̄) = NoTangent(), contract.(x̄, transpose(t)), transpose(A) *
3357
return A * t, gemv_pullback
3458
end
3559

@@ -42,3 +66,7 @@ function rrule(::typeof(+), t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {
4266
vadd_pullback(x̄) = NoTangent(), x̄, map(primal, x̄)
4367
return t + v, vadd_pullback
4468
end
69+
70+
@adjoint +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number} = t + v, x̄ -> (x̄, map(primal, x̄))
71+
72+
@adjoint +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number} = v + t, x̄ -> (map(primal, x̄), x̄)

0 commit comments

Comments
 (0)