Skip to content

Commit 9a1f965

Browse files
committed
Add PINN example
1 parent 08c8279 commit 9a1f965

4 files changed

Lines changed: 52 additions & 20 deletions

File tree

benchmark/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
33
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
45
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
56
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
67
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
78
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
8-
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
99
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
1010
TaylorSeries = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea"
1111
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

benchmark/pinn.jl

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,51 @@
1-
using TaylorDiff, Zygote
1+
using Flux
2+
using ChainRulesCore: @opt_out
3+
using TaylorDiff
4+
using Zygote
5+
using Plots
26

37
const input = 2
48
const hidden = 16
59

6-
struct PINN
7-
W₁
8-
b₁
9-
W₂
10-
b₂
10+
model = Chain(
11+
Dense(input => hidden, sin),
12+
Dense(hidden => hidden, sin),
13+
Dense(hidden => 1),
14+
first
15+
)
16+
trial(model, x) = model(x)
17+
18+
ε = cbrt(eps(Float32))
19+
ε₁ = [ε, 0]
20+
ε₂ = [0, ε]
21+
22+
M = 100
23+
data = [rand(input) for _ in 1:M]
24+
function loss_by_finitediff(model, x)
25+
error = (trial(model, x + ε₁) + trial(model, x - ε₁) + trial(model, x + ε₂) +
26+
trial(model, x - ε₂) - 4 * trial(model, x)) /
27+
ε^2 + sin* x[1]) * sin* x[2])
28+
abs2(error)
29+
end
30+
function loss_by_taylordiff(model, x)
31+
f(x) = trial(model, x)
32+
error = derivative(f, x, [1., 0.], 2) + derivative(f, x, [0., 1.], 2) + sin* x[1]) * sin* x[2])
33+
abs2(error)
1134
end
1235

13-
(pinn::PINN)(x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * first(pinn.W₂ * exp.(pinn.W₁ * x + pinn.b₁) + pinn.b₂)
36+
opt = Flux.setup(Adam(), model)
1437

15-
dataset = [rand(input) for i in 1:10]
16-
function loss(pinn)
17-
out = 0.0
18-
for x in dataset
19-
out += derivative(pinn, x, [1., 0.], Val(2))
20-
end
21-
out
38+
allloss(model, loss) = sum([loss(model, x) for x in data])
39+
for epoch in 1:1000
40+
Flux.train!(loss_by_taylordiff, model, data, opt)
2241
end
2342

24-
myPINN = PINN(rand(hidden, input), rand(hidden), rand(1, hidden), rand(1))
43+
grid = 0:0.01:1
44+
solution(x, y) = (sin* x) * sin* y)) / (2π^2)
45+
u = [trial(model, [x, y]) for x in grid, y in grid]
46+
utrue = [solution(x, y) for x in grid, y in grid]
47+
diff_u = abs.(u .- utrue)
2548

26-
gradient(loss, myPINN)
49+
surface(u)
50+
surface(utrue)
51+
surface(diff_u)

src/chainrules.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import ChainRulesCore: rrule, RuleConfig
1+
import ChainRulesCore: rrule, RuleConfig, ProjectTo
22
using ZygoteRules: @adjoint
33

44
contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N} = mapreduce(*, +, value(a), value(b))
55

66
NONLINEAR_UNARY_FUNCTIONS = Function[
77
exp, exp2, exp10, expm1,
88
log, log2, log10, log1p,
9+
inv, sqrt, cbrt,
910
sin, cos, tan, cot, sec, csc,
1011
asin, acos, atan, acot, asec, acsc,
1112
sinh, cosh, tanh, coth, sech, csch,
@@ -52,8 +53,9 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
5253
return extract_derivative(t, i), extract_derivative_pullback
5354
end
5455

55-
function rrule(::typeof(*), A::Matrix{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
56-
gemv_pullback(x̄) = NoTangent(), contract.(x̄, transpose(t)), transpose(A) *
56+
function rrule(::typeof(*), A::Matrix{S}, t::Vector{TaylorScalar{T, N}}) where {N, S <: Number, T}
57+
project_A = ProjectTo(A)
58+
gemv_pullback(x̄) = NoTangent(), project_A(contract.(x̄, transpose(t))), transpose(A) *
5759
return A * t, gemv_pullback
5860
end
5961

@@ -70,3 +72,7 @@ end
7072
@adjoint +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number} = t + v, x̄ -> (x̄, map(primal, x̄))
7173

7274
@adjoint +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number} = v + t, x̄ -> (map(primal, x̄), x̄)
75+
76+
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx)
77+
78+
(project::ProjectTo{S})(dx::TaylorScalar{T, N}) where {N, T <: Number, S <: Real} = project(primal(dx))

src/primitive.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import Base: hypot, max, min
1414
@inline sqrt(t::TaylorScalar) = t^0.5
1515
@inline cbrt(t::TaylorScalar) = ^(t, 1 / 3)
1616
@inline inv(t::TaylorScalar) = 1 / t
17+
@inline abs(t::TaylorScalar) = primal(t) >= 0 ? t : -t
1718

1819
for func in (:exp, :expm1, :exp2, :exp10)
1920
@eval @generated function $func(t::TaylorScalar{T, N}) where {T, N}

0 commit comments

Comments
 (0)