Skip to content

Commit ce4b82d

Browse files
committed
Improve macro
1 parent aeeceef commit ce4b82d

7 files changed

Lines changed: 119 additions & 57 deletions

File tree

benchmark/Manifest.toml

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.11.1"
44
manifest_format = "2.0"
5-
project_hash = "448f3cae2dea59422644934cbad9b5490d5dc6eb"
5+
project_hash = "89d6a775281b4cbd649c4e44a66577335ae263f3"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "eea5d80188827b35333801ef97a40c2ed653b081"
@@ -57,9 +57,9 @@ version = "0.1.38"
5757

5858
[[deps.Adapt]]
5959
deps = ["LinearAlgebra", "Requires"]
60-
git-tree-sha1 = "d80af0733c99ea80575f612813fa6aa71022d33a"
60+
git-tree-sha1 = "50c3c56a52972d78e8be9fd135bfb91c9574c140"
6161
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
62-
version = "4.1.0"
62+
version = "4.1.1"
6363
weakdeps = ["StaticArrays"]
6464

6565
[deps.Adapt.extensions]
@@ -221,6 +221,36 @@ deps = ["Artifacts", "Libdl"]
221221
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
222222
version = "1.1.1+0"
223223

224+
[[deps.ComponentArrays]]
225+
deps = ["ArrayInterface", "ChainRulesCore", "ForwardDiff", "Functors", "LinearAlgebra", "PackageExtensionCompat", "StaticArrayInterface", "StaticArraysCore"]
226+
git-tree-sha1 = "bc391f0c19fa242fb6f71794b949e256cfa3772c"
227+
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
228+
version = "0.15.17"
229+
230+
[deps.ComponentArrays.extensions]
231+
ComponentArraysAdaptExt = "Adapt"
232+
ComponentArraysConstructionBaseExt = "ConstructionBase"
233+
ComponentArraysGPUArraysExt = "GPUArrays"
234+
ComponentArraysOptimisersExt = "Optimisers"
235+
ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools"
236+
ComponentArraysReverseDiffExt = "ReverseDiff"
237+
ComponentArraysSciMLBaseExt = "SciMLBase"
238+
ComponentArraysTrackerExt = "Tracker"
239+
ComponentArraysTruncatedStacktracesExt = "TruncatedStacktraces"
240+
ComponentArraysZygoteExt = "Zygote"
241+
242+
[deps.ComponentArrays.weakdeps]
243+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
244+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
245+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
246+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
247+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
248+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
249+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
250+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
251+
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
252+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
253+
224254
[[deps.CompositeTypes]]
225255
git-tree-sha1 = "bce26c3dab336582805503bed209faab1c279768"
226256
uuid = "b152e2b5-7a66-4b01-a709-34e65c35f657"
@@ -364,9 +394,9 @@ version = "1.0.4"
364394

365395
[[deps.Enzyme]]
366396
deps = ["CEnum", "EnzymeCore", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "Preferences", "Printf", "Random", "SparseArrays"]
367-
git-tree-sha1 = "aba39bfce6e65ce740b29c8d9d0c8a6c5770e3c1"
397+
git-tree-sha1 = "abcbb722aafe8ed9cc667884b3a1e1d259c5e562"
368398
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
369-
version = "0.13.12"
399+
version = "0.13.13"
370400

371401
[deps.Enzyme.extensions]
372402
EnzymeBFloat16sExt = "BFloat16s"
@@ -383,19 +413,19 @@ version = "0.13.12"
383413
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
384414

385415
[[deps.EnzymeCore]]
386-
git-tree-sha1 = "9c3a42611e525352e9ad5e4134ddca5c692ff209"
416+
git-tree-sha1 = "04c777af6ef65530a96ab68f0a81a4608113aa1d"
387417
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
388-
version = "0.8.4"
418+
version = "0.8.5"
389419
weakdeps = ["Adapt"]
390420

391421
[deps.EnzymeCore.extensions]
392422
AdaptExt = "Adapt"
393423

394424
[[deps.Enzyme_jll]]
395425
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
396-
git-tree-sha1 = "c180391e0a09fedb2934e5c44455e13c38f859e6"
426+
git-tree-sha1 = "62cf2140d8daa3181e9f9d7a8b5e7b9493a57f21"
397427
uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef"
398-
version = "0.0.157+0"
428+
version = "0.0.159+0"
399429

400430
[[deps.ExceptionUnwrapping]]
401431
deps = ["Test"]
@@ -442,9 +472,9 @@ version = "1.3.7"
442472

443473
[[deps.ForwardDiff]]
444474
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"]
445-
git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad"
475+
git-tree-sha1 = "a9ce73d3c827adab2d70bf168aaece8cce196898"
446476
uuid = "f6369f11-7733-5829-9624-2563aa707210"
447-
version = "0.10.36"
477+
version = "0.10.37"
448478
weakdeps = ["StaticArrays"]
449479

450480
[deps.ForwardDiff.extensions]
@@ -986,6 +1016,12 @@ git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65"
9861016
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
9871017
version = "0.11.31"
9881018

1019+
[[deps.PackageExtensionCompat]]
1020+
git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518"
1021+
uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1022+
version = "1.0.2"
1023+
weakdeps = ["Requires", "TOML"]
1024+
9891025
[[deps.Parsers]]
9901026
deps = ["Dates", "PrecompileTools", "UUIDs"]
9911027
git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821"
@@ -1088,9 +1124,9 @@ version = "1.3.4"
10881124

10891125
[[deps.RecursiveArrayTools]]
10901126
deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"]
1091-
git-tree-sha1 = "43cdc0987135597867a37fc3e8e0fc9fdef6ac66"
1127+
git-tree-sha1 = "ed2514425d030d7c9054fa0f2275ada45681788d"
10921128
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
1093-
version = "3.27.1"
1129+
version = "3.27.2"
10941130

10951131
[deps.RecursiveArrayTools.extensions]
10961132
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
@@ -1152,9 +1188,9 @@ version = "0.1.0"
11521188

11531189
[[deps.SciMLBase]]
11541190
deps = ["ADTypes", "Accessors", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "Expronicon", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface"]
1155-
git-tree-sha1 = "86e1c491cddf233d77d8aadbe289005db44e8445"
1191+
git-tree-sha1 = "7a54136472ca0cb0f66ef22aa3f0ff198f379fa7"
11561192
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1157-
version = "2.57.2"
1193+
version = "2.58.0"
11581194

11591195
[deps.SciMLBase.extensions]
11601196
SciMLBaseChainRulesCoreExt = "ChainRulesCore"
@@ -1374,9 +1410,9 @@ version = "3.7.2"
13741410

13751411
[[deps.Symbolics]]
13761412
deps = ["ADTypes", "ArrayInterface", "Bijections", "CommonWorldInvalidations", "ConstructionBase", "DataStructures", "DiffRules", "Distributions", "DocStringExtensions", "DomainSets", "DynamicPolynomials", "IfElse", "LaTeXStrings", "Latexify", "Libdl", "LinearAlgebra", "LogExpFunctions", "MacroTools", "Markdown", "NaNMath", "PrecompileTools", "Primes", "RecipesBase", "Reexport", "RuntimeGeneratedFunctions", "SciMLBase", "Setfield", "SparseArrays", "SpecialFunctions", "StaticArraysCore", "SymbolicIndexingInterface", "SymbolicLimits", "SymbolicUtils", "TermInterface"]
1377-
git-tree-sha1 = "ef7532b95fbd529e1252cabb36bba64803020840"
1413+
git-tree-sha1 = "41852067b437d16a3ad4e01705ffc6e22925c42c"
13781414
uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1379-
version = "6.16.0"
1415+
version = "6.17.0"
13801416

13811417
[deps.Symbolics.extensions]
13821418
SymbolicsForwardDiffExt = "ForwardDiff"

benchmark/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
34
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
45
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
56
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"

benchmark/groups/pinn.jl

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,56 @@
1-
using Lux, Zygote
1+
using Lux, Zygote, Enzyme, ComponentArrays
22

3-
const input = 2
4-
const hidden = 16
5-
6-
model = Chain(Dense(input => hidden, Lux.relu),
7-
Dense(hidden => hidden, Lux.relu),
8-
Dense(hidden => 1),
9-
first)
10-
11-
ps, st = Lux.setup(rng, model)
12-
13-
trial(model, x) = x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * model(x, ps, st)[1]
3+
function trial(model, x, ps, st)
4+
u, st = Lux.apply(model, x, ps, st)
5+
x[1] * (1 - x[1]) * x[2] * (1 - x[2]) * u
6+
end
147

15-
x = rand(Float32, input)
16-
trial(model, x)
17-
function loss_by_finitediff(model, x)
18-
ε = cbrt(eps(Float32))
19-
ε₁ = [ε, 0]
20-
ε₂ = [0, ε]
21-
error = (trial(model, x + ε₁) + trial(model, x - ε₁) + trial(model, x + ε₂) +
22-
trial(model, x - ε₂) - 4 * trial(model, x)) /
23-
ε^2 + sin* x[1]) * sin* x[2])
8+
function loss_by_finitediff(model, x, ps, st)
9+
T = eltype(x)
10+
ε = cbrt(eps(T))
11+
ε₁ = [ε, zero(T)]
12+
ε₂ = [zero(T), ε]
13+
f(x) = trial(model, x, ps, st)
14+
error = (f(x + ε₁) + f(x - ε₁) + f(x + ε₂) + f(x - ε₂) - 4 * f(x)) / ε^2 +
15+
sin* x[1]) * sin* x[2])
2416
abs2(error)
2517
end
26-
function loss_by_taylordiff(model, x)
27-
f(x) = trial(model, x)
18+
function loss_by_taylordiff(model, x, ps, st)
19+
f(x) = trial(model, x, ps, st)
2820
error = derivative(f, x, Float32[1, 0], Val(2)) +
2921
derivative(f, x, Float32[0, 1], Val(2)) +
3022
sin* x[1]) * sin* x[2])
3123
abs2(error)
3224
end
25+
function loss_by_forwarddiff(model, x, ps, st)
26+
f(x) = trial(model, x, ps, st)
27+
error = derivative(f, x, Float32[1, 0], Val(2)) +
28+
derivative(f, x, Float32[0, 1], Val(2)) +
29+
sin* x[1]) * sin* x[2])
30+
abs2(error)
31+
end
32+
33+
const input = 2
34+
const hidden = 16
35+
model = Chain(Dense(input => hidden, exp),
36+
Dense(hidden => hidden, exp),
37+
Dense(hidden => 1),
38+
first)
39+
x = rand(Float32, input)
40+
dx = deepcopy(x)
41+
ps, st = Lux.setup(rng, model)
42+
ps = ps |> ComponentArray
43+
dps = deepcopy(ps)
44+
dx .= 0;
45+
dps .= 0;
3346

34-
pinn_t = BenchmarkGroup("primal" => (@benchmarkable loss_by_taylordiff($model, $x)),
47+
pinn_t = BenchmarkGroup(
48+
"primal" => (@benchmarkable loss_by_taylordiff($model, $x, $ps, $st)),
3549
"gradient" => (@benchmarkable gradient(loss_by_taylordiff, $model,
36-
$x)))
37-
pinn_f = BenchmarkGroup("primal" => (@benchmarkable loss_by_finitediff($model, $x)),
50+
$x, $ps, $st)))
51+
pinn_f = BenchmarkGroup(
52+
"primal" => (@benchmarkable loss_by_finitediff($model, $x, $ps, $st)),
3853
"gradient" => (@benchmarkable gradient($loss_by_finitediff, $model,
39-
$x)))
54+
$x, $ps, $st)))
4055
pinn = BenchmarkGroup(["vector", "physical"], "taylordiff" => pinn_t,
4156
"finitediff" => pinn_f)

src/array.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,8 @@ function find_taylor(a::Array{<:Tuple{TaylorScalar{T, P}, Any}, N}, rest) where
9696
TaylorArray{P}(zeros(T, size(a)))
9797
end
9898
find_taylor(::Any, rest) = find_taylor(rest)
99+
100+
# function Base.copyto!(dest::TaylorArray, bc::Broadcast.Broadcasted{<:TaylorArrayStyle, Axes}) where Axes
101+
# println("copyto!($(typeof(dest)), $(typeof(bc)))")
102+
# error("Not implemented")
103+
# end

src/chainrules.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,21 @@ function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
1717
end
1818

1919
function rrule(::typeof(partials), t::TaylorScalar{T, N}) where {N, T}
20-
value_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(0, v̄)
20+
z = zero(T)
21+
partials_pullback(v̄::NTuple{N, T}) = NoTangent(), TaylorScalar(z, v̄)
2122
# for structural tangent, convert to tuple
22-
function value_pullback(v̄::Tangent{P, NTuple{N, T}}) where {P}
23-
NoTangent(), TaylorScalar(zero(T), backing(v̄))
23+
function partials_pullback(v̄::Tangent{P, NTuple{N, T}}) where {P}
24+
NoTangent(), TaylorScalar(z, backing(v̄))
2425
end
25-
function value_pullback(v̄)
26-
NoTangent(), TaylorScalar(zero(T), map(x -> convert(T, x), Tuple(v̄)))
26+
function partials_pullback(::ZeroTangent)
27+
NoTangent(), TaylorScalar(z, ntuple(j -> zero(T), Val(N)))
2728
end
28-
return partials(t), value_pullback
29+
return partials(t), partials_pullback
2930
end
3031

3132
function rrule(::typeof(partials), t::TaylorArray{T, N, A, P}) where {N, T, A, P}
32-
value_pullback(v̄::NTuple{P, A}) = NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄)
33-
return partials(t), value_pullback
33+
partials_pullback(v̄::NTuple{P, A}) = NoTangent(), TaylorArray(broadcast(zero, v̄[1]), v̄)
34+
return partials(t), partials_pullback
3435
end
3536

3637
function rrule(::typeof(extract_derivative), t::TaylorScalar{T, P},

src/primitive.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ end
4848
f = flatten(t)
4949
v[0] = exp(f[0])
5050
for i in 1:P
51-
v[i] = zero(T)
5251
for j in 0:(i - 1)
5352
v[i] += (i - j) * v[j] * f[i - j]
5453
end
@@ -62,8 +61,6 @@ for func in (:sin, :cos)
6261
f = flatten(t)
6362
s[0], c[0] = sincos(f[0])
6463
for i in 1:P
65-
s[i] = zero(T)
66-
c[i] = zero(T)
6764
for j in 0:(i - 1)
6865
s[i] += (i - j) * c[j] * f[i - j]
6966
c[i] -= (i - j) * s[j] * f[i - j]
@@ -107,7 +104,6 @@ end
107104
@immutable function *(a::TaylorScalar{T, P}, b::TaylorScalar{T, P}) where {T, P}
108105
va, vb = flatten(a), flatten(b)
109106
for i in 0:P
110-
v[i] = zero(T)
111107
for j in 0:i
112108
v[i] += va[j] * vb[i - j]
113109
end
@@ -140,7 +136,6 @@ for R in (Integer, Real)
140136
f = flatten(t)
141137
v[0] = f[0]^n
142138
for i in 1:P
143-
v[i] = zero(T)
144139
for j in 0:(i - 1)
145140
v[i] += (n * (i - j) - j) * v[j] * f[i - j]
146141
end
@@ -164,3 +159,9 @@ end
164159
@inline raise(f0, d::TaylorScalar, t) = integrate(differentiate(t) * d, f0)
165160
@inline raise(f0, d::Number, t) = d * t
166161
@inline raiseinv(f0, d, t) = integrate(differentiate(t) / d, f0)
162+
163+
# Array primitives
164+
165+
# Pass-through linear operators
166+
167+
# *(a::AbstractMatrix{T}, b::TaylorArray{T}) where {T} = TaylorArray(a * value(b), map(p -> a * p, partials(b)))

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,12 @@ function process(d, expr)
8686
end
8787
# Modify indices
8888
magic_names = (:v, :s, :c)
89+
known_names = Set()
8990
expr = postwalk(expr) do x
9091
@match x begin
9192
a_[idx_] => a in magic_names ? Symbol(a, idx) : :($a[begin + $idx])
93+
(a_ += b_) => a in known_names ? :($a += $b) : (push!(known_names, a); :($a = $b))
94+
(a_ -= b_) => a in known_names ? :($a -= $b) : (push!(known_names, a); :($a = -$b))
9295
TaylorScalar(v_) => :(TaylorScalar(tuple($([Symbol(v, idx) for idx in 0:d[:P]]...))))
9396
_ => x
9497
end

0 commit comments

Comments
 (0)