Skip to content

Commit a3e588a

Browse files
committed
Fix trial function
1 parent 9a1f965 commit a3e588a

1 file changed

Lines changed: 4 additions & 6 deletions

File tree

benchmark/pinn.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Flux
2-
using ChainRulesCore: @opt_out
32
using TaylorDiff
43
using Zygote
54
using Plots
@@ -13,15 +12,14 @@ model = Chain(
1312
Dense(hidden => 1),
1413
first
1514
)
16-
trial(model, x) = model(x)
17-
18-
ε = cbrt(eps(Float32))
19-
ε₁ = [ε, 0]
20-
ε₂ = [0, ε]
15+
trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x)
2116

2217
M = 100
2318
data = [rand(input) for _ in 1:M]
2419
function loss_by_finitediff(model, x)
20+
ε = cbrt(eps(Float32))
21+
ε₁ = [ε, 0]
22+
ε₂ = [0, ε]
2523
error = (trial(model, x + ε₁) + trial(model, x - ε₁) + trial(model, x + ε₂) +
2624
trial(model, x - ε₂) - 4 * trial(model, x)) /
2725
ε^2 + sin* x[1]) * sin* x[2])

0 commit comments

Comments
 (0)