Skip to content

Commit 36737d5

Browse files
committed
Experiment with Zygote codegen
1 parent b1aaf45 commit 36737d5

4 files changed

Lines changed: 26 additions & 11 deletions

File tree

benchmark/pinn.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@ using Plots
66
const input = 2
77
const hidden = 16
88

9-
model = Chain(Dense(input => hidden, sin),
10-
Dense(hidden => hidden, sin),
11-
Dense(hidden => 1),
12-
first)
13-
trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x)
9+
# model = Chain(Dense(input => hidden, exp),
10+
# Dense(hidden => hidden, exp),
11+
# Dense(hidden => 1),
12+
# first)
13+
# trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x)
14+
15+
model = Chain(Dense(input => 1, exp), first)
16+
trial(model, x) = model(x)
1417

1518
M = 100
1619
data = [rand(Float32, input) for _ in 1:M]
@@ -25,7 +28,7 @@ function loss_by_finitediff(model, x)
2528
end
2629
function loss_by_taylordiff(model, x)
2730
f(x) = trial(model, x)
28-
error = derivative(f, x, Float32[1, 0], 2) + derivative(f, x, Float32[0, 1], 2) +
31+
error = derivative(f, x, Float32[1, 0], Val(3)) + derivative(f, x, Float32[0, 1], Val(3)) +
2932
sin* x[1]) * sin* x[2])
3033
abs2(error)
3134
end

src/chainrules.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import ChainRulesCore: rrule, RuleConfig, ProjectTo
1+
import ChainRulesCore: rrule, RuleConfig, ProjectTo, backing
22
using ZygoteRules: @adjoint
33

44
function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
@@ -39,7 +39,8 @@ end
3939
function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
4040
value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(v̄)
4141
# for structural tangent, convert to tuple
42-
value_pullback(v̄) = NoTangent(), TaylorScalar(map(x -> convert(T, x), Tuple(v̄)))
42+
value_pullback(v̄::Tangent{P, NTuple{N, T}}) where P = NoTangent(), TaylorScalar{T, N}(backing(v̄))
43+
value_pullback(v̄) = NoTangent(), TaylorScalar{T, N}(map(x -> convert(T, x), Tuple(v̄)))
4344
return value(t), value_pullback
4445
end
4546

src/derivative.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ end
2828
return extract_derivative(f(t), N)
2929
end
3030

31+
# Need to rewrite like this to help Zygote infer types
32+
make_taylor(t0::T, t1::T, ::Val{N}) where {T, N} = TaylorScalar{T, N}(t0, t1)
33+
3134
@inline function derivative(f, x::Vector{T}, l::Vector{T},
32-
::Val{N}) where {T <: Number, N}
33-
t = map(TaylorScalar{T, N}, x, l)
35+
vN::Val{N}) where {T <: Number, N}
36+
t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l) # i.e. map(TaylorScalar{T, N}, x, l)
3437
return extract_derivative(f(t), N)
3538
end

src/primitive.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ import Base: hypot, max, min
1515
@inline cbrt(t::TaylorScalar) = ^(t, 1 / 3)
1616
@inline inv(t::TaylorScalar) = one(t) / t
1717

18+
exp(t::TaylorScalar{T, 2}) where T = let v = value(t), e1 = exp(v[1])
19+
TaylorScalar{T, 2}((e1, e1 * v[2]))
20+
end
21+
22+
exp(t::TaylorScalar{T, 3}) where T = let v = value(t), e1 = exp(v[1])
23+
TaylorScalar{T, 3}((e1, e1 * v[2], e1 * v[3] + e1 * v[2] * v[2]))
24+
end
25+
1826
for func in (:exp, :expm1, :exp2, :exp10)
1927
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}
2028
ex = quote
@@ -37,7 +45,7 @@ for func in (:exp, :expm1, :exp2, :exp10)
3745
if $(QuoteNode(func)) == :expm1
3846
ex = :($ex; v1 = expm1(v[1]))
3947
end
40-
ex = :($ex; TaylorScalar($([Symbol('v', i) for i in 1:N]...)))
48+
ex = :($ex; TaylorScalar{T, N}(tuple($([Symbol('v', i) for i in 1:N]...))))
4149
return :(@inbounds $ex)
4250
end
4351
end

0 commit comments

Comments
 (0)