We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9a1f965 commit a3e588aCopy full SHA for a3e588a
1 file changed
benchmark/pinn.jl
@@ -1,5 +1,4 @@
1
using Flux
2
-using ChainRulesCore: @opt_out
3
using TaylorDiff
4
using Zygote
5
using Plots
@@ -13,15 +12,14 @@ model = Chain(
13
12
Dense(hidden => 1),
14
first
15
)
16
-trial(model, x) = model(x)
17
-
18
-ε = cbrt(eps(Float32))
19
-ε₁ = [ε, 0]
20
-ε₂ = [0, ε]
+trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x)
21
22
M = 100
23
data = [rand(input) for _ in 1:M]
24
function loss_by_finitediff(model, x)
+ ε = cbrt(eps(Float32))
+ ε₁ = [ε, 0]
+ ε₂ = [0, ε]
25
error = (trial(model, x + ε₁) + trial(model, x - ε₁) + trial(model, x + ε₂) +
26
trial(model, x - ε₂) - 4 * trial(model, x)) /
27
ε^2 + sin(π * x[1]) * sin(π * x[2])
0 commit comments