From b4ae4d7a6fc775f2d73066727fd34833abc854d2 Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 2 Oct 2025 13:25:12 -0400 Subject: [PATCH 01/64] Working BP Commit --- src/ITensorNetworksNext.jl | 3 +++ src/abstracttensornetwork.jl | 2 +- test/test_beliefpropagation.jl | 25 +++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 test/test_beliefpropagation.jl diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 19c4109..905d783 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -9,4 +9,7 @@ include("abstract_problem.jl") include("iterators.jl") include("adapters.jl") +include("beliefpropagation/abstractbeliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationcache.jl") + end diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index e566752..1ecbffa 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -254,4 +254,4 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) return nothing end -Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) \ No newline at end of file diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl new file mode 100644 index 0000000..4b179fb --- /dev/null +++ b/test/test_beliefpropagation.jl @@ -0,0 +1,25 @@ +using Dictionaries: Dictionary +using ITensorBase: Index +using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, + partitionfunction +using Graphs: edges, vertices +using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using Test: @test, @testset + +@testset "BeliefPropagation" begin + dims = (4, 1) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 10) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 +end \ No newline at end of file From d77d0632e6e88a13ab817d9d8a99a90442d37efe Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 23 Oct 2025 18:23:27 -0400 Subject: [PATCH 02/64] BP Code --- .../abstractbeliefpropagationcache.jl | 151 +++++++++++ .../beliefpropagationcache.jl | 237 ++++++++++++++++++ test/test_beliefpropagation.jl | 20 +- 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 src/beliefpropagation/abstractbeliefpropagationcache.jl create mode 100644 src/beliefpropagation/beliefpropagationcache.jl diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl new file mode 100644 index 0000000..5eae283 --- /dev/null +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -0,0 +1,151 @@ +abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end + +#Interface +factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented() +setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented() +messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented() +function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message) + return not_implemented() +end +function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +function rescale_messages( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs... + ) + return not_implemented() +end +function rescale_vertices( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs... + ) + return not_implemented() +end + +function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return not_implemented() +end +function edge_scalar( + bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs... + ) + return not_implemented() +end + +#Graph functionality needed +Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function NamedGraphs.GraphsExtensions.boundary_edges( + bp_cache::AbstractBeliefPropagationCache, vertices; kwargs... + ) + return not_implemented() +end + +#Functions derived from the interface +function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages) + for (e, m) in zip(edges) + setmessage!(bp_cache, e, m) + end + return +end + +function deletemessages!( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache) + ) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache +end + +function vertex_scalars( + bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs... + ) + return map(v -> region_scalar(bp_cache, v; kwargs...), vertices) +end + +function edge_scalars( + bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs... + ) + return map(e -> region_scalar(bp_cache, e; kwargs...), edges) +end + +function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache) + return vertex_scalars(bp_cache), edge_scalars(bp_cache) +end + +function incoming_messages( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = [] + ) + b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in) + b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges + return messages(bp_cache, b_edges) +end + +function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return incoming_messages(bp_cache, [vertex]; kwargs...) +end + +#Adapt interface for changing device +function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache)) + bp_cache = copy(bp_cache) + for e in es + setmessage!(bp_cache, e, f(message(bp_cache, e))) + end + return bp_cache +end +function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache)) + bp_cache = copy(bp_cache) + for v in vs + setfactor!(bp_cache, v, f(factor(bp_cache, v))) + end + return bp_cache +end +function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_messages(adapt(to), bp_cache, args...) +end +function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_factors(adapt(to), bp_cache, args...) +end + +function freenergy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + any(iszero, denominator_terms) && return -Inf + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) +end + +function partitionfunction(bp_cache::AbstractBeliefPropagationCache) + return exp(freenergy(bp_cache)) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return rescale_messages(bp_cache, [edge]) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache) + return rescale_messages(bp_cache, edges(bp_cache)) +end + +function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...) + return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...) +end + +function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...) + return rescale_vertices(bpc, [vertex]; kwargs...) +end + +function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) + bpc = rescale_messages(bpc) + bpc = rescale_partitions(bpc, args...; kwargs...) + return bpc +end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl new file mode 100644 index 0000000..295502a --- /dev/null +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -0,0 +1,237 @@ +using DiagonalArrays: delta +using Dictionaries: Dictionary, set!, delete! +using Graphs: AbstractGraph, is_tree, connected_components +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges +using ITensorBase: ITensor, dim +using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype + +struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: + AbstractBeliefPropagationCache{V} + network::N + messages::Dictionary +end + +messages(bp_cache::BeliefPropagationCache) = bp_cache.messages +network(bp_cache::BeliefPropagationCache) = bp_cache.network +default_messages() = Dictionary() + +BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages()) + +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) +end + +function deletemessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge) + ms = messages(bp_cache) + delete!(ms, e) + return bp_cache +end + +function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) + ms = messages(bp_cache) + set!(ms, e, message) + return bp_cache +end + +function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...) + ms = messages(bp_cache) + return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) +end + +function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}) + return [message(bp_cache, e) for e in edges] +end + +default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing +#Forward onto the network +for f in [ + :(Graphs.vertices), + :(Graphs.edges), + :(Graphs.is_tree), + :(NamedGraphs.GraphsExtensions.boundary_edges), + :(factors), + :(default_bp_maxiter), + :(ITensorNetworksNext.setfactor!), + :(ITensorNetworksNext.linkinds), + :(ITensorNetworksNext.underlying_graph), + ] + @eval begin + function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) + return $f(network(bp_cache), args...; kwargs...) + end + end +end + +#TODO: Get subgraph working on an ITensorNetwork to overload this directly +function default_bp_edge_sequence(bp_cache::BeliefPropagationCache) + return forest_cover_edge_sequence(underlying_graph(bp_cache)) +end + +function factors(tn::AbstractTensorNetwork, vertex) + return [tn[vertex]] +end + +function region_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] +end + +function region_scalar(bp_cache::BeliefPropagationCache, vertex) + incoming_ms = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + return (reduce(*, incoming_ms) * reduce(*, state))[] +end + +function default_message(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return default_message(network(bp_cache), edge::AbstractEdge) +end + +function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) + t = ITensor(ones(dim.(linkinds(tn, edge))...), linkinds(tn, edge)...) + #TODO: Get datatype working on tensornetworks so we can support GPU, etc... + return t +end + +#Algorithmic defaults +default_update_alg(bp_cache::BeliefPropagationCache) = "bp" +default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract" +default_normalize(::Algorithm"contract") = true +default_sequence_alg(::Algorithm"contract") = "optimal" +function set_default_kwargs(alg::Algorithm"contract") + normalize = get(alg, :normalize, default_normalize(alg)) + sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) + return Algorithm("contract"; normalize, sequence_alg) +end +function set_default_kwargs(alg::Algorithm"adapt_update") + _alg = set_default_kwargs(get(alg, :alg, Algorithm("contract"))) + return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg) +end +default_verbose(::Algorithm"bp") = false +default_tol(::Algorithm"bp") = nothing +function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) + verbose = get(alg, :verbose, default_verbose(alg)) + maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache)) + edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache)) + tol = get(alg, :tol, default_tol(alg)) + message_update_alg = set_default_kwargs( + get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache))) + ) + return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg) +end + +#TODO: Update message etc should go here... +function updated_message( + alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge + ) + vertex = src(edge) + incoming_ms = incoming_messages( + bp_cache, vertex; ignore_edges = typeof(edge)[reverse(edge)] + ) + state = factors(bp_cache, vertex) + #contract_list = ITensor[incoming_ms; state] + #sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg) + #updated_messages = contract(contract_list; sequence) + updated_message = + !isempty(incoming_ms) ? reduce(*, state) * reduce(*, incoming_ms) : reduce(*, state) + if alg.normalize + message_norm = LinearAlgebra.norm(updated_message) + if !iszero(message_norm) + updated_message /= message_norm + end + end + return updated_message +end + +function updated_message( + bp_cache::BeliefPropagationCache, + edge::AbstractEdge; + alg = default_message_update_alg(bpc), + kwargs..., + ) + return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge) +end + +function update_message!( + message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge + ) + return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge)) +end + +""" +Do a sequential update of the message tensors on `edges` +""" +function update_iteration( + alg::Algorithm"bp", + bpc::AbstractBeliefPropagationCache, + edges::Vector; + (update_diff!) = nothing, + ) + bpc = copy(bpc) + for e in edges + prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing + update_message!(alg.message_update_alg, bpc, e) + if !isnothing(update_diff!) + update_diff![] += message_diff(message(bpc, e), prev_message) + end + end + return bpc +end + +""" +Do parallel updates between groups of edges of all message tensors +Currently we send the full message tensor data struct to update for each edge_group. But really we only need the +mts relevant to that group. +""" +function update_iteration( + alg::Algorithm"bp", + bpc::AbstractBeliefPropagationCache, + edge_groups::Vector{<:Vector{<:AbstractEdge}}; + (update_diff!) = nothing, + ) + new_mts = empty(messages(bpc)) + for edges in edge_groups + bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!)) + for e in edges + set!(new_mts, e, message(bpc_t, e)) + end + end + return set_messages(bpc, new_mts) +end + +""" +More generic interface for update, with default params +""" +function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache) + compute_error = !isnothing(alg.tol) + if isnothing(alg.maxiter) + error("You need to specify a number of iterations for BP!") + end + for i in 1:alg.maxiter + diff = compute_error ? Ref(0.0) : nothing + bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff) + if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol + if alg.verbose + println("BP converged to desired precision after $i iterations.") + end + break + end + end + return bpc +end + +function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...) + return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc) +end + +#Edge sequence stuff +function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) + forests = forest_cover(g) + edges = edgetype(g)[] + for forest in forests + trees = [forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + return edges +end \ No newline at end of file diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 4b179fb..81ee722 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -3,11 +3,13 @@ using ITensorBase: Index using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, partitionfunction using Graphs: edges, vertices -using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using Test: @test, @testset @testset "BeliefPropagation" begin + + #Chain of tensors dims = (4, 1) g = named_grid(dims) l = Dict(e => Index(2) for e in edges(g)) @@ -17,6 +19,22 @@ using Test: @test, @testset return randn(Tuple(is)) end + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 1) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 + + #Tree of tensors + dims = (4, 3) + g = named_comb_tree(dims) + l = Dict(e => Index(3) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 10) z_bp = partitionfunction(bpc) From b80e36eaf6aac3a3702bd0403d7858603366b1e7 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Oct 2025 15:18:28 -0400 Subject: [PATCH 03/64] Express BP in terms of `SweepIterator` interface Introduce `BeliefPropagationProblem` wrapper to hold the cache and the error `diff` field. Also simplifies some kwargs wrangling. --- Project.toml | 2 + src/ITensorNetworksNext.jl | 1 + .../beliefpropagationcache.jl | 126 ++---------------- .../beliefpropagationproblem.jl | 85 ++++++++++++ 4 files changed, 101 insertions(+), 113 deletions(-) create mode 100644 src/beliefpropagation/beliefpropagationproblem.jl diff --git a/Project.toml b/Project.toml index 95b8be0..e0aea23 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" @@ -39,6 +40,7 @@ DerivableInterfaces = "0.5.5" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" +ITensorBase = "0.2.14" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.8, 0.9" diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 905d783..cca4b6d 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -11,5 +11,6 @@ include("adapters.jl") include("beliefpropagation/abstractbeliefpropagationcache.jl") include("beliefpropagation/beliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationproblem.jl") end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 295502a..cdae651 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,9 +1,7 @@ -using DiagonalArrays: delta using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges using ITensorBase: ITensor, dim -using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: AbstractBeliefPropagationCache{V} @@ -13,9 +11,8 @@ end messages(bp_cache::BeliefPropagationCache) = bp_cache.messages network(bp_cache::BeliefPropagationCache) = bp_cache.network -default_messages() = Dictionary() -BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages()) +BeliefPropagationCache(network) = BeliefPropagationCache(network, Dictionary()) function Base.copy(bp_cache::BeliefPropagationCache) return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) @@ -33,16 +30,15 @@ function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) return bp_cache end -function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...) +function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) ms = messages(bp_cache) return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) end -function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}) +function messages(bp_cache::BeliefPropagationCache, edges::Vector{<:AbstractEdge}) return [message(bp_cache, e) for e in edges] end -default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing #Forward onto the network for f in [ :(Graphs.vertices), @@ -62,11 +58,6 @@ for f in [ end end -#TODO: Get subgraph working on an ITensorNetwork to overload this directly -function default_bp_edge_sequence(bp_cache::BeliefPropagationCache) - return forest_cover_edge_sequence(underlying_graph(bp_cache)) -end - function factors(tn::AbstractTensorNetwork, vertex) return [tn[vertex]] end @@ -91,33 +82,6 @@ function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) return t end -#Algorithmic defaults -default_update_alg(bp_cache::BeliefPropagationCache) = "bp" -default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract" -default_normalize(::Algorithm"contract") = true -default_sequence_alg(::Algorithm"contract") = "optimal" -function set_default_kwargs(alg::Algorithm"contract") - normalize = get(alg, :normalize, default_normalize(alg)) - sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) - return Algorithm("contract"; normalize, sequence_alg) -end -function set_default_kwargs(alg::Algorithm"adapt_update") - _alg = set_default_kwargs(get(alg, :alg, Algorithm("contract"))) - return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg) -end -default_verbose(::Algorithm"bp") = false -default_tol(::Algorithm"bp") = nothing -function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) - verbose = get(alg, :verbose, default_verbose(alg)) - maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache)) - edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache)) - tol = get(alg, :tol, default_tol(alg)) - message_update_alg = set_default_kwargs( - get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache))) - ) - return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg) -end - #TODO: Update message etc should go here... function updated_message( alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge @@ -141,85 +105,21 @@ function updated_message( return updated_message end -function updated_message( - bp_cache::BeliefPropagationCache, - edge::AbstractEdge; - alg = default_message_update_alg(bpc), - kwargs..., +function default_algorithm( + ::Type{<:Algorithm"contract"}; normalize = true, sequence_alg = "optimal" ) - return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge) + return Algorithm("contract"; normalize, sequence_alg) end - -function update_message!( - message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge +function default_algorithm( + ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") ) - return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge)) + return Algorithm("adapt_update"; adapt, alg) end -""" -Do a sequential update of the message tensors on `edges` -""" -function update_iteration( - alg::Algorithm"bp", - bpc::AbstractBeliefPropagationCache, - edges::Vector; - (update_diff!) = nothing, - ) - bpc = copy(bpc) - for e in edges - prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing - update_message!(alg.message_update_alg, bpc, e) - if !isnothing(update_diff!) - update_diff![] += message_diff(message(bpc, e), prev_message) - end - end - return bpc -end - -""" -Do parallel updates between groups of edges of all message tensors -Currently we send the full message tensor data struct to update for each edge_group. But really we only need the -mts relevant to that group. -""" -function update_iteration( - alg::Algorithm"bp", - bpc::AbstractBeliefPropagationCache, - edge_groups::Vector{<:Vector{<:AbstractEdge}}; - (update_diff!) = nothing, +function update_message!( + message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge ) - new_mts = empty(messages(bpc)) - for edges in edge_groups - bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!)) - for e in edges - set!(new_mts, e, message(bpc_t, e)) - end - end - return set_messages(bpc, new_mts) -end - -""" -More generic interface for update, with default params -""" -function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache) - compute_error = !isnothing(alg.tol) - if isnothing(alg.maxiter) - error("You need to specify a number of iterations for BP!") - end - for i in 1:alg.maxiter - diff = compute_error ? Ref(0.0) : nothing - bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff) - if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol - if alg.verbose - println("BP converged to desired precision after $i iterations.") - end - break - end - end - return bpc -end - -function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...) - return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc) + return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) end #Edge sequence stuff @@ -234,4 +134,4 @@ function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root end end return edges -end \ No newline at end of file +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl new file mode 100644 index 0000000..a497363 --- /dev/null +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -0,0 +1,85 @@ +mutable struct BeliefPropagationProblem{V, Cache <: AbstractBeliefPropagationCache{V}} <: + AbstractProblem + const cache::Cache + diff::Union{Nothing, Float64} +end + +function default_algorithm( + ::Type{<:Algorithm"bp"}, + bpc::BeliefPropagationCache; + verbose = false, + tol = nothing, + edge_sequence = forest_cover_edge_sequence(underlying_graph(bpc)), + message_update_alg = default_algorithm(Algorithm"contract"), + maxiter = is_tree(bpc) ? 1 : nothing, + ) + return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) +end + +function compute!(iter::RegionIterator{<:BeliefPropagationProblem}) + prob = iter.problem + + edge_group, kwargs = current_region_plan(iter) + + new_message_tensors = map(edge_group) do edge + old_message = message(prob.cache, edge) + + new_message = updated_message(kwargs.message_update_alg, prob.cache, edge) + + if !isnothing(prob.diff) + # TODO: Define `message_diff` + prob.diff += message_diff(new_message, old_message) + end + + return new_message + end + + foreach(edge_group, new_message_tensors) do edge, new_message + setmessage!(prob.cache, edge, new_message) + end + + return iter +end + +function region_plan( + prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... + ) + edges = forest_cover_edge_sequence(underlying_graph(prob.cache); root_vertex) + + plan = map(edges) do e + return [e] => (; sweep_kwargs...) + end + + return plan +end + +function update(bpc::AbstractBeliefPropagationCache; kwargs...) + return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) +end +function update(alg::Algorithm"bp", bpc) + compute_error = !isnothing(alg.tol) + + diff = compute_error ? 0.0 : nothing + + prob = BeliefPropagationProblem(bpc, diff) + + iter = SweepIterator(prob, alg.maxiter; compute_error, getfield(alg, :kwargs)...) + + for _ in iter + if compute_error && prob.diff <= alg.tol + break + end + end + + if alg.verbose && compute_error + if prob.diff <= alg.tol + println("BP converged to desired precision after $(iter.which_sweep) iterations.") + else + println( + "BP failed to converge to precision $(alg.tol), got $(prob.diff) after $(iter.which_sweep) iterations", + ) + end + end + + return bpc +end From fe44b804f7461106caa3a8dbc6f0dad38ff67ede Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 31 Oct 2025 12:46:03 -0400 Subject: [PATCH 04/64] Add method for `setmessages!` that allows messages from one cache to be set from another cache --- src/beliefpropagation/beliefpropagationcache.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index cdae651..b3a32b1 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -30,6 +30,14 @@ function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) return bp_cache end +function setmessages!(bpc_dst::BeliefPropagationCache, bpc_src::BeliefPropagationCache, edges) + ms_dst = messages(bpc_dst) + for e in edges + set!(ms_dst, e, message(bpc_src, e)) + end + return bpc_dst +end + function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) ms = messages(bp_cache) return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) From 3ce08983b2a9feae9057dc10ca55491bddf08079 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 10 Nov 2025 14:03:59 -0500 Subject: [PATCH 05/64] Network is now passed to `forest_cover_edge_sequence` directly. --- src/beliefpropagation/beliefpropagationproblem.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index a497363..967b454 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -9,7 +9,7 @@ function default_algorithm( bpc::BeliefPropagationCache; verbose = false, tol = nothing, - edge_sequence = forest_cover_edge_sequence(underlying_graph(bpc)), + edge_sequence = forest_cover_edge_sequence(network(bpc)), message_update_alg = default_algorithm(Algorithm"contract"), maxiter = is_tree(bpc) ? 1 : nothing, ) @@ -44,7 +44,8 @@ end function region_plan( prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... ) - edges = forest_cover_edge_sequence(underlying_graph(prob.cache); root_vertex) + + edges = forest_cover_edge_sequence(network(prob.cache); root_vertex) plan = map(edges) do e return [e] => (; sweep_kwargs...) From f6e4fd0ea748f4a3da272dc1011a855fdaee7a9e Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 11:19:31 -0500 Subject: [PATCH 06/64] test file formatting --- test/test_beliefpropagation.jl | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 81ee722..fc657e7 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,7 +1,17 @@ using Dictionaries: Dictionary using ITensorBase: Index -using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, - partitionfunction +using ITensorNetworksNext: + BeliefPropagationCache, + ITensorNetworksNext, + TensorNetwork, + adapt_messages, + default_message, + default_messages, + edge_scalars, + factors, + messages, + partitionfunction, + setmessages! using Graphs: edges, vertices using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges @@ -15,15 +25,15 @@ using Test: @test, @testset l = Dict(e => Index(2) for e in edges(g)) l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) end bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1e-14 + @test abs(z_bp - z_exact) <= 1.0e-14 #Tree of tensors dims = (4, 3) @@ -31,13 +41,14 @@ using Test: @test, @testset l = Dict(e => Index(3) for e in edges(g)) l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) end bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 10) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1e-14 -end \ No newline at end of file + @test abs(z_bp - z_exact) <= 1.0e-14 +end + From 63840a90df869893d87c1ce6a6c58e06bb13973c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 11:25:31 -0500 Subject: [PATCH 07/64] Add `DataGraphsPartitionedGraphsExt` glue for `TensorNetwork` type Also includes some fixes to the way `TensorNetwork` types are constructed based on index structure. --- src/tensornetwork.jl | 79 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 582eec6..11c2e88 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,10 +1,21 @@ using Combinatorics: combinations using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: AbstractDictionary, Indices, dictionary -using Graphs: AbstractSimpleGraph +using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype -using NamedGraphs.GraphsExtensions: add_edges!, arrange_edge, arranged_edges, vertextype +using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, vertextype +using NamedGraphs.PartitionedGraphs: + AbstractPartitionedGraph, + PartitionedGraphs, + departition, + partitioned_vertices, + partitionedgraph, + quotient_graph, + quotient_graph_type +using .LazyNamedDimsArrays: lazy, Mul +using DataGraphs: vertex_data_eltype, vertex_data, edge_data +using DataGraphs.DataGraphsPartitionedGraphsExt function _TensorNetwork end @@ -24,8 +35,14 @@ function _TensorNetwork(graph::AbstractGraph, tensors) return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) end +function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors} + return _TensorNetwork(graph, Tensors()) +end + DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) +DataGraphs.edge_data(tn::TensorNetwork) = Dictionary{edgetype(tn), Nothing}() +DataGraphs.vertex_data_eltype(T::Type{<:TensorNetwork}) = eltype(fieldtype(T, :tensors)) function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) end @@ -70,7 +87,10 @@ function fix_links!(tn::AbstractTensorNetwork) for e in setdiff(arranged_edges(graph), tn_edges) insert_trivial_link!(tn, e) end - return tn + for edge in setdiff(arranged_edges(graph), arranged_edges(graph_structure)) + insert_trivial_link!(network, edge) + end + return network end # Determine the graph structure from the tensors. @@ -93,3 +113,56 @@ end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) + +Graphs.connected_components(tn::TensorNetwork) = Graphs.connected_components(underlying_graph(tn)) + +function Graphs.rem_edge!(tn::TensorNetwork, e) + if !has_edge(underlying_graph(tn), e) + return false + end + if !isempty(linkinds(tn, e)) + throw(ArgumentError("cannot remove edge $e due to tensor indices existing on this edge.")) + end + rem_edge!(underlying_graph(tn), e) + return true +end + +function GraphsExtensions.graph_from_vertices(type::Type{<:TensorNetwork}, vertices) + DT = fieldtype(type, :tensors) + empty_dict = DT() + return TensorNetwork(similar_graph(underlying_graph_type(type), vertices), empty_dict) +end + +## PartitionedGraphs +function PartitionedGraphs.quotient_graph(tn::TensorNetwork) + ug = quotient_graph(underlying_graph(tn)) + return TensorNetwork(ug, vertex_data(QuotientView(tn))) +end +function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) + UG = quotient_graph_type(underlying_graph_type(type)) + VD = Vector{vertex_data_eltype(type)} + V = vertextype(UG) + return TensorNetwork{V, VD, UG, Dictionary{V, VD}} +end + +function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts) + pg = partitionedgraph(underlying_graph(tn), parts) + return TensorNetwork(pg, vertex_data(tn)) +end + +PartitionedGraphs.departition(tn::TensorNetwork) = tn +function PartitionedGraphs.departition( + tn::TensorNetwork{<:Any, <:Any, <:AbstractPartitionedGraph} + ) + return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) +end + +function DataGraphsPartitionedGraphsExt.to_quotient_vertex_data(::TensorNetwork, data) + return mapreduce(lazy, *, collect(last(data))) +end + +function PartitionedGraphs.quotientview(tn::TensorNetwork) + qview = QuotientView(underlying_graph(tn)) + tensors = vertex_data(QuotientView(tn)) + return TensorNetwork(qview, tensors) +end From ba22ab5b107d2b681a5bd1d29395c0f390f23d56 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:27:20 -0500 Subject: [PATCH 08/64] Make abstract tensor network interface more generic. --- src/abstracttensornetwork.jl | 106 ++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 52 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 1ecbffa..b02c789 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -9,19 +9,23 @@ using LinearAlgebra: LinearAlgebra, factorize using MacroTools: @capture using NamedDimsArrays: dimnames, inds using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree -using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rem_edges!, - rename_vertices, vertextype +using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +using NamedGraphs.GraphsExtensions: + ⊔, + directed_graph, + incident_edges, + rem_edges!, + rename_vertices, + vertextype using SplitApplyCombine: flatten +using NamedGraphs.SimilarType: similar_type abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end -function Graphs.rem_edge!(tn::AbstractTensorNetwork, e) - rem_edge!(underlying_graph(tn), e) - return tn -end +# Need to be careful about removing edges from tensor networks in case there is a bond +Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented() -# TODO: Define a generic fallback for `AbstractDataGraph`? -DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data") +DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = not_implemented() # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) @@ -36,7 +40,7 @@ function Graphs.weights(graph::AbstractTensorNetwork) end # Copy -Base.copy(tn::AbstractTensorNetwork) = error("Not implemented") +Base.copy(::AbstractTensorNetwork) = not_implemented() # Iteration Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) @@ -49,20 +53,11 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) # Overload if needed Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false -# Derived interface, may need to be overloaded -function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork}) - return underlying_graph_type(data_graph_type(G)) -end - # AbstractDataGraphs overloads -function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end -function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") -end +DataGraphs.vertex_data(::AbstractTensorNetwork) = not_implemented() +DataGraphs.edge_data(::AbstractTensorNetwork) = not_implemented() -DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented") +DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented() function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) return NamedGraphs.vertex_positions(underlying_graph(tn)) end @@ -81,40 +76,37 @@ function Adapt.adapt_structure(to, tn::AbstractTensorNetwork) return map_vertex_data_preserve_graph(adapt(to), tn) end -function linkinds(tn::AbstractTensorNetwork, edge::Pair) - return linkinds(tn, edgetype(tn)(edge)) -end -function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) - return inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) -end -function linkaxes(tn::AbstractTensorNetwork, edge::Pair) +linkinds(tn::AbstractGraph, edge::Pair) = linkinds(tn, edgetype(tn)(edge)) +linkinds(tn::AbstractGraph, edge::AbstractEdge) = inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) + +function linkaxes(tn::AbstractGraph, edge::Pair) return linkaxes(tn, edgetype(tn)(edge)) end -function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linkaxes(tn::AbstractGraph, edge::AbstractEdge) return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) end -function linknames(tn::AbstractTensorNetwork, edge::Pair) +function linknames(tn::AbstractGraph, edge::Pair) return linknames(tn, edgetype(tn)(edge)) end -function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function linknames(tn::AbstractGraph, edge::AbstractEdge) return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) end -function siteinds(tn::AbstractTensorNetwork, v) +function siteinds(tn::AbstractGraph, v) s = inds(tn[v]) for v′ in neighbors(tn, v) s = setdiff(s, inds(tn[v′])) end return s end -function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) +function siteaxes(tn::AbstractGraph, edge::AbstractEdge) s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, axes(tn[v′])) end return s end -function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) +function sitenames(tn::AbstractGraph, edge::AbstractEdge) s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) for v′ in neighbors(tn, v) s = setdiff(s, dimnames(tn[v′])) @@ -122,8 +114,8 @@ function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) return s end -function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex) - vertex_data(tn)[vertex] = value +function setindex_preserve_graph!(tn::AbstractGraph, value, vertex) + set!(vertex_data(tn), vertex, value) return tn end @@ -153,7 +145,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should exist based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork) +function add_missing_edges!(tn::AbstractGraph) foreach(v -> add_missing_edges!(tn, v), vertices(tn)) return tn end @@ -161,7 +153,7 @@ end # Update the graph of the TensorNetwork `tn` to include # edges that should be incident to the vertex `v` # based on the tensor connectivity. -function add_missing_edges!(tn::AbstractTensorNetwork, v) +function add_missing_edges!(tn::AbstractGraph, v) for v′ in vertices(tn) if v ≠ v′ e = v => v′ @@ -175,13 +167,13 @@ end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity. -function fix_edges!(tn::AbstractTensorNetwork) +function fix_edges!(tn::AbstractGraph) foreach(v -> fix_edges!(tn, v), vertices(tn)) return tn end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity at vertex `v`. -function fix_edges!(tn::AbstractTensorNetwork, v) +function fix_edges!(tn::AbstractGraph, v) rem_edges!(tn, incident_edges(tn, v)) add_missing_edges!(tn, v) return tn @@ -215,28 +207,20 @@ function Base.setindex!(tn::AbstractTensorNetwork, value, v) fix_edges!(tn, v) return tn end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger # Fix ambiguity error. function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger) graph[vertices(graph)[vertex]] = value return graph end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) - return error("No edge data.") -end -# Fix ambiguity error. -function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) - return error("No edge data.") -end -using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) = not_implemented() +Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented() # Fix ambiguity error. function Base.setindex!( tn::AbstractTensorNetwork, value, edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger}, ) - return error("No edge data.") + return not_implemented() end function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) @@ -254,4 +238,22 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) return nothing end -Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) \ No newline at end of file +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) + +function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices::AbstractVector{V}) where {V <: Int} + return tensornetwork_induced_subgraph(graph, subvertices) +end +function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices) + return tensornetwork_induced_subgraph(graph, subvertices) +end + +function tensornetwork_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + subgraph = similar_type(graph)(underlying_subgraph) + for v in vertices(subgraph) + if isassigned(graph, v) + set!(vertex_data(subgraph), v, graph[v]) + end + end + return subgraph, vlist +end From 49b087015955f1865cc7b333e43f35b47e704751 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:27:50 -0500 Subject: [PATCH 09/64] BP Caching overhauls --- .../abstractbeliefpropagationcache.jl | 184 ++++++++---------- .../beliefpropagationcache.jl | 178 ++++++----------- .../beliefpropagationproblem.jl | 109 ++++++++--- 3 files changed, 226 insertions(+), 245 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 5eae283..8c6b3dd 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,117 +1,124 @@ -abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end +using Graphs: AbstractGraph, AbstractEdge +using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype +using NamedGraphs.GraphsExtensions: boundary_edges +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent -#Interface -factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented() -setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented() -messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented() -function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return not_implemented() -end -default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message) - return not_implemented() +messages(::AbstractGraph) = not_implemented() +messages(bp_cache::AbstractDataGraph) = edge_data(bp_cache) +messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] + +message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] + +deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() +function deletemessage!(bp_cache::AbstractDataGraph, edge) + ms = messages(bp_cache) + delete!(ms, edge) + return bp_cache end -function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return not_implemented() + +function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache)) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache end -function rescale_messages( - bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs... - ) - return not_implemented() + +setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented() +function setmessage!(bp_cache::AbstractDataGraph, edge, message) + ms = messages(bp_cache) + set!(ms, edge, message) + return bp_cache end -function rescale_vertices( - bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs... - ) - return not_implemented() +function setmessage!(bp_cache::QuotientView, edge, message) + setmessages!(parent(bp_cache), QuotientEdge(edge), message) + return bp_cache end -function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) - return not_implemented() +function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message) + for e in edges(bp_cache, edge) + setmessage!(parent(bp_cache), e, message[e]) + end + return bp_cache end -function edge_scalar( - bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs... - ) - return not_implemented() +function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges) + for e in edges + setmessage!(bpc_dst, e, message(bpc_src, e)) + end + return bpc_dst end -#Graph functionality needed -Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented() -function NamedGraphs.GraphsExtensions.boundary_edges( - bp_cache::AbstractBeliefPropagationCache, vertices; kwargs... - ) - return not_implemented() +factors(bpc::AbstractGraph) = vertex_data(bpc) +factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices] +factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex]) + +factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex] + +setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() +function setfactor!(bpc::AbstractDataGraph, vertex, factor) + fs = factors(bpc) + set!(fs, vertex, factor) + return bpc end -#Functions derived from the interface -function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages) - for (e, m) in zip(edges) - setmessage!(bp_cache, e, m) - end - return +function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) + return message(bp_cache, edge) * message(bp_cache, reverse(edge)) end -function deletemessages!( - bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache) - ) - for e in edges - deletemessage!(bp_cache, e) - end - return bp_cache +function region_scalar(bp_cache::AbstractGraph, vertex) + + messages = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + + return reduce(*, messages) * reduce(*, state) end -function vertex_scalars( - bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs... - ) - return map(v -> region_scalar(bp_cache, v; kwargs...), vertices) +message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) +message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) +message_type(type::Type{<:AbstractDataGraph}) = edge_data_eltype(type) + +function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) + return map(v -> region_scalar(bp_cache, v), vertices) end -function edge_scalars( - bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs... - ) - return map(e -> region_scalar(bp_cache, e; kwargs...), edges) +function edge_scalars(bp_cache::AbstractGraph, edges = edges(bp_cache)) + return map(e -> region_scalar(bp_cache, e), edges) end -function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache) +function scalar_factors_quotient(bp_cache::AbstractGraph) return vertex_scalars(bp_cache), edge_scalars(bp_cache) end -function incoming_messages( - bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = [] - ) - b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in) +function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = []) + b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in) b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges return messages(bp_cache, b_edges) end -function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) - return incoming_messages(bp_cache, [vertex]; kwargs...) -end +default_messages(::AbstractGraph) = not_implemented() #Adapt interface for changing device -function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache)) - bp_cache = copy(bp_cache) +map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es) +function map_messages!(f, bp_cache, es = edges(bp_cache)) for e in es setmessage!(bp_cache, e, f(message(bp_cache, e))) end return bp_cache end -function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache)) - bp_cache = copy(bp_cache) + +map_factors(f, bp_cache, vs = vertices(bp_cache)) = map_factors!(f, copy(bp_cache), vs) +function map_factors!(f, bp_cache, vs = vertices(bp_cache)) for v in vs setfactor!(bp_cache, v, f(factor(bp_cache, v))) end return bp_cache end -function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...) - return map_messages(adapt(to), bp_cache, args...) -end -function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...) - return map_factors(adapt(to), bp_cache, args...) -end -function freenergy(bp_cache::AbstractBeliefPropagationCache) +adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es) +adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs) + +abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end + +function free_energy(bp_cache::AbstractBeliefPropagationCache) numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) if any(t -> real(t) < 0, numerator_terms) numerator_terms = complex.(numerator_terms) @@ -123,29 +130,4 @@ function freenergy(bp_cache::AbstractBeliefPropagationCache) any(iszero, denominator_terms) && return -Inf return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end - -function partitionfunction(bp_cache::AbstractBeliefPropagationCache) - return exp(freenergy(bp_cache)) -end - -function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) - return rescale_messages(bp_cache, [edge]) -end - -function rescale_messages(bp_cache::AbstractBeliefPropagationCache) - return rescale_messages(bp_cache, edges(bp_cache)) -end - -function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...) - return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...) -end - -function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...) - return rescale_vertices(bpc, [vertex]; kwargs...) -end - -function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) - bpc = rescale_messages(bpc) - bpc = rescale_partitions(bpc, args...; kwargs...) - return bpc -end +partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index b3a32b1..4e441fb 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,145 +1,93 @@ +using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components -using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges +using NamedGraphs: convert_vertextype +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph using ITensorBase: ITensor, dim +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph -struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: - AbstractBeliefPropagationCache{V} +struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: + AbstractBeliefPropagationCache{V, MT} network::N - messages::Dictionary + messages::Dictionary{ET, MT} end -messages(bp_cache::BeliefPropagationCache) = bp_cache.messages -network(bp_cache::BeliefPropagationCache) = bp_cache.network +network(bp_cache) = underlying_graph(bp_cache) -BeliefPropagationCache(network) = BeliefPropagationCache(network, Dictionary()) - -function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :network) +DataGraphs.edge_data(bpc::BeliefPropagationCache) = getfield(bpc, :messages) +DataGraphs.vertex_data(bpc::BeliefPropagationCache) = vertex_data(network(bpc)) +function DataGraphs.underlying_graph_type(type::Type{<:BeliefPropagationCache}) + return fieldtype(type, :network) end -function deletemessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge) - ms = messages(bp_cache) - delete!(ms, e) - return bp_cache -end +message_type(::Type{<:BeliefPropagationCache{V, N, ET, MT}}) where {V, N, ET, MT} = MT -function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) - ms = messages(bp_cache) - set!(ms, e, message) - return bp_cache +function BeliefPropagationCache(alg, network::AbstractGraph) + es = collect(edges(network)) + es = vcat(es, reverse.(es)) + messages = map(edge -> default_message(alg, network, edge), es) + return BeliefPropagationCache(network, Dictionary(es, messages)) end -function setmessages!(bpc_dst::BeliefPropagationCache, bpc_src::BeliefPropagationCache, edges) - ms_dst = messages(bpc_dst) - for e in edges - set!(ms_dst, e, message(bpc_src, e)) - end - return bpc_dst +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) end -function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) - ms = messages(bp_cache) - return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) +# TODO: This needs to go in DataGraphsGraphsExtensionsExt +# +# This function is problematic when `ng isa TensorNetwork` as it relies on deleting edges +# and taking subgraphs, which is not always well-defined for the `TensorNetwork` type, +# hence we just strip off any `AbstractDataGraph` data to avoid this. +function forest_cover_edge_sequence(g::AbstractDataGraph; kwargs...) + return forest_cover_edge_sequence(underlying_graph(g); kwargs...) end - -function messages(bp_cache::BeliefPropagationCache, edges::Vector{<:AbstractEdge}) - return [message(bp_cache, e) for e in edges] +# TODO: This needs to go in PartitionedGraphsGraphsExtensionsExt +# +# While it is not at all necessary to explictly instantiate the `QuotientView`, it allows the +# data of a data graph to be removed using the above method if `parent_type(g)` is an +# `AbstractDataGraph`. +function forest_cover_edge_sequence(g::QuotientView; kwargs...) + return forest_cover_edge_sequence(quotient_graph(parent(g)); kwargs...) end - -#Forward onto the network -for f in [ - :(Graphs.vertices), - :(Graphs.edges), - :(Graphs.is_tree), - :(NamedGraphs.GraphsExtensions.boundary_edges), - :(factors), - :(default_bp_maxiter), - :(ITensorNetworksNext.setfactor!), - :(ITensorNetworksNext.linkinds), - :(ITensorNetworksNext.underlying_graph), - ] - @eval begin - function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) - return $f(network(bp_cache), args...; kwargs...) +# TODO: This needs to go in GraphsExtensions +function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) + add_edges!(g, edges(g)) + forests = forest_cover(g) + rv = edgetype(g)[] + for forest in forests + trees = [forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) end end + return rv end -function factors(tn::AbstractTensorNetwork, vertex) - return [tn[vertex]] -end - -function region_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge) - return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] -end - -function region_scalar(bp_cache::BeliefPropagationCache, vertex) - incoming_ms = incoming_messages(bp_cache, vertex) - state = factors(bp_cache, vertex) - return (reduce(*, incoming_ms) * reduce(*, state))[] -end - -function default_message(bp_cache::BeliefPropagationCache, edge::AbstractEdge) - return default_message(network(bp_cache), edge::AbstractEdge) -end - -function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) - t = ITensor(ones(dim.(linkinds(tn, edge))...), linkinds(tn, edge)...) - #TODO: Get datatype working on tensornetworks so we can support GPU, etc... - return t -end - -#TODO: Update message etc should go here... -function updated_message( - alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge - ) - vertex = src(edge) - incoming_ms = incoming_messages( - bp_cache, vertex; ignore_edges = typeof(edge)[reverse(edge)] - ) - state = factors(bp_cache, vertex) - #contract_list = ITensor[incoming_ms; state] - #sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg) - #updated_messages = contract(contract_list; sequence) - updated_message = - !isempty(incoming_ms) ? reduce(*, state) * reduce(*, incoming_ms) : reduce(*, state) - if alg.normalize - message_norm = LinearAlgebra.norm(updated_message) - if !iszero(message_norm) - updated_message /= message_norm +function bpcache_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(network(graph), subvertices) + subgraph = BeliefPropagationCache(underlying_subgraph, typeof(edge_data(graph))()) + for e in edges(subgraph) + if isassigned(graph, e) + set!(edge_data(subgraph), e, graph[e]) end end - return updated_message + return subgraph, vlist end -function default_algorithm( - ::Type{<:Algorithm"contract"}; normalize = true, sequence_alg = "optimal" - ) - return Algorithm("contract"; normalize, sequence_alg) +function Graphs.induced_subgraph(graph::BeliefPropagationCache, subvertices) + return bpcache_induced_subgraph(graph, subvertices) end -function default_algorithm( - ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") - ) - return Algorithm("adapt_update"; adapt, alg) +# For method ambiguity +function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::AbstractVector{V}) where {V <: Int} + return bpcache_induced_subgraph(graph, subvertices) end -function update_message!( - message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge - ) - return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) -end +## PartitionedGraphs -#Edge sequence stuff -function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) - forests = forest_cover(g) - edges = edgetype(g)[] - for forest in forests - trees = [forest[vs] for vs in connected_components(forest)] - for tree in trees - tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) - push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) - end - end - return edges +function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) + qview = QuotientView(network(bpc)) + messages = edge_data(QuotientView(bpc)) + return BeliefPropagationCache(qview, messages) end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 967b454..a05c97a 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,70 +1,121 @@ -mutable struct BeliefPropagationProblem{V, Cache <: AbstractBeliefPropagationCache{V}} <: - AbstractProblem +using Distributed: WorkerPool, @everywhere, remotecall, myid, waitall, workers +using Graphs: SimpleGraph, vertices, edges, has_edge +using NamedGraphs: AbstractNamedGraph, position_graph +using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices +using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices + +abstract type AbstractBeliefPropagationProblem{Alg} <: AbstractProblem end + +mutable struct BeliefPropagationProblem{Alg, Cache} <: AbstractBeliefPropagationProblem{Alg} + const alg::Alg const cache::Cache diff::Union{Nothing, Float64} end +BeliefPropagationProblem(alg, cache) = BeliefPropagationProblem(alg, cache, nothing) + function default_algorithm( ::Type{<:Algorithm"bp"}, - bpc::BeliefPropagationCache; + bpc; verbose = false, tol = nothing, - edge_sequence = forest_cover_edge_sequence(network(bpc)), + edge_sequence = forest_cover_edge_sequence(bpc), message_update_alg = default_algorithm(Algorithm"contract"), maxiter = is_tree(bpc) ? 1 : nothing, ) return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) end -function compute!(iter::RegionIterator{<:BeliefPropagationProblem}) - prob = iter.problem +function region_plan(prob::BeliefPropagationProblem{<:Algorithm"bp"}; sweep_kwargs...) + edges = prob.alg.edge_sequence - edge_group, kwargs = current_region_plan(iter) + plan = map(edges) do e + return e => (; sweep_kwargs...) + end - new_message_tensors = map(edge_group) do edge - old_message = message(prob.cache, edge) + return plan +end - new_message = updated_message(kwargs.message_update_alg, prob.cache, edge) +function compute!(iter::RegionIterator{<:BeliefPropagationProblem{<:Algorithm"bp"}}) + prob = iter.problem - if !isnothing(prob.diff) - # TODO: Define `message_diff` - prob.diff += message_diff(new_message, old_message) - end + edge, _ = current_region_plan(iter) + new_message = updated_message(prob.alg.message_update_alg, prob.cache, edge) + setmessage!(prob.cache, edge, new_message) - return new_message - end + return iter +end - foreach(edge_group, new_message_tensors) do edge, new_message - setmessage!(prob.cache, edge, new_message) - end +default_message(alg, network, edge) = default_message(typeof(alg), network, edge) - return iter +default_message(::Type{<:Algorithm}, network, edge) = not_implemented() +function default_message(::Type{<:Algorithm"bp"}, network, edge) + + #TODO: Get datatype working on tensornetworks so we can support GPU, etc... + links = linkinds(network, edge) + data = ones(dim.(links)...) + + t = ITensor(data, links) + return t end -function region_plan( - prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... +updated_message(alg, bpc, edge) = not_implemented() +function updated_message(alg::Algorithm"contract", bpc, edge) + vertex = src(edge) + + incoming_ms = incoming_messages( + bpc, vertex; ignore_edges = typeof(edge)[reverse(edge)] ) - edges = forest_cover_edge_sequence(network(prob.cache); root_vertex) + updated_message = contract_messages(alg.contraction_alg, factors(bpc, vertex), incoming_ms) - plan = map(edges) do e - return [e] => (; sweep_kwargs...) + if alg.normalize + message_norm = LinearAlgebra.norm(updated_message) + if !iszero(message_norm) + updated_message /= message_norm + end end + return updated_message +end - return plan +contract_messages(alg, factors, messages) = not_implemented() +function contract_messages( + alg, + factors::Vector{<:AbstractArray}, + messages::Vector{<:AbstractArray}, + ) + return contract_network(alg, vcat(factors, messages)) +end + +function default_algorithm( + ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = default_algorithm(Algorithm"exact") + ) + return Algorithm("contract"; normalize, contraction_alg) +end +function default_algorithm( + ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") + ) + return Algorithm("adapt_update"; adapt, alg) +end + +function update_message!( + message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge + ) + return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) end function update(bpc::AbstractBeliefPropagationCache; kwargs...) return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) end -function update(alg::Algorithm"bp", bpc) + +function update(alg, bpc) compute_error = !isnothing(alg.tol) diff = compute_error ? 0.0 : nothing - prob = BeliefPropagationProblem(bpc, diff) + prob = BeliefPropagationProblem(alg, bpc, diff) - iter = SweepIterator(prob, alg.maxiter; compute_error, getfield(alg, :kwargs)...) + iter = SweepIterator(prob, alg.maxiter; compute_error) for _ in iter if compute_error && prob.diff <= alg.tol From db46c04214ed93c05a6bbcc7d88b06c2745f9c34 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 12:47:19 -0500 Subject: [PATCH 10/64] Remove dead deps --- src/beliefpropagation/beliefpropagationproblem.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index a05c97a..f487ccc 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,4 +1,3 @@ -using Distributed: WorkerPool, @everywhere, remotecall, myid, waitall, workers using Graphs: SimpleGraph, vertices, edges, has_edge using NamedGraphs: AbstractNamedGraph, position_graph using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices From 400e373b9fbb7205359bfe5914ba8d6e0763cd16 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 13:05:45 -0500 Subject: [PATCH 11/64] Fix merge --- src/beliefpropagation/beliefpropagationproblem.jl | 2 +- src/tensornetwork.jl | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index f487ccc..61c97df 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -87,7 +87,7 @@ function contract_messages( end function default_algorithm( - ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = default_algorithm(Algorithm"exact") + ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = Algorithm("exact") ) return Algorithm("contract"; normalize, contraction_alg) end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 11c2e88..44b883a 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -4,7 +4,7 @@ using Dictionaries: AbstractDictionary, Indices, dictionary using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype -using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, vertextype +using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, PartitionedGraphs, @@ -87,10 +87,7 @@ function fix_links!(tn::AbstractTensorNetwork) for e in setdiff(arranged_edges(graph), tn_edges) insert_trivial_link!(tn, e) end - for edge in setdiff(arranged_edges(graph), arranged_edges(graph_structure)) - insert_trivial_link!(network, edge) - end - return network + return tn end # Determine the graph structure from the tensors. From b9aafe890f235c0543d7b209a46fbb86ce9f3b70 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 13:12:01 -0500 Subject: [PATCH 12/64] Fix type inference in TensorNetwork construction --- src/tensornetwork.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 44b883a..0681da5 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -66,8 +66,7 @@ end tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors) function TensorNetwork(f::Base.Callable, graph::AbstractGraph) - tensors = Dictionary(vertices(graph), f.(vertices(graph))) - return TensorNetwork(graph, tensors) + return TensorNetwork(graph, Dictionary(map(f, vertices(graph)))) end function TensorNetwork(graph::AbstractGraph, tensors) tn = _TensorNetwork(graph, tensors) From 4090e61f0069084ffd64ff53f65095ea3d05353c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:16:04 +0000 Subject: [PATCH 13/64] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/test_beliefpropagation.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index fc657e7..a39e1a6 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -51,4 +51,3 @@ using Test: @test, @testset z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test abs(z_bp - z_exact) <= 1.0e-14 end - From be0750ee8f0ea1323eb94de8c14eec4490ef1995 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 25 Nov 2025 16:45:45 -0500 Subject: [PATCH 14/64] Remove `ITensorBase` dep --- Project.toml | 2 -- src/beliefpropagation/beliefpropagationcache.jl | 1 - src/beliefpropagation/beliefpropagationproblem.jl | 6 ++---- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index e0aea23..95b8be0 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" -ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" @@ -40,7 +39,6 @@ DerivableInterfaces = "0.5.5" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" -ITensorBase = "0.2.14" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.8, 0.9" diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 4e441fb..5d8fa35 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -3,7 +3,6 @@ using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components using NamedGraphs: convert_vertextype using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph -using ITensorBase: ITensor, dim using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 61c97df..49d0ef8 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -52,10 +52,8 @@ function default_message(::Type{<:Algorithm"bp"}, network, edge) #TODO: Get datatype working on tensornetworks so we can support GPU, etc... links = linkinds(network, edge) - data = ones(dim.(links)...) - - t = ITensor(data, links) - return t + data = ones(Tuple(links)) + return data end updated_message(alg, bpc, edge) = not_implemented() From b971b89a91954d4175160c9788e2974267dc6fdc Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 1 Dec 2025 17:24:09 -0500 Subject: [PATCH 15/64] `forest_cover_edge_sequence` now constructs a temporary `NamedGraph` instead of trying to operate on existing graphs The reason for this is: - One only cares about the edges of the input graph - A simple graph cannot be used as it "forgets" its edge names resulting in recursion - As shown with `TensorNetwork`, removing edges may not always be defined. --- .../beliefpropagationcache.jl | 22 ++++--------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 5d8fa35..994f480 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -33,25 +33,11 @@ function Base.copy(bp_cache::BeliefPropagationCache) return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) end -# TODO: This needs to go in DataGraphsGraphsExtensionsExt -# -# This function is problematic when `ng isa TensorNetwork` as it relies on deleting edges -# and taking subgraphs, which is not always well-defined for the `TensorNetwork` type, -# hence we just strip off any `AbstractDataGraph` data to avoid this. -function forest_cover_edge_sequence(g::AbstractDataGraph; kwargs...) - return forest_cover_edge_sequence(underlying_graph(g); kwargs...) -end -# TODO: This needs to go in PartitionedGraphsGraphsExtensionsExt -# -# While it is not at all necessary to explictly instantiate the `QuotientView`, it allows the -# data of a data graph to be removed using the above method if `parent_type(g)` is an -# `AbstractDataGraph`. -function forest_cover_edge_sequence(g::QuotientView; kwargs...) - return forest_cover_edge_sequence(quotient_graph(parent(g)); kwargs...) -end # TODO: This needs to go in GraphsExtensions -function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) - add_edges!(g, edges(g)) +function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_root_vertex) + # All we care about are the edges so the type of the graph doesnt matter + g = NamedGraph(vertices(gi)) + add_edges!(g, edges(gi)) forests = forest_cover(g) rv = edgetype(g)[] for forest in forests From 9ebf0310c19fdf661cf6afd39c294710f167918b Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:42:36 -0500 Subject: [PATCH 16/64] [LazyNamedDimsArrays] Fix `parenttype` method --- src/LazyNamedDimsArrays/lazynameddimsarray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/LazyNamedDimsArrays/lazynameddimsarray.jl b/src/LazyNamedDimsArrays/lazynameddimsarray.jl index b0ed86a..c269902 100644 --- a/src/LazyNamedDimsArrays/lazynameddimsarray.jl +++ b/src/LazyNamedDimsArrays/lazynameddimsarray.jl @@ -7,7 +7,7 @@ using WrappedUnions: @wrapped union::Union{A, Mul{LazyNamedDimsArray{T, A}}} end -parenttype(::Type{LazyNamedDimsArray{<:Any, A}}) where {A} = A +parenttype(::Type{LazyNamedDimsArray{T, A}}) where {T, A} = A parenttype(::Type{LazyNamedDimsArray{T}}) where {T} = AbstractNamedDimsArray{T} parenttype(::Type{LazyNamedDimsArray}) = AbstractNamedDimsArray From 16fe303b73ab7f9ab3f5a1c46118319063a7af4a Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:46:08 -0500 Subject: [PATCH 17/64] BP Cache now uses new `DataGraphs`interface --- .../abstractbeliefpropagationcache.jl | 13 +-- .../beliefpropagationcache.jl | 101 +++++++++++++----- 2 files changed, 82 insertions(+), 32 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 8c6b3dd..0cae3fa 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -3,11 +3,13 @@ using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent -messages(::AbstractGraph) = not_implemented() -messages(bp_cache::AbstractDataGraph) = edge_data(bp_cache) +messages(bp_cache::AbstractGraph) = edge_data(bp_cache) messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] -message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] +function message(bp_cache::AbstractGraph, edge::AbstractEdge) + ms = messages(bp_cache) + return get!(ms, edge, default_message(bp_cache, edge)) +end deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() function deletemessage!(bp_cache::AbstractDataGraph, edge) @@ -25,8 +27,7 @@ end setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented() function setmessage!(bp_cache::AbstractDataGraph, edge, message) - ms = messages(bp_cache) - set!(ms, edge, message) + setindex!(bp_cache, message, edge) return bp_cache end function setmessage!(bp_cache::QuotientView, edge, message) @@ -56,7 +57,7 @@ factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex] setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() function setfactor!(bpc::AbstractDataGraph, vertex, factor) fs = factors(bpc) - set!(fs, vertex, factor) + setindex!(fs, vertex, factor) return bpc end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 994f480..c9793e6 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,32 +1,85 @@ -using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph +using DataGraphs: + DataGraphs, + AbstractDataGraph, + DataGraph, + has_edge_data, + get_vertex_data, + get_edge_data, + set_vertex_data!, + set_edge_data!, + unset_vertex_data!, + unset_edge_data!, + vertex_data_eltype, + edge_data_eltype, + underlying_graph, + underlying_graph_type using Dictionaries: Dictionary, set!, delete! -using Graphs: AbstractGraph, is_tree, connected_components -using NamedGraphs: convert_vertextype -using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, is_path_graph -using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, quotient_graph +using Graphs: AbstractGraph, is_tree, connected_components, is_directed +using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices +using NamedGraphs.GraphsExtensions: default_root_vertex, + forest_cover, + post_order_dfs_edges, + vertextype, + is_path_graph, + undirected_graph +using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges -struct BeliefPropagationCache{V, N <: AbstractGraph{V}, ET, MT} <: +struct BeliefPropagationCache{V, G <: AbstractGraph{V}, N <: AbstractGraph{V}, ET, MT} <: AbstractBeliefPropagationCache{V, MT} + underlying_graph::G # we only use this for the edges. network::N messages::Dictionary{ET, MT} + function BeliefPropagationCache(network::AbstractGraph, messages::Dictionary) + + V = vertextype(network) + N = typeof(network) + ET = keytype(messages) + MT = eltype(messages) + + # Construct a directed graph version of the underlying graph of the tensor network. + digraph = directed_graph(underlying_graph(network)) + + bpc = new{V, typeof(digraph), N, ET, MT}(digraph, network, messages) + + for edge in edges(bpc) + get!(() -> default_message(bpc, edge), messages, edge) + end + return bpc + end end -network(bp_cache) = underlying_graph(bp_cache) +network(bp_cache) = getfield(bp_cache, :network) + +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :underlying_graph) + +DataGraphs.has_vertex_data(bpc::BeliefPropagationCache, vertex) = has_vertex_data(network(bpc), vertex) +DataGraphs.has_edge_data(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) -DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :network) -DataGraphs.edge_data(bpc::BeliefPropagationCache) = getfield(bpc, :messages) -DataGraphs.vertex_data(bpc::BeliefPropagationCache) = vertex_data(network(bpc)) -function DataGraphs.underlying_graph_type(type::Type{<:BeliefPropagationCache}) - return fieldtype(type, :network) +DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = get_vertex_data(network(bpc), vertex) +DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] + +DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set_vertex_data!(network(bpc), val, vertex) +DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) + +DataGraphs.unset_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = unset_vertex_data!(network(bpc), val, vertex) +DataGraphs.unset_edge_data!(bpc::BeliefPropagationCache, val, edge) = unset!(bpc.messages, edge, val) + +function DataGraphs.vertex_data_eltype(T::Type{<:BeliefPropagationCache}) + return vertex_data_eltype(fieldtype(T, :network)) +end +function DataGraphs.edge_data_eltype(T::Type{<:BeliefPropagationCache}) + return eltype(fieldtype(T, :messages)) end -message_type(::Type{<:BeliefPropagationCache{V, N, ET, MT}}) where {V, N, ET, MT} = MT +message_type(T::Type{<:BeliefPropagationCache}) = edge_data_eltype(T) -function BeliefPropagationCache(alg, network::AbstractGraph) - es = collect(edges(network)) - es = vcat(es, reverse.(es)) - messages = map(edge -> default_message(alg, network, edge), es) - return BeliefPropagationCache(network, Dictionary(es, messages)) +function BeliefPropagationCache(network::AbstractGraph) + MT = vertex_data_eltype(typeof(network)) + return BeliefPropagationCache(MT, network) +end +function BeliefPropagationCache(MT::Type, network::AbstractGraph) + dict = Dictionary{edgetype(network), MT}() + return BeliefPropagationCache(network, dict) end function Base.copy(bp_cache::BeliefPropagationCache) @@ -61,18 +114,14 @@ function bpcache_induced_subgraph(graph, subvertices) return subgraph, vlist end -function Graphs.induced_subgraph(graph::BeliefPropagationCache, subvertices) - return bpcache_induced_subgraph(graph, subvertices) -end -# For method ambiguity -function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::AbstractVector{V}) where {V <: Int} +function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::Vector{V}) where {V} return bpcache_induced_subgraph(graph, subvertices) end ## PartitionedGraphs function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) - qview = QuotientView(network(bpc)) - messages = edge_data(QuotientView(bpc)) - return BeliefPropagationCache(qview, messages) + inds = Indices(parent_graph_indices(QuotientEdges(underlying_graph(bpc)))) + data = map(e -> bpc[QuotientEdge(e)], inds) + return BeliefPropagationCache(QuotientView(network(bpc)), data) end From 24a4335f61699a2d818f8b75a8b2867f7a16b3b5 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:46:49 -0500 Subject: [PATCH 18/64] Adjust `default_message` to take a `message` type as its first argument --- .../beliefpropagationproblem.jl | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 49d0ef8..24b024d 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -2,6 +2,9 @@ using Graphs: SimpleGraph, vertices, edges, has_edge using NamedGraphs: AbstractNamedGraph, position_graph using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices +using NamedDimsArrays: AbstractNamedDimsArray +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, parenttype, lazy + abstract type AbstractBeliefPropagationProblem{Alg} <: AbstractProblem end @@ -45,15 +48,16 @@ function compute!(iter::RegionIterator{<:BeliefPropagationProblem{<:Algorithm"bp return iter end -default_message(alg, network, edge) = default_message(typeof(alg), network, edge) - -default_message(::Type{<:Algorithm}, network, edge) = not_implemented() -function default_message(::Type{<:Algorithm"bp"}, network, edge) - - #TODO: Get datatype working on tensornetworks so we can support GPU, etc... - links = linkinds(network, edge) - data = ones(Tuple(links)) - return data +function default_message(bpc::BeliefPropagationCache, edge) + return default_message(message_type(bpc), network(bpc), edge) +end +function default_message(T::Type, network, edge) + array = ones(Tuple(linkinds(network, edge))) + return convert(T, array) +end +function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) + message = default_message(parenttype(T), network, edge) + return convert(T, lazy(message)) end updated_message(alg, bpc, edge) = not_implemented() From c43884ecb5185386ab5acc6c08f4344c0d566811 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:47:44 -0500 Subject: [PATCH 19/64] Remove unnecessary code and fix ambiguities in `AbstractTensorNetwork` --- src/abstracttensornetwork.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index b02c789..b820867 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -53,10 +53,6 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) # Overload if needed Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false -# AbstractDataGraphs overloads -DataGraphs.vertex_data(::AbstractTensorNetwork) = not_implemented() -DataGraphs.edge_data(::AbstractTensorNetwork) = not_implemented() - DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented() function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) return NamedGraphs.vertex_positions(underlying_graph(tn)) @@ -240,10 +236,7 @@ end Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) -function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices::AbstractVector{V}) where {V <: Int} - return tensornetwork_induced_subgraph(graph, subvertices) -end -function Graphs.induced_subgraph(graph::AbstractTensorNetwork, subvertices) +function Graphs.induced_subgraph(graph::AbstractTensorNetwork{V}, subvertices::Vector{V}) where {V} return tensornetwork_induced_subgraph(graph, subvertices) end From dd6f6454f01380e03e609cd60b1d4bfdf5499718 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 6 Jan 2026 09:48:10 -0500 Subject: [PATCH 20/64] `TensorNetwork` type now uses new DataGraphs interface --- src/tensornetwork.jl | 50 +++++++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 0681da5..16c80e3 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,9 +1,9 @@ using Combinatorics: combinations using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph -using Dictionaries: AbstractDictionary, Indices, dictionary +using Dictionaries: AbstractDictionary, Indices, dictionary, set!, unset! using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! using NamedDimsArrays: AbstractNamedDimsArray, dimnames -using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype +using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype, Vertices, parent_graph_indices using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, @@ -12,9 +12,13 @@ using NamedGraphs.PartitionedGraphs: partitioned_vertices, partitionedgraph, quotient_graph, - quotient_graph_type + quotient_graph_type, + QuotientVertex, + QuotientVertices, + QuotientVertexVertices, + quotientvertices using .LazyNamedDimsArrays: lazy, Mul -using DataGraphs: vertex_data_eltype, vertex_data, edge_data +using DataGraphs: vertex_data_eltype, vertex_data, edge_data, get_vertices_data using DataGraphs.DataGraphsPartitionedGraphsExt function _TensorNetwork end @@ -31,7 +35,7 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar end end # This assumes the tensor connectivity matches the graph structure. -function _TensorNetwork(graph::AbstractGraph, tensors) +function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) end @@ -39,10 +43,18 @@ function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: Abstra return _TensorNetwork(graph, Tensors()) end -DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) -DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) -DataGraphs.edge_data(tn::TensorNetwork) = Dictionary{edgetype(tn), Nothing}() -DataGraphs.vertex_data_eltype(T::Type{<:TensorNetwork}) = eltype(fieldtype(T, :tensors)) +# DataGraphs interface + +DataGraphs.underlying_graph(tn::TensorNetwork) = tn.underlying_graph + +DataGraphs.has_vertex_data(tn::TensorNetwork, v) = haskey(tn.tensors, v) +DataGraphs.has_edge_data(tn::TensorNetwork, e) = false + +DataGraphs.get_vertex_data(tn::TensorNetwork, v) = tn.tensors[v] + +DataGraphs.set_vertex_data!(tn::TensorNetwork, val, v) = set!(tn.tensors, v, val) +DataGraphs.unset_vertex_data!(tn::TensorNetwork, val, v) = unset!(tn.tensors, v, val) + function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) end @@ -123,17 +135,23 @@ function Graphs.rem_edge!(tn::TensorNetwork, e) return true end -function GraphsExtensions.graph_from_vertices(type::Type{<:TensorNetwork}, vertices) +function GraphsExtensions.similar(type::Type{<:TensorNetwork}) DT = fieldtype(type, :tensors) empty_dict = DT() - return TensorNetwork(similar_graph(underlying_graph_type(type), vertices), empty_dict) + return TensorNetwork(similar_graph(underlying_graph_type(type)), empty_dict) end ## PartitionedGraphs function PartitionedGraphs.quotient_graph(tn::TensorNetwork) ug = quotient_graph(underlying_graph(tn)) - return TensorNetwork(ug, vertex_data(QuotientView(tn))) + + inds = Indices(parent_graph_indices(QuotientVertices(tn))) + data = map(v -> tn[QuotientVertex(v)], inds) + + return TensorNetwork(ug, data) end +# TODO: This method should not be required with a better interface with a better +# DataGraphsPartitionedGraphsExt interface. function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) UG = quotient_graph_type(underlying_graph_type(type)) VD = Vector{vertex_data_eltype(type)} @@ -141,9 +159,10 @@ function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) return TensorNetwork{V, VD, UG, Dictionary{V, VD}} end +# Partition the underlying graph of the tensor network; does not affect the data. function PartitionedGraphs.partitionedgraph(tn::TensorNetwork, parts) pg = partitionedgraph(underlying_graph(tn), parts) - return TensorNetwork(pg, vertex_data(tn)) + return TensorNetwork(pg, copy(vertex_data(tn))) end PartitionedGraphs.departition(tn::TensorNetwork) = tn @@ -153,8 +172,9 @@ function PartitionedGraphs.departition( return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) end -function DataGraphsPartitionedGraphsExt.to_quotient_vertex_data(::TensorNetwork, data) - return mapreduce(lazy, *, collect(last(data))) +function DataGraphs.get_vertices_data(tn::TensorNetwork, vertex::QuotientVertexVertices) + data = collect(map(v -> tn[v], NamedGraphs.parent_graph_indices(vertex))) + return mapreduce(lazy, *, data) end function PartitionedGraphs.quotientview(tn::TensorNetwork) From 7bb579c7037c93e591a09a0c88e3aa489ef39c5d Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Fri, 19 Dec 2025 16:37:59 -0500 Subject: [PATCH 21/64] Sweeping algorithms based on AlgorithmsInterface.jl (#30) --- Project.toml | 4 +- docs/Project.toml | 2 +- examples/Project.toml | 2 +- .../AlgorithmsInterfaceExtensions.jl | 306 ++++++++++++ src/ITensorNetworksNext.jl | 6 +- src/abstract_problem.jl | 1 - src/adapters.jl | 45 -- src/iterators.jl | 170 ------- src/sweeping/eigenproblem.jl | 44 ++ src/sweeping/utils.jl | 12 + test/Project.toml | 3 +- test/test_algorithmsinterfaceextensions.jl | 472 ++++++++++++++++++ test/test_aqua.jl | 2 +- test/test_dmrg.jl | 34 ++ test/test_iterators.jl | 221 -------- test/test_sweeping.jl | 65 +++ 16 files changed, 944 insertions(+), 445 deletions(-) create mode 100644 src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl delete mode 100644 src/abstract_problem.jl delete mode 100644 src/adapters.jl delete mode 100644 src/iterators.jl create mode 100644 src/sweeping/eigenproblem.jl create mode 100644 src/sweeping/utils.jl create mode 100644 test/test_algorithmsinterfaceextensions.jl create mode 100644 test/test_dmrg.jl delete mode 100644 test/test_iterators.jl create mode 100644 test/test_sweeping.jl diff --git a/Project.toml b/Project.toml index 95b8be0..e6919fc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.2.4" +version = "0.3.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" @@ -32,6 +33,7 @@ ITensorNetworksNextTensorOperationsExt = "TensorOperations" [compat] AbstractTrees = "0.4.5" Adapt = "4.3" +AlgorithmsInterface = "0.1.0" BackendSelection = "0.1.6" Combinatorics = "1" DataGraphs = "0.2.7" diff --git a/docs/Project.toml b/docs/Project.toml index 15d156a..9e273b0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,5 +8,5 @@ ITensorNetworksNext = {path = ".."} [compat] Documenter = "1" -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" Literate = "2" diff --git a/examples/Project.toml b/examples/Project.toml index a9cd21b..bd688e9 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,4 +5,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" ITensorNetworksNext = {path = ".."} [compat] -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl new file mode 100644 index 0000000..a8c814e --- /dev/null +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -0,0 +1,306 @@ +module AlgorithmsInterfaceExtensions + +import AlgorithmsInterface as AI + +#========================== Patches for AlgorithmsInterface.jl ============================# + +abstract type Problem <: AI.Problem end +abstract type Algorithm <: AI.Algorithm end +abstract type State <: AI.State end + +function AI.initialize_state!( + problem::Problem, algorithm::Algorithm, state::State; iteration = 0, kwargs... + ) + for (k, v) in pairs(kwargs) + setproperty!(state, k, v) + end + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + +function AI.initialize_state( + problem::Problem, algorithm::Algorithm; kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + return DefaultState(; stopping_criterion_state, kwargs...) +end + +#============================ DefaultState ================================================# + +@kwdef mutable struct DefaultState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, + } <: State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +#============================ increment! ==================================================# + +# Custom version of `increment!` that also takes the problem and algorithm as arguments. +function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) + return AI.increment!(state) +end + +#============================ solve! ======================================================# + +# Custom version of `solve!` that allows specifying the logger and also overloads +# `increment!` on the problem and algorithm. +function basetypenameof(x) + return Symbol(last(split(String(Symbol(Base.typename(typeof(x)).wrapper)), "."))) +end +default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_) +function default_logging_context_prefix(problem::Problem, algorithm::Algorithm) + return Symbol( + default_logging_context_prefix(problem), + default_logging_context_prefix(algorithm), + ) +end +function AI.solve!( + problem::Problem, algorithm::Algorithm, state::State; + logging_context_prefix = default_logging_context_prefix(problem, algorithm), + kwargs..., + ) + logger = AI.algorithm_logger() + + context_suffixes = [:Start, :PreStep, :PostStep, :Stop] + contexts = Dict(context_suffixes .=> Symbol.(logging_context_prefix, context_suffixes)) + + # initialize the state and emit message + AI.initialize_state!(problem, algorithm, state; kwargs...) + AI.emit_message(logger, problem, algorithm, state, contexts[:Start]) + + # main body of the algorithm + while !AI.is_finished!(problem, algorithm, state) + AI.increment!(problem, algorithm, state) + + # logging event between convergence check and algorithm step + AI.emit_message(logger, problem, algorithm, state, contexts[:PreStep]) + + # algorithm step + AI.step!(problem, algorithm, state; logging_context_prefix) + + # logging event between algorithm step and convergence check + AI.emit_message(logger, problem, algorithm, state, contexts[:PostStep]) + end + + # emit message about finished state + AI.emit_message(logger, problem, algorithm, state, contexts[:Stop]) + return state +end + +function AI.solve( + problem::Problem, algorithm::Algorithm; + logging_context_prefix = default_logging_context_prefix(problem, algorithm), + kwargs..., + ) + state = AI.initialize_state(problem, algorithm; kwargs...) + return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...) +end + +#============================ AlgorithmIterator ===========================================# + +abstract type AlgorithmIterator end + +function algorithm_iterator( + problem::Problem, algorithm::Algorithm, state::State + ) + return DefaultAlgorithmIterator(problem, algorithm, state) +end + +function AI.is_finished!(iterator::AlgorithmIterator) + return AI.is_finished!(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.is_finished(iterator::AlgorithmIterator) + return AI.is_finished(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.increment!(iterator::AlgorithmIterator) + return AI.increment!(iterator.problem, iterator.algorithm, iterator.state) +end +function AI.step!(iterator::AlgorithmIterator) + return AI.step!(iterator.problem, iterator.algorithm, iterator.state) +end +function Base.iterate(iterator::AlgorithmIterator, init = nothing) + AI.is_finished!(iterator) && return nothing + AI.increment!(iterator) + AI.step!(iterator) + return iterator.state, nothing +end + +struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator + problem::Problem + algorithm::Algorithm + state::State +end + +#============================ with_algorithmlogger ========================================# + +# Allow passing functions, not just CallbackActions. +@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...) + return AI.with_algorithmlogger(f, args...) +end +@inline function with_algorithmlogger(f, args::Pair{Symbol}...) + return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...) +end + +#============================ NestedAlgorithm =============================================# + +abstract type NestedAlgorithm <: Algorithm end + +function nested_algorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultNestedAlgorithm(f, nalgorithms; kwargs...) +end + +max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) + +function get_subproblem( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate +end + +function set_substate!( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State, substate::AI.State + ) + state.iterate = substate.iterate + return state +end + +function AI.step!( + problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State; + logging_context_prefix = Symbol() + ) + # Get the subproblem, subalgorithm, and substate. + subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state) + + # Solve the subproblem with the subalgorithm. + logging_context_prefix = Symbol( + logging_context_prefix, default_logging_context_prefix(subalgorithm) + ) + AI.solve!(subproblem, subalgorithm, substate; logging_context_prefix) + + # Update the state with the substate. + set_substate!(problem, algorithm, state, substate) + + return state +end + +#= + DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm}) + +An algorithm that consists of running an algorithm at each iteration +from a list of stored algorithms. +=# +@kwdef struct DefaultNestedAlgorithm{ + ChildAlgorithm <: Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +end +function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) +end + +#============================ FlattenedAlgorithm ==========================================# + +# Flatten a nested algorithm. +abstract type FlattenedAlgorithm <: Algorithm end +abstract type FlattenedAlgorithmState <: State end + +function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...) +end + +function AI.initialize_state( + problem::Problem, algorithm::FlattenedAlgorithm; kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...) +end +function AI.increment!( + problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState + ) + # Increment the total iteration count. + state.iteration += 1 + # TODO: Use `is_finished!` instead? + if state.child_iteration ≥ max_iterations(algorithm.algorithms[state.parent_iteration]) + # We're on the last iteration of the child algorithm, so move to the next + # child algorithm. + state.parent_iteration += 1 + state.child_iteration = 1 + else + # Iterate the child algorithm. + state.child_iteration += 1 + end + return state +end +function AI.step!( + problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState; + logging_context_prefix = Symbol() + ) + algorithm_sweep = algorithm.algorithms[state.parent_iteration] + state_sweep = AI.initialize_state( + problem, algorithm_sweep; + state.iterate, iteration = state.child_iteration + ) + AI.step!(problem, algorithm_sweep, state_sweep; logging_context_prefix) + state.iterate = state_sweep.iterate + return state +end + +@kwdef struct DefaultFlattenedAlgorithm{ + ChildAlgorithm <: Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: FlattenedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = + AI.StopAfterIteration(sum(max_iterations, algorithms)) +end +function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...) + return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) +end + +@kwdef mutable struct DefaultFlattenedAlgorithmState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, + } <: FlattenedAlgorithmState + iterate::Iterate + iteration::Int = 0 + parent_iteration::Int = 1 + child_iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +#============================ NonIterativeAlgorithm =======================================# + +# Algorithm that only performs a single step. +abstract type NonIterativeAlgorithm <: Algorithm end +abstract type NonIterativeAlgorithmState <: State end + +function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...) + return DefaultNonIterativeAlgorithmState(; kwargs...) +end +function AI.solve!( + problem::Problem, algorithm::NonIterativeAlgorithm, state::State; kwargs... + ) + return throw(MethodError(AI.solve!, (problem, algorithm, state))) +end + +@kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <: + NonIterativeAlgorithmState + iterate::Iterate +end + +end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index cca4b6d..d3c5c21 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -1,13 +1,13 @@ module ITensorNetworksNext +include("AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl") include("LazyNamedDimsArrays/LazyNamedDimsArrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") include("TensorNetworkGenerators/TensorNetworkGenerators.jl") include("contract_network.jl") -include("abstract_problem.jl") -include("iterators.jl") -include("adapters.jl") +include("sweeping/utils.jl") +include("sweeping/eigenproblem.jl") include("beliefpropagation/abstractbeliefpropagationcache.jl") include("beliefpropagation/beliefpropagationcache.jl") diff --git a/src/abstract_problem.jl b/src/abstract_problem.jl deleted file mode 100644 index 5a65e0a..0000000 --- a/src/abstract_problem.jl +++ /dev/null @@ -1 +0,0 @@ -abstract type AbstractProblem end diff --git a/src/adapters.jl b/src/adapters.jl deleted file mode 100644 index 28318fb..0000000 --- a/src/adapters.jl +++ /dev/null @@ -1,45 +0,0 @@ -""" - struct IncrementOnly{S<:AbstractNetworkIterator} <: AbstractNetworkIterator - -Iterator wrapper whos `compute!` function simply returns itself, doing nothing in the -process. This allows one to manually call a custom `compute!` or insert their own code it in -the loop body in place of `compute!`. -""" -struct IncrementOnly{S <: AbstractNetworkIterator} <: AbstractNetworkIterator - parent::S -end - -islaststep(adapter::IncrementOnly) = islaststep(adapter.parent) -increment!(adapter::IncrementOnly) = increment!(adapter.parent) -compute!(adapter::IncrementOnly) = adapter - -IncrementOnly(adapter::IncrementOnly) = adapter - -""" - struct EachRegion{SweepIterator} <: AbstractNetworkIterator - -Adapter that flattens each region iterator in the parent sweep iterator into a single -iterator. -""" -struct EachRegion{SI <: SweepIterator} <: AbstractNetworkIterator - parent::SI -end - -# In keeping with Julia convention. -eachregion(iter::SweepIterator) = EachRegion(iter) - -# Essential definitions -function islaststep(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - return islaststep(adapter.parent) && islaststep(region_iter) -end -function increment!(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - islaststep(region_iter) ? increment!(adapter.parent) : increment!(region_iter) - return adapter -end -function compute!(adapter::EachRegion) - region_iter = region_iterator(adapter.parent) - compute!(region_iter) - return adapter -end diff --git a/src/iterators.jl b/src/iterators.jl deleted file mode 100644 index 62d5b21..0000000 --- a/src/iterators.jl +++ /dev/null @@ -1,170 +0,0 @@ -""" - abstract type AbstractNetworkIterator - -A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins -with a call to `increment!` before executing `compute!`, however the initial call to -`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that -this call is implict. Termination of the iterator is controlled by the function `done`. -""" -abstract type AbstractNetworkIterator end - -# We use greater than or equals here as we increment the state at the start of the iteration -islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) - -function Base.iterate(iterator::AbstractNetworkIterator, init = true) - # The assumption is that first "increment!" is implicit, therefore we must skip the - # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not - # defined when length < 1, - init || islaststep(iterator) && return nothing - # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* - # define a method for increment! This way we avoid cases where one may wish to nest - # calls to different step! methods accidentaly incrementing multiple times. - init || increment!(iterator) - rv = compute!(iterator) - return rv, false -end - -increment!(iterator::AbstractNetworkIterator) = throw(MethodError(increment!, Tuple{typeof(iterator)})) -compute!(iterator::AbstractNetworkIterator) = iterator - -step!(iterator::AbstractNetworkIterator) = step!(identity, iterator) -function step!(f, iterator::AbstractNetworkIterator) - compute!(iterator) - f(iterator) - increment!(iterator) - return iterator -end - -# -# RegionIterator -# -""" - struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator -""" -mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator - problem::Problem - region_plan::RegionPlan - which_region::Int - const which_sweep::Int - function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} - if isempty(region_plan) - throw(ArgumentError("Cannot construct a region iterator with 0 elements.")) - end - return new{P, R}(problem, region_plan, 1, sweep) - end -end - -function RegionIterator(problem; sweep, sweep_kwargs...) - plan = region_plan(problem; sweep_kwargs...) - return RegionIterator(problem, plan, sweep) -end - -state(region_iter::RegionIterator) = region_iter.which_region -Base.length(region_iter::RegionIterator) = length(region_iter.region_plan) - -problem(region_iter::RegionIterator) = region_iter.problem - -function current_region_plan(region_iter::RegionIterator) - return region_iter.region_plan[region_iter.which_region] -end - -function current_region(region_iter::RegionIterator) - region, _ = current_region_plan(region_iter) - return region -end - -function region_kwargs(region_iter::RegionIterator) - _, kwargs = current_region_plan(region_iter) - return kwargs -end -function region_kwargs(f::Function, iter::RegionIterator) - return get(region_kwargs(iter), Symbol(f, :_kwargs), (;)) -end - -function prev_region(region_iter::RegionIterator) - state(region_iter) <= 1 && return nothing - prev, _ = region_iter.region_plan[region_iter.which_region - 1] - return prev -end - -function next_region(region_iter::RegionIterator) - islaststep(region_iter) && return nothing - next, _ = region_iter.region_plan[region_iter.which_region + 1] - return next -end - -# -# Functions associated with RegionIterator -# -function increment!(region_iter::RegionIterator) - region_iter.which_region += 1 - return region_iter -end - -function compute!(iter::RegionIterator) - extract!(iter; region_kwargs(extract!, iter)...) - update!(iter; region_kwargs(update!, iter)...) - insert!(iter; region_kwargs(insert!, iter)...) - - return iter -end - -region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) - -# -# SweepIterator -# - -mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator - region_iter::RegionIterator{Problem} - sweep_kwargs::Iterators.Stateful{Iter} - which_sweep::Int - function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} - stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) - - first_state = Iterators.peel(stateful_sweep_kwargs) - - if isnothing(first_state) - throw(ArgumentError("Cannot construct a sweep iterator with 0 elements.")) - end - - first_kwargs, _ = first_state - region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) - - return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) - end -end - -islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) - -region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter -problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) - -state(sweep_iter::SweepIterator) = sweep_iter.which_sweep -Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs) -function increment!(sweep_iter::SweepIterator) - sweep_iter.which_sweep += 1 - sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs) - update_region_iterator!(sweep_iter; sweep_kwargs...) - return sweep_iter -end - -function update_region_iterator!(iterator::SweepIterator; kwargs...) - sweep = state(iterator) - iterator.region_iter = RegionIterator(problem(iterator); sweep, kwargs...) - return iterator -end - -function compute!(sweep_iter::SweepIterator) - for _ in sweep_iter.region_iter - # TODO: Is it sensible to execute the default region callback function? - end - return -end - -# More basic constructor where sweep_kwargs are constant throughout sweeps -function SweepIterator(problem, nsweeps::Int; sweep_kwargs...) - # Initialize this to an empty RegionIterator - sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps) - return SweepIterator(problem, sweep_kwargs_iter) -end diff --git a/src/sweeping/eigenproblem.jl b/src/sweeping/eigenproblem.jl new file mode 100644 index 0000000..36978b2 --- /dev/null +++ b/src/sweeping/eigenproblem.jl @@ -0,0 +1,44 @@ +import AlgorithmsInterface as AI +import .AlgorithmsInterfaceExtensions as AIE + +function dmrg(operator, algorithm, state) + problem = EigenProblem(operator) + return AI.solve(problem, algorithm; iterate = state).iterate +end +function dmrg(operator, state; kwargs...) + problem = EigenProblem(operator) + algorithm = select_algorithm(dmrg, operator, state; kwargs...) + return AI.solve(problem, algorithm; iterate = state).iterate +end + +# TODO: Allow specifying the region algorithm type? +function select_algorithm(::typeof(dmrg), operator, state; nsweeps, regions, kwargs...) + extended_kwargs = extend_columns((; kwargs...), nsweeps) + region_kwargs = rows(extended_kwargs) + return AIE.nested_algorithm(nsweeps) do i + return AIE.nested_algorithm(length(regions)) do j + return EigsolveRegion(regions[j]; region_kwargs[i]...) + end + end +end +#= + EigenProblem(operator) + +Represents the problem we are trying to solve and minimal algorithm-independent +information, so for an eigenproblem it is the operator we want the eigenvector of. +=# +struct EigenProblem{Operator} <: AIE.Problem + operator::Operator +end + +struct EigsolveRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm + region::R + kwargs::Kwargs +end +EigsolveRegion(region; kwargs...) = EigsolveRegion(region, (; kwargs...)) + +function AI.solve!( + problem::EigenProblem, algorithm::EigsolveRegion, state::AIE.State; kwargs... + ) + return error("EigsolveRegion step for EigenProblem not implemented yet.") +end diff --git a/src/sweeping/utils.jl b/src/sweeping/utils.jl new file mode 100644 index 0000000..39e09e4 --- /dev/null +++ b/src/sweeping/utils.jl @@ -0,0 +1,12 @@ +# Utility functions for processing keyword arguments. +function repeat_last(v::AbstractVector, len::Int) + return [v; fill(v[end], max(len - length(v), 0))] +end +repeat_last(v, len::Int) = fill(v, len) +function extend_columns(nt::NamedTuple, len::Int) + return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) +end +rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) +function rows(nt::NamedTuple, len::Int = rowlength(nt)) + return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] +end diff --git a/test/Project.toml b/test/Project.toml index 4b7dc81..e71e7a4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" @@ -26,7 +27,7 @@ DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.3" -ITensorNetworksNext = "0.2" +ITensorNetworksNext = "0.3" NamedDimsArrays = "0.8, 0.9" NamedGraphs = "0.6.8, 0.7, 0.8" QuadGK = "2.11.2" diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl new file mode 100644 index 0000000..8e0665c --- /dev/null +++ b/test/test_algorithmsinterfaceextensions.jl @@ -0,0 +1,472 @@ +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using Test: @test, @testset + +# Define test problems, algorithms, and states for testing +struct TestProblem <: AIE.Problem + data::Vector{Float64} +end + +@kwdef struct TestAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(10) +end + +@kwdef struct TestAlgorithmStep{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(5) +end + +function AI.step!( + problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState; + logging_context_prefix = Symbol() + ) + state.iterate .+= 1 # Simple increment step + return state +end + +function AI.step!( + problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState; + kwargs... + ) + state.iterate .+= 2 # Different increment step + return state +end + +@testset "AlgorithmsInterfaceExtensions" begin + @testset "DefaultState" begin + # Test DefaultState construction + iterate = [1.0, 2.0, 3.0] + stopping_criterion_state = AI.initialize_state( + TestProblem([1.0]), TestAlgorithm(), TestAlgorithm().stopping_criterion + ) + state = AIE.DefaultState(; iterate = copy(iterate), stopping_criterion_state) + @test state.iterate == iterate + @test state.iteration == 0 + @test state.stopping_criterion_state isa AI.StoppingCriterionState + + # Test DefaultState with custom iteration + state.iteration = 5 + @test state.iteration == 5 + end + + @testset "initialize_state!" begin + # Test initialize_state! with iterate kwarg + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + state = AIE.DefaultState(; + iteration = 2, iterate = [0.0, 0.0], stopping_criterion_state + ) + AI.initialize_state!(problem, algorithm, state) + @test state.iterate == [0.0, 0.0] + @test state.iteration == 0 + @test state.stopping_criterion_state == stopping_criterion_state + end + + @testset "initialize_state" begin + # Test initialize_state without exclamation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + + state = AI.initialize_state(problem, algorithm; iterate = [0.0, 0.0]) + @test state isa AIE.DefaultState + @test state.iteration == 0 + end + + @testset "increment!" begin + # Test increment! with problem and algorithm + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm() + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) + + # Increment and verify iteration counter increases + AI.increment!(problem, algorithm, state) + @test state.iteration == 1 + + AI.increment!(problem, algorithm, state) + @test state.iteration == 2 + end + + @testset "solve! and solve" begin + # Test solve! with simple problem + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(3)) + + initial_iterate = [10.0, 20.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + + # Solve with custom initial iterate + initial_iterate = [5.0, 10.0] + final_state = AI.solve!( + problem, algorithm, state; iterate = copy(initial_iterate) + ) + + @test final_state.iteration == 3 + # Each step increments by 1, so after 3 steps: [5, 10] + 3 = [8, 13] + @test final_state.iterate ≈ [8.0, 13.0] + + # Test solve without exclamation + problem2 = TestProblem([1.0, 2.0]) + algorithm2 = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate2 = [5.0, 10.0] + + final_state2 = AI.solve(problem2, algorithm2; iterate = copy(initial_iterate2)) + @test final_state2.iteration == 2 + @test final_state2.iterate ≈ [7.0, 12.0] + end + + @testset "DefaultAlgorithmIterator" begin + # Test algorithm iterator creation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + @test iterator isa AIE.DefaultAlgorithmIterator + @test iterator.problem === problem + @test iterator.algorithm === algorithm + @test iterator.state === state + + # Test iteration interface + @test !AI.is_finished!(iterator) + + # Step through iterator + state_out, _ = iterate(iterator) + @test state_out.iteration == 1 + @test state_out.iterate ≈ [1.0, 1.0] # Incremented by step! + + state_out, _ = iterate(iterator) + @test state_out.iteration == 2 + + @test AI.is_finished!(iterator) + end + + @testset "with_algorithmlogger" begin + # Test with_algorithmlogger with functions + results = [] + function callback1(problem, algorithm, state) + push!(results, :callback1) + return nothing + end + function callback2(problem, algorithm, state) + push!(results, :callback2) + return nothing + end + + problem = TestProblem([1.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + + # Test with CallbackAction (wrapped functions) + state = AIE.with_algorithmlogger( + :TestProblem_TestAlgorithm_PreStep => callback1, + :TestProblem_TestAlgorithm_PostStep => callback2, + ) do + return AI.solve(problem, algorithm; iterate = [0.0]) + end + @test results == [:callback1, :callback2] + end + + @testset "DefaultNestedAlgorithm" begin + # Test creating nested algorithm with function + nested_alg = AIE.nested_algorithm(3) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + @test nested_alg isa AIE.DefaultNestedAlgorithm + @test length(nested_alg.algorithms) == 3 + @test AIE.max_iterations(nested_alg) == 3 + + # Test stepping through nested algorithm + problem = TestProblem([1.0, 2.0]) + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) + + initial_iterate = [0.0, 0.0] + AI.solve!( + problem, nested_alg, state; iterate = copy(initial_iterate) + ) + + @test state.iteration == 3 + # Each nested algorithm runs once with 2 steps, incrementing by 2 + # Total: 3 algorithms × 2 iterations × 2 increment = 12 + @test state.iterate ≈ [12.0, 12.0] + end + + @testset "NestedAlgorithm basic tests" begin + # Test basic nested algorithm functionality + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + problem = TestProblem([1.0, 2.0]) + + # Test state initialization + state_nested = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) + + @test state_nested isa AIE.DefaultState + @test state_nested.iteration == 0 + @test AIE.max_iterations(nested_alg) == 2 + end + + @testset "increment! for nested algorithms" begin + # Test increment! logic for nested algorithm state + problem = TestProblem([1.0]) + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; + iterate = [0.0], + stopping_criterion_state = stopping_criterion_state, + ) + + # Test progression through iterations + @test state.iteration == 0 + + AI.increment!(problem, nested_alg, state) + @test state.iteration == 1 + + AI.increment!(problem, nested_alg, state) + @test state.iteration == 2 + end + + @testset "get_subproblem and set_substate!" begin + # Test get_subproblem + problem = TestProblem([1.0, 2.0]) + nested_alg = AIE.nested_algorithm(2) do i + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(1)) + end + + stopping_criterion_state = AI.initialize_state( + problem, nested_alg, nested_alg.stopping_criterion + ) + state = AIE.DefaultState(; + iterate = [5.0, 10.0], + iteration = 1, + stopping_criterion_state, + ) + + subproblem, subalgorithm, substate = AIE.get_subproblem(problem, nested_alg, state) + @test subproblem === problem + @test subalgorithm === nested_alg.algorithms[1] + @test substate.iterate ≈ [5.0, 10.0] + + # Test set_substate! + new_substate = AIE.DefaultState(; + iterate = [100.0, 200.0], + substate.stopping_criterion_state, + ) + AIE.set_substate!(problem, nested_alg, state, new_substate) + @test state.iterate ≈ [100.0, 200.0] + end + + @testset "basetypenameof and default_logging_context_prefix" begin + # Test basetypenameof utility + problem = TestProblem([1.0]) + algorithm = TestAlgorithm() + + prefix_problem = AIE.default_logging_context_prefix(problem) + prefix_algorithm = AIE.default_logging_context_prefix(algorithm) + prefix_combined = AIE.default_logging_context_prefix(problem, algorithm) + + @test prefix_problem isa Symbol + @test prefix_algorithm isa Symbol + @test prefix_combined isa Symbol + @test contains(String(prefix_combined), String(prefix_problem)) + end + + @testset "DefaultFlattenedAlgorithm" begin + # Create nested algorithms that support max_iterations + nested_algs = map(1:3) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(6) # 3 algorithms × 2 iterations each + ) + + @test flattened_alg isa AIE.DefaultFlattenedAlgorithm + @test length(flattened_alg.algorithms) == 3 + + # Test state initialization + problem = TestProblem([1.0, 2.0]) + state_flat = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) + + @test state_flat isa AIE.DefaultFlattenedAlgorithmState + @test state_flat.iteration == 0 + @test state_flat.parent_iteration == 1 + @test state_flat.child_iteration == 0 + end + + @testset "DefaultFlattenedAlgorithmState increment!" begin + # Create nested algorithms for flattened algorithm + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(4), + ) + + problem = TestProblem([1.0]) + stopping_criterion_state = AI.initialize_state( + problem, flattened_alg, flattened_alg.stopping_criterion + ) + state = AIE.DefaultFlattenedAlgorithmState(; + iterate = [0.0], + stopping_criterion_state = stopping_criterion_state, + ) + + # Test initial state + @test state.iteration == 0 + @test state.parent_iteration == 1 + @test state.child_iteration == 0 + + # First increment - should increment child_iteration + AI.increment!(problem, flattened_alg, state) + @test state.iteration == 1 + @test state.parent_iteration == 1 + @test state.child_iteration == 1 + + # Second increment - should increment child_iteration again + AI.increment!(problem, flattened_alg, state) + @test state.iteration == 2 + @test state.parent_iteration == 2 # Should move to next parent + @test state.child_iteration == 1 + end + + @testset "FlattenedAlgorithm step!" begin + # Test individual step! calls for flattened algorithm + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + flattened_alg = AIE.DefaultFlattenedAlgorithm(; + algorithms = nested_algs, + stopping_criterion = AI.StopAfterIteration(4) + ) + + problem = TestProblem([1.0, 2.0]) + state = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) + + # Manually step through to test step! functionality + AI.increment!(problem, flattened_alg, state) + @test state.parent_iteration == 1 + @test state.child_iteration == 1 + + AI.step!(problem, flattened_alg, state) + # The nested algorithm runs TestAlgorithmStep with 2 iterations, each incrementing by 2 + @test state.iterate ≈ [4.0, 4.0] + end + + @testset "flattened_algorithm helper" begin + # Test the flattened_algorithm helper function + nested_algs = map(1:2) do i + return AIE.nested_algorithm(1) do j + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + # Using the helper function + flattened_alg = AIE.flattened_algorithm(2) do i + AIE.nested_algorithm(1) do j + TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + end + end + + @test flattened_alg isa AIE.DefaultFlattenedAlgorithm + @test length(flattened_alg.algorithms) == 2 + end + + @testset "AlgorithmIterator is_finished (without !)" begin + # Test is_finished without mutation + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + # Before any iterations + @test !AI.is_finished(iterator) + + # Run the algorithm + AI.solve!(problem, algorithm, state; iterate = copy(initial_iterate)) + + # After completion + @test AI.is_finished(iterator) + end + + @testset "AlgorithmIterator step!" begin + # Test step! method for iterator + problem = TestProblem([1.0, 2.0]) + algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) + initial_iterate = [0.0, 0.0] + state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) + iterator = AIE.algorithm_iterator(problem, algorithm, state) + + # Step the iterator + AI.step!(iterator) + @test iterator.state.iterate ≈ [1.0, 1.0] + + AI.step!(iterator) + @test iterator.state.iterate ≈ [2.0, 2.0] + end + + @testset "NestedAlgorithm with different sub-algorithms" begin + # Test nested algorithm with varying sub-algorithms + nested_alg = AIE.DefaultNestedAlgorithm(; + algorithms = [ + TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)), + TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + ] + ) + + @test AIE.max_iterations(nested_alg) == 3 + @test length(nested_alg.algorithms) == 3 + + problem = TestProblem([1.0, 2.0]) + state = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) + + AI.solve!(problem, nested_alg, state; iterate = [0.0, 0.0]) + + # First algorithm: 1 iteration × 1 increment = 1 + # Second algorithm: 2 iterations × 2 increment = 4 + # Third algorithm: 1 iteration × 1 increment = 1 + # Total: 1 + 4 + 1 = 6 + @test state.iterate ≈ [6.0, 6.0] + @test state.iteration == 3 + end + + @testset "Edge cases" begin + # Test with single nested algorithm + nested_alg = AIE.nested_algorithm(1) do i + return TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) + end + + problem = TestProblem([1.0]) + state = AI.initialize_state(problem, nested_alg; iterate = [0.0]) + AI.solve!(problem, nested_alg, state; iterate = [0.0]) + + @test state.iterate ≈ [1.0] + @test state.iteration == 1 + end +end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 0afead5..a38563a 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,5 +3,5 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(ITensorNetworksNext) + Aqua.test_all(ITensorNetworksNext; persistent_tasks = false) end diff --git a/test/test_dmrg.jl b/test/test_dmrg.jl new file mode 100644 index 0000000..01f04ac --- /dev/null +++ b/test/test_dmrg.jl @@ -0,0 +1,34 @@ +import AlgorithmsInterface as AI +using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using Test: @test, @testset + +@testset "select_algorithm(dmrg, ...)" begin + operator = "operator" + init = "init" + nsweeps = 3 + regions = ["region1", "region2"] + maxdim = [10, 20] + cutoff = 1.0e-7 + algorithm = select_algorithm(dmrg, operator, init; nsweeps, regions, maxdim, cutoff) + @test algorithm isa AIE.NestedAlgorithm + @test length(algorithm.algorithms) == nsweeps + + maxdims = [10, 20, 20] + cutoffs = [1.0e-7, 1.0e-7, 1.0e-7] + algorithm′ = AIE.nested_algorithm(nsweeps) do i + return AIE.nested_algorithm(length(regions)) do j + return EigsolveRegion( + regions[j]; + maxdim = maxdims[i], + cutoff = cutoffs[i], + ) + end + end + for i in 1:nsweeps + for j in 1:length(regions) + @test algorithm.algorithms[i].algorithms[j] == + algorithm′.algorithms[i].algorithms[j] + end + end +end diff --git a/test/test_iterators.jl b/test/test_iterators.jl deleted file mode 100644 index a17c7be..0000000 --- a/test/test_iterators.jl +++ /dev/null @@ -1,221 +0,0 @@ -using Test: @test, @testset, @test_throws -import ITensorNetworksNext as ITensorNetworks -using .ITensorNetworks: RegionIterator, SweepIterator, IncrementOnly, compute!, increment!, islaststep, state, eachregion - -module TestIteratorUtils - - import ITensorNetworksNext as ITensorNetworks - using .ITensorNetworks - - struct TestProblem <: ITensorNetworks.AbstractProblem - data::Vector{Int} - end - ITensorNetworks.region_plan(::TestProblem) = [:a => (; val = 1), :b => (; val = 2)] - function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem}) - kwargs = ITensorNetworks.region_kwargs(iter) - push!(ITensorNetworks.problem(iter).data, kwargs.val) - return iter - end - - - mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator - state::Int - max::Int - output::Vector{Int} - end - - ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1 - Base.length(TI::TestIterator) = TI.max - ITensorNetworks.state(TI::TestIterator) = TI.state - function ITensorNetworks.compute!(TI::TestIterator) - push!(TI.output, ITensorNetworks.state(TI)) - return TI - end - - mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator - parent::TestIterator - end - - Base.length(SA::SquareAdapter) = length(SA.parent) - ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent) - ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent) - function ITensorNetworks.compute!(SA::SquareAdapter) - ITensorNetworks.compute!(SA.parent) - return last(SA.parent.output)^2 - end - -end - -@testset "Iterators" begin - - import .TestIteratorUtils - - @testset "`AbstractNetworkIterator` Interface" begin - - @testset "Edge cases" begin - TI = TestIteratorUtils.TestIterator(1, 1, []) - cb = [] - @test islaststep(TI) - for _ in TI - @test islaststep(TI) - push!(cb, state(TI)) - end - @test length(cb) == 1 - @test length(TI.output) == 1 - @test only(cb) == 1 - - prob = TestIteratorUtils.TestProblem([]) - @test_throws ArgumentError SweepIterator(prob, 0) - @test_throws ArgumentError RegionIterator(prob, [], 1) - end - - TI = TestIteratorUtils.TestIterator(1, 4, []) - - @test !islaststep((TI)) - - # First iterator should compute only - rv, st = iterate(TI) - @test !islaststep((TI)) - @test !st - @test rv === TI - @test length(TI.output) == 1 - @test only(TI.output) == 1 - @test state(TI) == 1 - @test !st - - rv, st = iterate(TI, st) - @test !islaststep((TI)) - @test !st - @test length(TI.output) == 2 - @test state(TI) == 2 - @test TI.output == [1, 2] - - increment!(TI) - @test !islaststep((TI)) - @test state(TI) == 3 - @test length(TI.output) == 2 - @test TI.output == [1, 2] - - compute!(TI) - @test !islaststep((TI)) - @test state(TI) == 3 - @test length(TI.output) == 3 - @test TI.output == [1, 2, 3] - - # Final Step - iterate(TI, false) - @test islaststep((TI)) - @test state(TI) == 4 - @test length(TI.output) == 4 - @test TI.output == [1, 2, 3, 4] - - @test iterate(TI, false) === nothing - - TI = TestIteratorUtils.TestIterator(1, 5, []) - - cb = [] - - for _ in TI - @test length(cb) == length(TI.output) - 1 - @test cb == (TI.output)[1:(end - 1)] - push!(cb, state(TI)) - @test cb == TI.output - end - - @test islaststep((TI)) - @test length(TI.output) == 5 - @test length(cb) == 5 - @test cb == TI.output - - - TI = TestIteratorUtils.TestIterator(1, 5, []) - end - - @testset "Adapters" begin - TI = TestIteratorUtils.TestIterator(1, 5, []) - SA = TestIteratorUtils.SquareAdapter(TI) - - @testset "Generic" begin - - i = 0 - for rv in SA - i += 1 - @test rv isa Int - @test rv == i^2 - @test state(SA) == i - end - - @test islaststep((SA)) - - TI = TestIteratorUtils.TestIterator(1, 5, []) - SA = TestIteratorUtils.SquareAdapter(TI) - - SA_c = collect(SA) - - @test SA_c isa Vector - @test length(SA_c) == 5 - @test SA_c == [1, 4, 9, 16, 25] - - end - - @testset "IncrementOnly" begin - TI = TestIteratorUtils.TestIterator(1, 5, []) - NI = IncrementOnly(TI) - - NI_c = [] - - for _ in IncrementOnly(TI) - push!(NI_c, state(TI)) - end - - @test length(NI_c) == 5 - @test isempty(TI.output) - end - - @testset "EachRegion" begin - prob = TestIteratorUtils.TestProblem([]) - prob_region = TestIteratorUtils.TestProblem([]) - - SI = SweepIterator(prob, 5) - SI_region = SweepIterator(prob_region, 5) - - callback = [] - callback_region = [] - - let i = 1 - for _ in SI - push!(callback, i) - i += 1 - end - end - - @test length(callback) == 5 - - let i = 1 - for _ in eachregion(SI_region) - push!(callback_region, i) - i += 1 - end - end - - @test length(callback_region) == 10 - - @test prob.data == prob_region.data - - @test prob.data[1:2:end] == fill(1, 5) - @test prob.data[2:2:end] == fill(2, 5) - - - let i = 1, prob = TestIteratorUtils.TestProblem([]) - SI = SweepIterator(prob, 1) - cb = [] - for _ in eachregion(SI) - push!(cb, i) - i += 1 - end - @test length(cb) == 2 - end - - end - end -end diff --git a/test/test_sweeping.jl b/test/test_sweeping.jl new file mode 100644 index 0000000..215a8b8 --- /dev/null +++ b/test/test_sweeping.jl @@ -0,0 +1,65 @@ +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using Test: @test, @testset + +struct TestProblem <: AIE.Problem +end + +struct TestRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm + region::R + kwargs::Kwargs +end +TestRegion(region; kwargs...) = TestRegion(region, (; kwargs...)) + +function AI.solve!(problem::TestProblem, algorithm::TestRegion, state::AIE.State; kwargs...) + new_iterate = (; algorithm.region, algorithm.kwargs.foo, algorithm.kwargs.bar) + state.iterate = [state.iterate; [new_iterate]] + return state +end + +@testset "Sweeping" begin + @testset "TestRegion" begin + algorithm = TestRegion("region"; foo = 1, bar = 2) + @test algorithm isa AIE.NonIterativeAlgorithm + @test algorithm isa AIE.Algorithm + @test algorithm isa AI.Algorithm + @test algorithm.region == "region" + @test algorithm.kwargs == (; foo = 1, bar = 2) + + problem = TestProblem() + iterate = [] + state = AI.solve(problem, algorithm; iterate) + @test state.iterate == [(; region = "region", foo = 1, bar = 2)] + end + @testset "Sweep" begin + algorithm = AIE.nested_algorithm(3) do i + return TestRegion("region$i"; foo = i, bar = 2i) + end + problem = TestProblem() + iterate = [] + state = AI.solve(problem, algorithm; iterate) + @test state.iterate == [ + (; region = "region1", foo = 1, bar = 2), + (; region = "region2", foo = 2, bar = 4), + (; region = "region3", foo = 3, bar = 6), + ] + end + @testset "Sweeping" begin + algorithm = AIE.nested_algorithm(2) do i + AIE.nested_algorithm(3) do j + return TestRegion("sweep$i, region$j"; foo = (i, j), bar = (2i, 2j)) + end + end + problem = TestProblem() + iterate = [] + state = AI.solve(problem, algorithm; iterate) + @test state.iterate == [ + (; region = "sweep1, region1", foo = (1, 1), bar = (2, 2)), + (; region = "sweep1, region2", foo = (1, 2), bar = (2, 4)), + (; region = "sweep1, region3", foo = (1, 3), bar = (2, 6)), + (; region = "sweep2, region1", foo = (2, 1), bar = (4, 2)), + (; region = "sweep2, region2", foo = (2, 2), bar = (4, 4)), + (; region = "sweep2, region3", foo = (2, 3), bar = (4, 6)), + ] + end +end From 032447a00de29e7a8fba27f76bb0ae6a8c193e26 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Tue, 23 Dec 2025 18:15:22 -0500 Subject: [PATCH 22/64] Upgrade to NamedDimsArrays.jl v0.11 (#38) --- Project.toml | 6 +++--- test/Project.toml | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index e6919fc..7b86558 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -33,7 +33,7 @@ ITensorNetworksNextTensorOperationsExt = "TensorOperations" [compat] AbstractTrees = "0.4.5" Adapt = "4.3" -AlgorithmsInterface = "0.1.0" +AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" DataGraphs = "0.2.7" @@ -43,7 +43,7 @@ Dictionaries = "0.4.5" Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" -NamedDimsArrays = "0.8, 0.9" +NamedDimsArrays = "0.8, 0.9, 0.10, 0.11" NamedGraphs = "0.6.9, 0.7, 0.8" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" diff --git a/test/Project.toml b/test/Project.toml index e71e7a4..0e74eef 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,13 +22,14 @@ ITensorNetworksNext = {path = ".."} [compat] AbstractTrees = "0.4.5" +AlgorithmsInterface = "0.1" Aqua = "0.8.14" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" -ITensorBase = "0.3" +ITensorBase = "0.3, 0.4" ITensorNetworksNext = "0.3" -NamedDimsArrays = "0.8, 0.9" +NamedDimsArrays = "0.8, 0.9, 0.10, 0.11" NamedGraphs = "0.6.8, 0.7, 0.8" QuadGK = "2.11.2" SafeTestsets = "0.1" From b256d79f250cc5f06b83885381879b8f0fa41f10 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:34:38 -0500 Subject: [PATCH 23/64] [LazyNamedDimsArrays] New `symnameddims` method that pulls out indices from an array. --- src/LazyNamedDimsArrays/symbolicnameddimsarray.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl index a215319..628baf3 100644 --- a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl +++ b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl @@ -5,6 +5,9 @@ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = function symnameddims(name, dims) return lazy(nameddims(SymbolicArray(name, dename.(dims)), dims)) end +function symnameddims(name, ndarray::AbstractNamedDimsArray) + return symnameddims(name, Tuple(inds(ndarray))) +end symnameddims(name) = symnameddims(name, ()) using AbstractTrees: AbstractTrees function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) From b2da9d80a35da7ea5a2b51fb791a1115342cd8ca Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:35:32 -0500 Subject: [PATCH 24/64] The function `region_scalar` should now return a scalar, rather than a order-0 array --- src/beliefpropagation/abstractbeliefpropagationcache.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 0cae3fa..3545b53 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -62,7 +62,7 @@ function setfactor!(bpc::AbstractDataGraph, vertex, factor) end function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) - return message(bp_cache, edge) * message(bp_cache, reverse(edge)) + return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] end function region_scalar(bp_cache::AbstractGraph, vertex) @@ -70,7 +70,7 @@ function region_scalar(bp_cache::AbstractGraph, vertex) messages = incoming_messages(bp_cache, vertex) state = factors(bp_cache, vertex) - return reduce(*, messages) * reduce(*, state) + return (reduce(*, messages) * reduce(*, state))[] end message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) From 8506e26a3d8814e3e51487a48469f27c9cd64a8f Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:37:43 -0500 Subject: [PATCH 25/64] Fix double counting in `edge_scalars` function This was caused by the change to the `cache` being backed by a directed graph. --- src/beliefpropagation/abstractbeliefpropagationcache.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 3545b53..8e7185e 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -81,7 +81,7 @@ function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) return map(v -> region_scalar(bp_cache, v), vertices) end -function edge_scalars(bp_cache::AbstractGraph, edges = edges(bp_cache)) +function edge_scalars(bp_cache::AbstractGraph, edges = edges(undirected_graph(underlying_graph(bp_cache)))) return map(e -> region_scalar(bp_cache, e), edges) end @@ -120,7 +120,9 @@ adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end function free_energy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + if any(t -> real(t) < 0, numerator_terms) numerator_terms = complex.(numerator_terms) end From 938180af0e35b3e091aa39bfa405a0dd5842d523 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:37:59 -0500 Subject: [PATCH 26/64] Minor code formatting --- src/beliefpropagation/abstractbeliefpropagationcache.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 8e7185e..0efc95d 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -130,7 +130,10 @@ function free_energy(bp_cache::AbstractBeliefPropagationCache) denominator_terms = complex.(denominator_terms) end - any(iszero, denominator_terms) && return -Inf + if any(iszero, denominator_terms) + return -Inf + end + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) From 44619673fedaf47c59bd2557222086807f12a2ec Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:39:43 -0500 Subject: [PATCH 27/64] Expressed belief propagation in terms of AlgorithmsInterface --- .../beliefpropagationcache.jl | 13 + .../beliefpropagationproblem.jl | 279 +++++++++++++----- src/sweeping/utils.jl | 8 +- test/test_beliefpropagation.jl | 10 +- 4 files changed, 222 insertions(+), 88 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index c9793e6..27a580d 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -23,6 +23,7 @@ using NamedGraphs.GraphsExtensions: default_root_vertex, is_path_graph, undirected_graph using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype struct BeliefPropagationCache{V, G <: AbstractGraph{V}, N <: AbstractGraph{V}, ET, MT} <: AbstractBeliefPropagationCache{V, MT} @@ -125,3 +126,15 @@ function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) data = map(e -> bpc[QuotientEdge(e)], inds) return BeliefPropagationCache(QuotientView(network(bpc)), data) end + +function default_message(bpc::BeliefPropagationCache, edge) + return default_message(message_type(bpc), network(bpc), edge) +end +function default_message(T::Type, network, edge) + array = ones(Tuple(linkinds(network, edge))) + return convert(T, array) +end +function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) + message = default_message(parenttype(T), network, edge) + return convert(T, lazy(message)) +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 24b024d..0d997ee 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,82 +1,200 @@ -using Graphs: SimpleGraph, vertices, edges, has_edge +using Graphs: SimpleGraph, vertices, edges, has_edge, AbstractEdge using NamedGraphs: AbstractNamedGraph, position_graph using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices using NamedDimsArrays: AbstractNamedDimsArray -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, parenttype, lazy +using DataGraphs: edge_data +import AlgorithmsInterface as AI +import .AlgorithmsInterfaceExtensions as AIE -abstract type AbstractBeliefPropagationProblem{Alg} <: AbstractProblem end +@kwdef struct StopWhenConverged <: AI.StoppingCriterion + tol::Float64 = 0.0 +end -mutable struct BeliefPropagationProblem{Alg, Cache} <: AbstractBeliefPropagationProblem{Alg} - const alg::Alg - const cache::Cache - diff::Union{Nothing, Float64} +@kwdef mutable struct StopWhenConvergedState <: AI.StoppingCriterionState + delta::Float64 = Inf end -BeliefPropagationProblem(alg, cache) = BeliefPropagationProblem(alg, cache, nothing) +function AI.initialize_state(::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged) + return StopWhenConvergedState() +end -function default_algorithm( - ::Type{<:Algorithm"bp"}, - bpc; - verbose = false, - tol = nothing, - edge_sequence = forest_cover_edge_sequence(bpc), - message_update_alg = default_algorithm(Algorithm"contract"), - maxiter = is_tree(bpc) ? 1 : nothing, +function AI.initialize_state!( + ::AIE.Problem, + ::AIE.Algorithm, + ::StopWhenConverged, + st::StopWhenConvergedState, ) - return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) + st.delta = Inf + return st end -function region_plan(prob::BeliefPropagationProblem{<:Algorithm"bp"}; sweep_kwargs...) - edges = prob.alg.edge_sequence +function AI.is_finished!( + ::AIE.Problem, + ::AIE.Algorithm, + state::AIE.State, + c::StopWhenConverged, + st::StopWhenConvergedState, + ) - plan = map(edges) do e - return e => (; sweep_kwargs...) + # maxdiff = 0.0 initially, so skip this the first time. + if state.iteration > 0 + st.delta = state.iterate.maxdiff end - return plan + return st.delta < c.tol +end + +struct BeliefPropagationProblem{Network} <: AIE.Problem + network::Network +end + +@kwdef mutable struct BeliefPropagationState{ + Iterate <: BeliefPropagationCache, + Diffs, + } <: AIE.NonIterativeAlgorithmState + iterate::Iterate + diffs::Diffs = similar(edge_data(iterate), Float64) + maxdiff::Float64 = 0.0 +end + +function AI.initialize_state( + problem::BeliefPropagationProblem, + algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... + ) + + diffs = iterate.diffs + maxdiff = iterate.maxdiff + + return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) end -function compute!(iter::RegionIterator{<:BeliefPropagationProblem{<:Algorithm"bp"}}) - prob = iter.problem +# This gets called at the start of every sweep. +function AI.initialize_state!( + problem::BeliefPropagationProblem, + algorithm::AIE.NestedAlgorithm, + state::AIE.State, + ) + state.iterate.maxdiff = 0.0 + return state +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + ::AIE.NestedAlgorithm, + state::AIE.State, + substate::BeliefPropagationState + ) + + state.iterate = substate + + return state +end - edge, _ = current_region_plan(iter) - new_message = updated_message(prob.alg.message_update_alg, prob.cache, edge) - setmessage!(prob.cache, edge, new_message) +abstract type AbstractMessageUpdate <: AIE.NonIterativeAlgorithm end - return iter +struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} <: AbstractMessageUpdate + edge::E + kwargs::Kwargs end -function default_message(bpc::BeliefPropagationCache, edge) - return default_message(message_type(bpc), network(bpc), edge) +function SimpleMessageUpdate( + edge; + normalize = false, + contraction_alg = "eager", + compute_diff = false, + kwargs... + ) + return SimpleMessageUpdate(edge, (; normalize, contraction_alg, compute_diff, kwargs...)) end -function default_message(T::Type, network, edge) - array = ones(Tuple(linkinds(network, edge))) - return convert(T, array) + +function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) + if name in (:edge, :kwargs) + return getfield(alg, name) + else + return getproperty(getfield(alg, :kwargs), name) + end end -function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) - message = default_message(parenttype(T), network, edge) - return convert(T, lazy(message)) + +struct MessageUpdateProblem{Messages, Factors} <: AIE.Problem + messages::Messages + factors::Factors end -updated_message(alg, bpc, edge) = not_implemented() -function updated_message(alg::Algorithm"contract", bpc, edge) +function AI.solve!( + problem::BeliefPropagationProblem, + algorithm::AbstractMessageUpdate, + state::BeliefPropagationState; + logging_context_prefix = default_logging_context_prefix(problem, algorithm), + ) + + logger = AI.algorithm_logger() + + cache = state.iterate + edge = algorithm.edge + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) + ) + + new_message = updated_message(algorithm, cache) + + if algorithm.compute_diff + diff = message_diff(new_message, cache[edge]) + + if diff > state.maxdiff + state.maxdiff = diff + end + + state.diffs[edge] = diff + end + + setmessage!(cache, edge, new_message) + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostUpdate) + ) + + return state +end + +message_diff(m1, m2) = LinearAlgebra.norm(m1 - m2) + +function updated_message(algorithm, cache) + edge = algorithm.edge + vertex = src(edge) + messages = incoming_messages(cache, vertex; ignore_edges = typeof(edge)[reverse(edge)]) + + update_problem = MessageUpdateProblem(messages, factors(cache, vertex)) + + message_state = AI.solve(update_problem, algorithm; iterate = message(cache, edge)) - incoming_ms = incoming_messages( - bpc, vertex; ignore_edges = typeof(edge)[reverse(edge)] + return message_state.iterate +end + +function AI.solve!( + problem::MessageUpdateProblem, + algorithm::SimpleMessageUpdate, + state::AIE.NonIterativeAlgorithmState; + logging_context_prefix = AI.default_logging_context_prefix(problem, algorithm), + kwargs... ) - updated_message = contract_messages(alg.contraction_alg, factors(bpc, vertex), incoming_ms) + # TODO: logging... - if alg.normalize - message_norm = LinearAlgebra.norm(updated_message) + state.iterate = contract_messages(algorithm.contraction_alg, problem.factors, problem.messages) + + if algorithm.normalize + # TODO: use `sum` not `norm` + message_norm = LinearAlgebra.norm(state.iterate) if !iszero(message_norm) - updated_message /= message_norm + state.iterate /= message_norm end end - return updated_message + + return state end contract_messages(alg, factors, messages) = not_implemented() @@ -85,54 +203,51 @@ function contract_messages( factors::Vector{<:AbstractArray}, messages::Vector{<:AbstractArray}, ) - return contract_network(alg, vcat(factors, messages)) + return contract_network(vcat(factors, messages); alg) end -function default_algorithm( - ::Type{<:Algorithm"contract"}; normalize = true, contraction_alg = Algorithm("exact") - ) - return Algorithm("contract"; normalize, contraction_alg) -end -function default_algorithm( - ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") - ) - return Algorithm("adapt_update"; adapt, alg) -end +beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network); kwargs...) +function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) -function update_message!( - message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge - ) - return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) -end + problem = BeliefPropagationProblem(network(cache)) -function update(bpc::AbstractBeliefPropagationCache; kwargs...) - return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) -end + algorithm = select_algorithm(beliefpropagation, cache; kwargs...) -function update(alg, bpc) - compute_error = !isnothing(alg.tol) + # The nested algorithms will wrap and manipulate this object. + base_state = BeliefPropagationState(; iterate = cache) - diff = compute_error ? 0.0 : nothing + state = AI.solve(problem, algorithm; iterate = base_state) - prob = BeliefPropagationProblem(alg, bpc, diff) + return state.iterate.iterate +end - iter = SweepIterator(prob, alg.maxiter; compute_error) +function select_algorithm( + ::typeof(beliefpropagation), + cache; + edges = forest_cover_edge_sequence(network(cache)), + maxiter = is_tree(network(cache)) ? 1 : nothing, + tol = 0.0, + kwargs... + ) - for _ in iter - if compute_error && prob.diff <= alg.tol - break - end + if isnothing(maxiter) + throw(ArgumentError("`maxiter` must be specified for non-tree graphs")) end - if alg.verbose && compute_error - if prob.diff <= alg.tol - println("BP converged to desired precision after $(iter.which_sweep) iterations.") - else - println( - "BP failed to converge to precision $(alg.tol), got $(prob.diff) after $(iter.which_sweep) iterations", - ) - end + stopping_criterion = AI.StopAfterIteration(maxiter) + compute_diff = false + + if tol > 0.0 + stopping_criterion = stopping_criterion | StopWhenConverged(tol) + compute_diff = true end - return bpc + extended_kwargs = extend_columns((; compute_diff, kwargs...), maxiter) + edge_kwargs = rows(extended_kwargs, len = maxiter) + + return AIE.nested_algorithm(maxiter; stopping_criterion) do repnum + return AIE.nested_algorithm(length(edges)) do edgenum + return SimpleMessageUpdate(edges[edgenum]; edge_kwargs[repnum]...) + end + end end diff --git a/src/sweeping/utils.jl b/src/sweeping/utils.jl index 39e09e4..9a39c9d 100644 --- a/src/sweeping/utils.jl +++ b/src/sweeping/utils.jl @@ -7,6 +7,12 @@ function extend_columns(nt::NamedTuple, len::Int) return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) end rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) -function rows(nt::NamedTuple, len::Int = rowlength(nt)) +function rows(nt::NamedTuple; len = nothing) + if isnothing(len) + if isempty(nt) + throw(ArgumentError("Got empty named tuple; keyword `len` must be specified in this case.")) + end + len = rowlength(nt) + end return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index a39e1a6..8c7829b 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -20,7 +20,7 @@ using Test: @test, @testset @testset "BeliefPropagation" begin #Chain of tensors - dims = (4, 1) + dims = (2, 1) g = named_grid(dims) l = Dict(e => Index(2) for e in edges(g)) l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) @@ -30,10 +30,10 @@ using Test: @test, @testset end bpc = BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.update(bpc; maxiter = 1) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1.0e-14 + @test z_bp ≈ z_exact atol = 1.0e-14 #Tree of tensors dims = (4, 3) @@ -46,8 +46,8 @@ using Test: @test, @testset end bpc = BeliefPropagationCache(tn) - bpc = ITensorNetworksNext.update(bpc; maxiter = 10) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test abs(z_bp - z_exact) <= 1.0e-14 + @test z_bp ≈ z_exact atol = 1.0e-12 end From d68860ae59092f2382fccfee87d03abe9a097b58 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:40:23 -0500 Subject: [PATCH 28/64] Fixes to TensorNetwork construction from tensor list --- src/abstracttensornetwork.jl | 4 ++-- src/tensornetwork.jl | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index b820867..08f86a1 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,7 +1,7 @@ using Adapt: Adapt, adapt, adapt_structure using BackendSelection: @Algorithm_str, Algorithm using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, underlying_graph, - underlying_graph_type, vertex_data + underlying_graph_type, vertex_data, set_vertex_data! using Dictionaries: Dictionary using Graphs: Graphs, AbstractEdge, AbstractGraph, Graph, add_edge!, add_vertex!, bfs_tree, center, dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices @@ -111,7 +111,7 @@ function sitenames(tn::AbstractGraph, edge::AbstractEdge) end function setindex_preserve_graph!(tn::AbstractGraph, value, vertex) - set!(vertex_data(tn), vertex, value) + set_vertex_data!(tn, value, vertex) return tn end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 16c80e3..b811e2b 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -35,8 +35,13 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar end end # This assumes the tensor connectivity matches the graph structure. +function TensorNetwork(graph::AbstractGraph, tensors) + return TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) +end function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) - return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) + tn = _TensorNetwork(graph, tensors) + fix_links!(tn) + return tn end function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors} @@ -80,11 +85,6 @@ tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors) function TensorNetwork(f::Base.Callable, graph::AbstractGraph) return TensorNetwork(graph, Dictionary(map(f, vertices(graph)))) end -function TensorNetwork(graph::AbstractGraph, tensors) - tn = _TensorNetwork(graph, tensors) - fix_links!(tn) - return tn -end # Insert trivial links for missing edges, and also check # the vertices and edges are consistent between the graph and tensors. @@ -172,6 +172,7 @@ function PartitionedGraphs.departition( return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) end +# When getting data according the quotient vertices, take a lazy contraction. function DataGraphs.get_vertices_data(tn::TensorNetwork, vertex::QuotientVertexVertices) data = collect(map(v -> tn[v], NamedGraphs.parent_graph_indices(vertex))) return mapreduce(lazy, *, data) From 2f5c783f4760d813777e392321c97028f05b3f99 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 9 Jan 2026 14:41:18 -0500 Subject: [PATCH 29/64] Minor simplifications to `contract_network` interface. --- src/contract_network.jl | 91 ++++++++++++++++------------------- test/test_contract_network.jl | 12 ++--- 2 files changed, 48 insertions(+), 55 deletions(-) diff --git a/src/contract_network.jl b/src/contract_network.jl index e89fa00..4511595 100644 --- a/src/contract_network.jl +++ b/src/contract_network.jl @@ -1,69 +1,62 @@ using BackendSelection: @Algorithm_str, Algorithm using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order, +using NamedDimsArrays: inds +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, Mul, lazy, optimize_evaluation_order, substitute, symnameddims -# This is related to `MatrixAlgebraKit.select_algorithm`. -# TODO: Define this in BackendSelection.jl. -backend_value(::Algorithm{alg}) where {alg} = alg -using BackendSelection: parameters -function merge_parameters(alg::Algorithm; kwargs...) - return Algorithm(backend_value(alg); merge(parameters(alg), kwargs)...) +function contract_network(tn; alg = default_kwargs(contract_network, tn).alg) + return contract_network(alg, tn) end -to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...) -to_algorithm(alg; kwargs...) = Algorithm(alg; kwargs...) -# `contract_network` -function contract_network(alg::Algorithm, tn) - return throw(ArgumentError("`contract_network` algorithm `$(alg)` not implemented.")) -end -function default_kwargs(::typeof(contract_network), tn) - return (; alg = Algorithm"exact"(; order_alg = Algorithm"eager"())) -end -function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...) - return contract_network(to_algorithm(alg; kwargs...), tn) +contract_network(alg::String, tn) = contract_network(Algorithm(alg), tn) + +default_kwargs(::typeof(contract_network), tn) = (; alg = "eager") + +function contract_network( + alg, + tensors, + ) + + order = contraction_expression(tensors; order = alg) + symbols_to_tensors = Dict( + symnameddims(i, tensors[i]) => lazy(tensors[i]) for i in keys(tensors) + ) + + return materialize(substitute(order, symbols_to_tensors)) end -# `contract_network(::Algorithm"exact", ...)` -function get_order(alg::Algorithm"exact", tn) - # Allow specifying either `order` or `order_alg`. - order = get(alg, :order, nothing) - order = if !isnothing(order) - order - else - default_order_alg = default_kwargs(contraction_order, tn).alg - order_alg = get(alg, :order_alg, default_order_alg) - # TODO: Capture other keyword arguments and pass them to `contraction_order`. - contraction_order(tn; alg = order_alg) - end +# `contraction_order` +function contraction_order end +default_kwargs(::typeof(contraction_order), tensors) = (; order = "eager") + +function contraction_expression(tensors; order = default_kwargs(contraction_order, tensors).order) + order = contraction_order(order, tensors) + # Contraction order may or may not have indices attached, canonicalize the format # by attaching indices. - subs = Dict(symnameddims(i) => symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn)) + subs = Dict(symnameddims(i) => symnameddims(i, tensors[i]) for i in keys(tensors)) + return substitute(order, subs) end -function contract_network(alg::Algorithm"exact", tn) - order = get_order(alg, tn) - syms_to_ts = Dict(symnameddims(i, Tuple(inds(tn[i]))) => lazy(tn[i]) for i in keys(tn)) - tn_expression = substitute(order, syms_to_ts) - return materialize(tn_expression) -end -# `contraction_order` -function contraction_order end -default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"()) -function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...) - return contraction_order(to_algorithm(alg; kwargs...), tn) +contraction_order(order, tensors) = order +function contraction_order(tensors; order = default_kwargs(contraction_order, tensors).order) + return contraction_order(Algorithm(order), tensors) end # Convert the tensor network to a flat symbolic multiplication expression. -function contraction_order(alg::Algorithm"flat", tn) +function contraction_order(::Algorithm"flat", tensors) # Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`. - syms = vec([symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn)]) + syms = vec([symnameddims(i, Tuple(inds(tensors[i]))) for i in keys(tensors)]) return lazy(Mul(syms)) end -function contraction_order(alg::Algorithm"left_associative", tn) - return prod(i -> symnameddims(i, Tuple(inds(tn[i]))), keys(tn)) +function contraction_order(::Algorithm"left_associative", tensors) + return prod(i -> symnameddims(i, Tuple(inds(tensors[i]))), keys(tensors)) end -function contraction_order(alg::Algorithm, tn) - s = contraction_order(Algorithm"flat"(), tn) - return optimize_evaluation_order(s; alg) + +function contraction_order( + order_algorithm::Algorithm, + tensors, + ) + order = contraction_order(tensors; order = "flat") + return optimize_evaluation_order(order; alg = order_algorithm) end diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index c9abfdd..b5ff72e 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -14,9 +14,9 @@ using Test: @test, @testset C = ITensor([5.0, 1.0], j) D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k) - ABCD_1 = contract_network([A, B, C, D]; order_alg = "left_associative") - ABCD_2 = contract_network([A, B, C, D]; order_alg = "eager") - ABCD_3 = contract_network([A, B, C, D]; order_alg = "optimal") + ABCD_1 = contract_network([A, B, C, D]; alg = "left_associative") + ABCD_2 = contract_network([A, B, C, D]; alg = "eager") + ABCD_3 = contract_network([A, B, C, D]; alg = "optimal") @test ABCD_1 == ABCD_2 == ABCD_3 end @@ -31,9 +31,9 @@ using Test: @test, @testset return randn(Tuple(is)) end - z1 = contract_network(tn; order_alg = "left_associative")[] - z2 = contract_network(tn; order_alg = "eager")[] - z3 = contract_network(tn; order_alg = "optimal")[] + z1 = contract_network(tn; alg = "left_associative")[] + z2 = contract_network(tn; alg = "eager")[] + z3 = contract_network(tn; alg = "optimal")[] @test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64) @test abs(z1 - z3) / abs(z1) <= 1.0e3 * eps(Float64) From 4eec9b65e4917c3feb11926ccf61207773833e2b Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 10 Feb 2026 11:50:00 -0500 Subject: [PATCH 30/64] Upgrade DataGraphs and NamedGraphs dependencies --- src/abstracttensornetwork.jl | 20 +----- .../abstractbeliefpropagationcache.jl | 19 +++--- .../beliefpropagationcache.jl | 63 ++++++++++--------- src/tensornetwork.jl | 40 +++++++++--- test/Project.toml | 4 +- 5 files changed, 79 insertions(+), 67 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 08f86a1..671ba3a 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -16,7 +16,8 @@ using NamedGraphs.GraphsExtensions: incident_edges, rem_edges!, rename_vertices, - vertextype + vertextype, + similar_graph using SplitApplyCombine: flatten using NamedGraphs.SimilarType: similar_type @@ -25,7 +26,7 @@ abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} # Need to be careful about removing edges from tensor networks in case there is a bond Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented() -DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = not_implemented() +DataGraphs.edge_data_type(::Type{<:AbstractTensorNetwork}) = not_implemented() # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) @@ -235,18 +236,3 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) end Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) - -function Graphs.induced_subgraph(graph::AbstractTensorNetwork{V}, subvertices::Vector{V}) where {V} - return tensornetwork_induced_subgraph(graph, subvertices) -end - -function tensornetwork_induced_subgraph(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) - subgraph = similar_type(graph)(underlying_subgraph) - for v in vertices(subgraph) - if isassigned(graph, v) - set!(vertex_data(subgraph), v, graph[v]) - end - end - return subgraph, vlist -end diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 0efc95d..b77fb4e 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,15 +1,12 @@ using Graphs: AbstractGraph, AbstractEdge -using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype +using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_type using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent messages(bp_cache::AbstractGraph) = edge_data(bp_cache) messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] -function message(bp_cache::AbstractGraph, edge::AbstractEdge) - ms = messages(bp_cache) - return get!(ms, edge, default_message(bp_cache, edge)) -end +message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() function deletemessage!(bp_cache::AbstractDataGraph, edge) @@ -52,7 +49,7 @@ factors(bpc::AbstractGraph) = vertex_data(bpc) factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices] factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex]) -factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex] +factor(bpc::AbstractGraph, vertex) = bpc[vertex] setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() function setfactor!(bpc::AbstractDataGraph, vertex, factor) @@ -75,7 +72,7 @@ end message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) -message_type(type::Type{<:AbstractDataGraph}) = edge_data_eltype(type) +message_type(type::Type{<:AbstractDataGraph}) = edge_data_type(type) function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) return map(v -> region_scalar(bp_cache, v), vertices) @@ -117,7 +114,13 @@ end adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es) adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs) -abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end +abstract type AbstractBeliefPropagationCache{V, VD, ED} <: AbstractDataGraph{V, VD, ED} end + +factor_type(bpc::AbstractBeliefPropagationCache) = typeof(bpc) +factor_type(::Type{<:AbstractBeliefPropagationCache{<:Any, VD}}) where {VD} = VD + +message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) +message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED function free_energy(bp_cache::AbstractBeliefPropagationCache) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 27a580d..10ab586 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -2,20 +2,19 @@ using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph, - has_edge_data, get_vertex_data, get_edge_data, set_vertex_data!, set_edge_data!, - unset_vertex_data!, - unset_edge_data!, - vertex_data_eltype, - edge_data_eltype, + vertex_data_type, + edge_data_type, underlying_graph, - underlying_graph_type + underlying_graph_type, + is_vertex_assigned, + is_edge_assigned using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components, is_directed -using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices +using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices, Vertices using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, @@ -25,22 +24,23 @@ using NamedGraphs.GraphsExtensions: default_root_vertex, using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype -struct BeliefPropagationCache{V, G <: AbstractGraph{V}, N <: AbstractGraph{V}, ET, MT} <: - AbstractBeliefPropagationCache{V, MT} +struct BeliefPropagationCache{V, VD, ED, G <: AbstractGraph{V}, N <: AbstractGraph{V}, E} <: + AbstractBeliefPropagationCache{V, VD, ED} underlying_graph::G # we only use this for the edges. network::N - messages::Dictionary{ET, MT} + messages::Dictionary{E, ED} function BeliefPropagationCache(network::AbstractGraph, messages::Dictionary) V = vertextype(network) + VD = vertex_data_type(network) N = typeof(network) ET = keytype(messages) - MT = eltype(messages) + ED = eltype(messages) # Construct a directed graph version of the underlying graph of the tensor network. digraph = directed_graph(underlying_graph(network)) - bpc = new{V, typeof(digraph), N, ET, MT}(digraph, network, messages) + bpc = new{V, VD, ED, typeof(digraph), N, ET}(digraph, network, messages) for edge in edges(bpc) get!(() -> default_message(bpc, edge), messages, edge) @@ -53,8 +53,8 @@ network(bp_cache) = getfield(bp_cache, :network) DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :underlying_graph) -DataGraphs.has_vertex_data(bpc::BeliefPropagationCache, vertex) = has_vertex_data(network(bpc), vertex) -DataGraphs.has_edge_data(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) +DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = is_vertex_assigned(network(bpc), vertex) +DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = get_vertex_data(network(bpc), vertex) DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] @@ -62,20 +62,8 @@ DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc. DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set_vertex_data!(network(bpc), val, vertex) DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) -DataGraphs.unset_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = unset_vertex_data!(network(bpc), val, vertex) -DataGraphs.unset_edge_data!(bpc::BeliefPropagationCache, val, edge) = unset!(bpc.messages, edge, val) - -function DataGraphs.vertex_data_eltype(T::Type{<:BeliefPropagationCache}) - return vertex_data_eltype(fieldtype(T, :network)) -end -function DataGraphs.edge_data_eltype(T::Type{<:BeliefPropagationCache}) - return eltype(fieldtype(T, :messages)) -end - -message_type(T::Type{<:BeliefPropagationCache}) = edge_data_eltype(T) - function BeliefPropagationCache(network::AbstractGraph) - MT = vertex_data_eltype(typeof(network)) + MT = vertex_data_type(typeof(network)) return BeliefPropagationCache(MT, network) end function BeliefPropagationCache(MT::Type, network::AbstractGraph) @@ -95,7 +83,7 @@ function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_roo forests = forest_cover(g) rv = edgetype(g)[] for forest in forests - trees = [forest[vs] for vs in connected_components(forest)] + trees = [forest[Vertices(vs)] for vs in connected_components(forest)] for tree in trees tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) @@ -106,16 +94,19 @@ end function bpcache_induced_subgraph(graph, subvertices) underlying_subgraph, vlist = Graphs.induced_subgraph(network(graph), subvertices) - subgraph = BeliefPropagationCache(underlying_subgraph, typeof(edge_data(graph))()) + + edge_data = Dictionary{edgetype(underlying_subgraph), edge_data_type(typeof(graph))}() + + subgraph = BeliefPropagationCache(underlying_subgraph, edge_data) for e in edges(subgraph) if isassigned(graph, e) - set!(edge_data(subgraph), e, graph[e]) + subgraph[e] = graph[e] end end return subgraph, vlist end -function Graphs.induced_subgraph(graph::BeliefPropagationCache{V}, subvertices::Vector{V}) where {V} +function NamedGraphs.induced_subgraph_from_vertices(graph::BeliefPropagationCache, subvertices) return bpcache_induced_subgraph(graph, subvertices) end @@ -138,3 +129,13 @@ function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) message = default_message(parenttype(T), network, edge) return convert(T, lazy(message)) end + +NamedGraphs.to_graph_index(::BeliefPropagationCache, vertex::QuotientVertex) = vertex +# When getting data according the quotient vertices, take a lazy contraction. +function DataGraphs.get_index_data(tn::BeliefPropagationCache, vertex::QuotientVertex) + data = collect(map(v -> tn[v], vertices(tn, vertex))) + return mapreduce(lazy, *, data) +end +function DataGraphs.is_graph_index_assigned(tn::BeliefPropagationCache, vertex::QuotientVertex) + return isassigned(tn, Vertices(vertices(tn, vertex))) +end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index b811e2b..0d30970 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -18,7 +18,7 @@ using NamedGraphs.PartitionedGraphs: QuotientVertexVertices, quotientvertices using .LazyNamedDimsArrays: lazy, Mul -using DataGraphs: vertex_data_eltype, vertex_data, edge_data, get_vertices_data +using DataGraphs: vertex_data_type, vertex_data, edge_data, get_vertices_data using DataGraphs.DataGraphsPartitionedGraphsExt function _TensorNetwork end @@ -52,13 +52,12 @@ end DataGraphs.underlying_graph(tn::TensorNetwork) = tn.underlying_graph -DataGraphs.has_vertex_data(tn::TensorNetwork, v) = haskey(tn.tensors, v) -DataGraphs.has_edge_data(tn::TensorNetwork, e) = false +DataGraphs.is_vertex_assigned(tn::TensorNetwork, v) = haskey(tn.tensors, v) +DataGraphs.is_edge_assigned(tn::TensorNetwork, e) = false DataGraphs.get_vertex_data(tn::TensorNetwork, v) = tn.tensors[v] DataGraphs.set_vertex_data!(tn::TensorNetwork, val, v) = set!(tn.tensors, v, val) -DataGraphs.unset_vertex_data!(tn::TensorNetwork, val, v) = unset!(tn.tensors, v, val) function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) @@ -135,11 +134,30 @@ function Graphs.rem_edge!(tn::TensorNetwork, e) return true end -function GraphsExtensions.similar(type::Type{<:TensorNetwork}) +function GraphsExtensions.similar_graph(type::Type{<:TensorNetwork}) DT = fieldtype(type, :tensors) empty_dict = DT() return TensorNetwork(similar_graph(underlying_graph_type(type)), empty_dict) end +function GraphsExtensions.similar_graph(tn::TensorNetwork, underlying_graph::AbstractGraph) + DT = fieldtype(typeof(tn), :tensors) + empty_dict = DT() + return _TensorNetwork(underlying_graph, empty_dict) +end + +function NamedGraphs.induced_subgraph_from_vertices(graph::TensorNetwork, subvertices) + return tensornetwork_induced_subgraph(graph, subvertices) +end + +function tensornetwork_induced_subgraph(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + + subgraph = TensorNetwork(underlying_subgraph) do vertex + return graph[vertex] + end + + return subgraph, vlist +end ## PartitionedGraphs function PartitionedGraphs.quotient_graph(tn::TensorNetwork) @@ -154,7 +172,7 @@ end # DataGraphsPartitionedGraphsExt interface. function PartitionedGraphs.quotient_graph_type(type::Type{<:TensorNetwork}) UG = quotient_graph_type(underlying_graph_type(type)) - VD = Vector{vertex_data_eltype(type)} + VD = Vector{vertex_data_type(type)} V = vertextype(UG) return TensorNetwork{V, VD, UG, Dictionary{V, VD}} end @@ -172,14 +190,18 @@ function PartitionedGraphs.departition( return TensorNetwork(departition(underlying_graph(tn)), vertex_data(tn)) end +NamedGraphs.to_graph_index(::TensorNetwork, vertex::QuotientVertex) = vertex # When getting data according the quotient vertices, take a lazy contraction. -function DataGraphs.get_vertices_data(tn::TensorNetwork, vertex::QuotientVertexVertices) - data = collect(map(v -> tn[v], NamedGraphs.parent_graph_indices(vertex))) +function DataGraphs.get_index_data(tn::TensorNetwork, vertex::QuotientVertex) + data = collect(map(v -> tn[v], vertices(tn, vertex))) return mapreduce(lazy, *, data) end +function DataGraphs.is_graph_index_assigned(tn::TensorNetwork, vertex::QuotientVertex) + return isassigned(tn, Vertices(vertices(tn, vertex))) +end function PartitionedGraphs.quotientview(tn::TensorNetwork) qview = QuotientView(underlying_graph(tn)) - tensors = vertex_data(QuotientView(tn)) + tensors = map(qv -> vertex_data(tn)[Indices(qv)], Indices(quotientvertices(tn))) return TensorNetwork(qview, tensors) end diff --git a/test/Project.toml b/test/Project.toml index 564db3f..975c2c1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,8 +29,8 @@ Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.5" ITensorNetworksNext = "0.3" -NamedDimsArrays = "0.14" -NamedGraphs = "0.6.8, 0.7, 0.8" +NamedDimsArrays = "0.13" +NamedGraphs = "0.10" QuadGK = "2.11.2" SafeTestsets = "0.1" Suppressor = "0.2.8" From 202724ca021139bf7fa5d5cd561406dd497cacd4 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 10 Feb 2026 11:57:32 -0500 Subject: [PATCH 31/64] [AlgorithmsInterfaceExtensions] Allowing mapping over a generic iterable when constructing nested algorithms --- .../AlgorithmsInterfaceExtensions.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index a8c814e..3c887b7 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -152,8 +152,8 @@ end abstract type NestedAlgorithm <: Algorithm end -function nested_algorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultNestedAlgorithm(f, nalgorithms; kwargs...) +function nested_algorithm(f::Function, iterable; kwargs...) + return DefaultNestedAlgorithm(f, iterable; kwargs...) end max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) @@ -211,6 +211,9 @@ function DefaultNestedAlgorithm(f::Function, nalgorithms::Int; kwargs...) return DefaultNestedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) end +function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) + return DefaultNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) +end #============================ FlattenedAlgorithm ==========================================# # Flatten a nested algorithm. From 69542e32ba7d5ad1a4b616a40822dffcd1de4c9c Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Wed, 11 Feb 2026 11:44:18 -0500 Subject: [PATCH 32/64] Upgrade serial BP to use own `<:Algorithm` structs. --- .../beliefpropagationproblem.jl | 136 +++++++++++------- 1 file changed, 87 insertions(+), 49 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 0d997ee..75023b3 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,8 +1,9 @@ using Graphs: SimpleGraph, vertices, edges, has_edge, AbstractEdge using NamedGraphs: AbstractNamedGraph, position_graph -using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices +using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices, subgraph, boundary_edges using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices using NamedDimsArrays: AbstractNamedDimsArray +using NamedGraphs.PartitionedGraphs: quotientvertices using DataGraphs: edge_data import AlgorithmsInterface as AI @@ -41,55 +42,35 @@ function AI.is_finished!( # maxdiff = 0.0 initially, so skip this the first time. if state.iteration > 0 st.delta = state.iterate.maxdiff + @info "$(state.iteration): $(st.delta)" end return st.delta < c.tol end -struct BeliefPropagationProblem{Network} <: AIE.Problem - network::Network -end +# struct BeliefPropagationProblem{Network} <: AIE.Problem +# network::Network +# end + +struct BeliefPropagationProblem <: AIE.Problem end -@kwdef mutable struct BeliefPropagationState{ - Iterate <: BeliefPropagationCache, - Diffs, - } <: AIE.NonIterativeAlgorithmState +@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: AIE.NonIterativeAlgorithmState iterate::Iterate diffs::Diffs = similar(edge_data(iterate), Float64) maxdiff::Float64 = 0.0 end -function AI.initialize_state( - problem::BeliefPropagationProblem, - algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... - ) - - diffs = iterate.diffs - maxdiff = iterate.maxdiff - - return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) -end - -# This gets called at the start of every sweep. -function AI.initialize_state!( - problem::BeliefPropagationProblem, - algorithm::AIE.NestedAlgorithm, - state::AIE.State, - ) - state.iterate.maxdiff = 0.0 - return state +@kwdef struct BeliefPropagation{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -function AIE.set_substate!( - ::BeliefPropagationProblem, - ::AIE.NestedAlgorithm, - state::AIE.State, - substate::BeliefPropagationState - ) - - state.iterate = substate - - return state +function BeliefPropagation(f::Function, niterations::Int; kwargs...) + return BeliefPropagation(; algorithms = f.(1:niterations), kwargs...) end abstract type AbstractMessageUpdate <: AIE.NonIterativeAlgorithm end @@ -101,7 +82,7 @@ end function SimpleMessageUpdate( edge; - normalize = false, + normalize = true, contraction_alg = "eager", compute_diff = false, kwargs... @@ -117,6 +98,53 @@ function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) end end +struct BeliefPropagationSweep{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::AI.StopAfterIteration + function BeliefPropagationSweep(; algorithms) + stopping_criterion = AI.StopAfterIteration(length(algorithms)) + return new{eltype(algorithms), typeof(algorithms)}(algorithms, stopping_criterion) + end +end + +BeliefPropagationSweep(f::Function, edges) = BeliefPropagationSweep(; algorithms = f.(edges)) + +function AI.initialize_state( + problem::BeliefPropagationProblem, + update_algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... + ) + + diffs = iterate.diffs + maxdiff = iterate.maxdiff + + return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) +end + +# This gets called at the start of every sweep. +function AI.initialize_state!( + ::BeliefPropagationProblem, + ::BeliefPropagationSweep, + iteration_state::AIE.State, + ) + iteration_state.iterate.maxdiff = 0.0 + return iteration_state +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + sweep_algorithm::BeliefPropagationSweep, + sweep_state::AIE.DefaultState, + noniterative_substate::BeliefPropagationState, + ) + + sweep_state.iterate = noniterative_substate + + return sweep_state +end + struct MessageUpdateProblem{Messages, Factors} <: AIE.Problem messages::Messages factors::Factors @@ -124,7 +152,7 @@ end function AI.solve!( problem::BeliefPropagationProblem, - algorithm::AbstractMessageUpdate, + algorithm::SimpleMessageUpdate, state::BeliefPropagationState; logging_context_prefix = default_logging_context_prefix(problem, algorithm), ) @@ -177,7 +205,7 @@ end function AI.solve!( problem::MessageUpdateProblem, algorithm::SimpleMessageUpdate, - state::AIE.NonIterativeAlgorithmState; + state::AIE.DefaultNonIterativeAlgorithmState; logging_context_prefix = AI.default_logging_context_prefix(problem, algorithm), kwargs... ) @@ -209,24 +237,29 @@ end beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network); kwargs...) function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) - problem = BeliefPropagationProblem(network(cache)) + # problem = BeliefPropagationProblem(network(cache)) + problem = BeliefPropagationProblem() algorithm = select_algorithm(beliefpropagation, cache; kwargs...) # The nested algorithms will wrap and manipulate this object. + base_state = BeliefPropagationState(; iterate = cache) - state = AI.solve(problem, algorithm; iterate = base_state) + state = AI.initialize_state(problem, algorithm; iterate = base_state) + + state = AI.solve!(problem, algorithm, state) return state.iterate.iterate end + function select_algorithm( ::typeof(beliefpropagation), - cache; + cache::AbstractBeliefPropagationCache; edges = forest_cover_edge_sequence(network(cache)), maxiter = is_tree(network(cache)) ? 1 : nothing, - tol = 0.0, + tol = -Inf, kwargs... ) @@ -237,7 +270,7 @@ function select_algorithm( stopping_criterion = AI.StopAfterIteration(maxiter) compute_diff = false - if tol > 0.0 + if tol > -Inf stopping_criterion = stopping_criterion | StopWhenConverged(tol) compute_diff = true end @@ -245,9 +278,14 @@ function select_algorithm( extended_kwargs = extend_columns((; compute_diff, kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, len = maxiter) - return AIE.nested_algorithm(maxiter; stopping_criterion) do repnum - return AIE.nested_algorithm(length(edges)) do edgenum - return SimpleMessageUpdate(edges[edgenum]; edge_kwargs[repnum]...) - end + return BeliefPropagation(maxiter; stopping_criterion) do repnum + return beliefpropagation_sweep(cache; edges, edge_kwargs[repnum]...) + end +end + +# A single sweep across the given edges. +function beliefpropagation_sweep(cache::BeliefPropagationCache; edges, kwargs...) + return BeliefPropagationSweep(edges) do edge + return SimpleMessageUpdate(edge; kwargs...) end end From 992506900fd225d106a57e03346fd62e6f74bc80 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 17:19:04 -0500 Subject: [PATCH 33/64] Simplify BP cache to only store factors --- src/abstracttensornetwork.jl | 26 ++-- .../beliefpropagationcache.jl | 131 +++++++++--------- .../beliefpropagationproblem.jl | 81 ++++++----- 3 files changed, 115 insertions(+), 123 deletions(-) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 671ba3a..c4b6fcb 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,25 +1,17 @@ -using Adapt: Adapt, adapt, adapt_structure +using Adapt: Adapt, adapt using BackendSelection: @Algorithm_str, Algorithm -using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, underlying_graph, - underlying_graph_type, vertex_data, set_vertex_data! +using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, set_vertex_data!, + underlying_graph, underlying_graph_type, vertex_data using Dictionaries: Dictionary -using Graphs: Graphs, AbstractEdge, AbstractGraph, Graph, add_edge!, add_vertex!, - bfs_tree, center, dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices -using LinearAlgebra: LinearAlgebra, factorize +using Graphs: AbstractEdge, AbstractGraph, Graphs, add_edge!, add_vertex!, + dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices +using LinearAlgebra: LinearAlgebra using MacroTools: @capture using NamedDimsArrays: dimnames, inds -using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree +using NamedGraphs: NamedGraph, NamedGraphs, not_implemented using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger -using NamedGraphs.GraphsExtensions: - ⊔, - directed_graph, - incident_edges, - rem_edges!, - rename_vertices, - vertextype, - similar_graph -using SplitApplyCombine: flatten -using NamedGraphs.SimilarType: similar_type +using NamedGraphs.GraphsExtensions: directed_graph, incident_edges, rem_edges!, + similar_graph, vertextype abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 10ab586..2c253e6 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,46 +1,29 @@ -using DataGraphs: - DataGraphs, - AbstractDataGraph, - DataGraph, - get_vertex_data, - get_edge_data, - set_vertex_data!, - set_edge_data!, - vertex_data_type, - edge_data_type, - underlying_graph, - underlying_graph_type, - is_vertex_assigned, - is_edge_assigned -using Dictionaries: Dictionary, set!, delete! -using Graphs: AbstractGraph, is_tree, connected_components, is_directed -using NamedGraphs: NamedDiGraph, convert_vertextype, parent_graph_indices, Vertices -using NamedGraphs.GraphsExtensions: default_root_vertex, - forest_cover, - post_order_dfs_edges, - vertextype, - is_path_graph, - undirected_graph -using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, QuotientEdges, quotient_graph, quotientedges +using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, edge_data_type, + set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, + vertex_data_type +using Dictionaries: Dictionary, delete!, set!, getindices +using Graphs: AbstractGraph, connected_components, is_tree, is_directed using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph -struct BeliefPropagationCache{V, VD, ED, G <: AbstractGraph{V}, N <: AbstractGraph{V}, E} <: - AbstractBeliefPropagationCache{V, VD, ED} +using NamedGraphs: Vertices, convert_vertextype, parent_graph_indices + +struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: AbstractBeliefPropagationCache{V, VD, ED} underlying_graph::G # we only use this for the edges. - network::N + factors::Dictionary{V, VD} messages::Dictionary{E, ED} - function BeliefPropagationCache(network::AbstractGraph, messages::Dictionary) + function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary, messages::Dictionary) + # Ensure the graph is directed, if not make it directed. + digraph = is_directed(graph) ? graph : directed_graph(graph) - V = vertextype(network) - VD = vertex_data_type(network) - N = typeof(network) - ET = keytype(messages) - ED = eltype(messages) + V = keytype(factors) + VD = eltype(factors) - # Construct a directed graph version of the underlying graph of the tensor network. - digraph = directed_graph(underlying_graph(network)) + E = keytype(messages) + ED = eltype(messages) - bpc = new{V, VD, ED, typeof(digraph), N, ET}(digraph, network, messages) + bpc = new{V, VD, ED, E, typeof(digraph)}(digraph, factors, messages) for edge in edges(bpc) get!(() -> default_message(bpc, edge), messages, edge) @@ -49,30 +32,39 @@ struct BeliefPropagationCache{V, VD, ED, G <: AbstractGraph{V}, N <: AbstractGra end end -network(bp_cache) = getfield(bp_cache, :network) - -DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = getfield(bpc, :underlying_graph) +DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = bpc.underlying_graph -DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = is_vertex_assigned(network(bpc), vertex) +DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = haskey(bpc.factors, vertex) DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) -DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = get_vertex_data(network(bpc), vertex) +DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = bpc.factors[vertex] DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] -DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set_vertex_data!(network(bpc), val, vertex) +DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set!(bpc.factors, vertex, val) DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) +# These two methods assume `network` behaves llike a tensor network +# (could be e.g. a QuotientView) otherwise how would one know what the factors should be. function BeliefPropagationCache(network::AbstractGraph) - MT = vertex_data_type(typeof(network)) - return BeliefPropagationCache(MT, network) + graph = underlying_graph(network) + return BeliefPropagationCache(graph, copy(vertex_data(network))) end function BeliefPropagationCache(MT::Type, network::AbstractGraph) - dict = Dictionary{edgetype(network), MT}() - return BeliefPropagationCache(network, dict) + graph = underlying_graph(network) + return BeliefPropagationCache(MT, graph, copy(vertex_data(network))) +end + +function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary) + MT = vertex_data_type(typeof(graph)) + return BeliefPropagationCache(MT, graph, factors) +end +function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Dictionary) + messages = Dictionary{edgetype(graph), MT}() + return BeliefPropagationCache(graph, factors, messages) end function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) + return BeliefPropagationCache(copy(bp_cache.underlying_graph), copy(bp_cache.factors), copy(bp_cache.messages)) end # TODO: This needs to go in GraphsExtensions @@ -92,41 +84,50 @@ function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_roo return rv end -function bpcache_induced_subgraph(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(network(graph), subvertices) +function induced_subgraph_bpcache(graph, subvertices) + underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) - edge_data = Dictionary{edgetype(underlying_subgraph), edge_data_type(typeof(graph))}() + assigned = v -> isassigned(graph, v) + + assigned_subvertices = Iterators.filter(assigned, subvertices) + assigned_subedges = Iterators.filter(assigned, edges(underlying_subgraph)) + + factors = getindices(vertex_data(graph), Indices(assigned_subvertices)) + messages = getindices(edge_data(graph), Indices(assigned_subedges)) + + subgraph = BeliefPropagationCache(underlying_subgraph, factors, messages) - subgraph = BeliefPropagationCache(underlying_subgraph, edge_data) - for e in edges(subgraph) - if isassigned(graph, e) - subgraph[e] = graph[e] - end - end return subgraph, vlist end function NamedGraphs.induced_subgraph_from_vertices(graph::BeliefPropagationCache, subvertices) - return bpcache_induced_subgraph(graph, subvertices) + return induced_subgraph_bpcache(graph, subvertices) end ## PartitionedGraphs +# Take a QuotientView of the underlying graph. function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) - inds = Indices(parent_graph_indices(QuotientEdges(underlying_graph(bpc)))) - data = map(e -> bpc[QuotientEdge(e)], inds) - return BeliefPropagationCache(QuotientView(network(bpc)), data) + + graph = underlying_graph(bpc) + + quotient_view = QuotientView(graph) + + factors = map(v -> bpc[QuotientVertex(v)], Indices(vertices(quotient_view))) + messages = map(e -> bpc[QuotientEdge(e)], Indices(edges(quotient_view))) + + return BeliefPropagationCache(quotient_view, factors, messages) end function default_message(bpc::BeliefPropagationCache, edge) - return default_message(message_type(bpc), network(bpc), edge) + return default_message(message_type(bpc), bpc[src(edge)], bpc[dst(edge)]) end -function default_message(T::Type, network, edge) - array = ones(Tuple(linkinds(network, edge))) +function default_message(T::Type, src, dst) + array = ones(Tuple(inds(src) ∩ inds(dst))) return convert(T, array) end -function default_message(T::Type{<:LazyNamedDimsArray}, network, edge) - message = default_message(parenttype(T), network, edge) +function default_message(T::Type{<:LazyNamedDimsArray}, src, dst) + message = default_message(parenttype(T), src, dst) return convert(T, lazy(message)) end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 75023b3..89c28df 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,10 +1,9 @@ -using Graphs: SimpleGraph, vertices, edges, has_edge, AbstractEdge -using NamedGraphs: AbstractNamedGraph, position_graph -using NamedGraphs.GraphsExtensions: add_edges!, partition_vertices, subgraph, boundary_edges -using NamedGraphs.OrderedDictionaries: OrderedDictionary, OrderedIndices +using Graphs: AbstractEdge, edges, has_edge, vertices +using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedDimsArrays: AbstractNamedDimsArray using NamedGraphs.PartitionedGraphs: quotientvertices using DataGraphs: edge_data +using LinearAlgebra: norm, normalize import AlgorithmsInterface as AI import .AlgorithmsInterfaceExtensions as AIE @@ -42,17 +41,14 @@ function AI.is_finished!( # maxdiff = 0.0 initially, so skip this the first time. if state.iteration > 0 st.delta = state.iterate.maxdiff - @info "$(state.iteration): $(st.delta)" end return st.delta < c.tol end -# struct BeliefPropagationProblem{Network} <: AIE.Problem -# network::Network -# end - -struct BeliefPropagationProblem <: AIE.Problem end +struct BeliefPropagationProblem{Network} <: AIE.Problem + network::Network +end @kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: AIE.NonIterativeAlgorithmState iterate::Iterate @@ -113,8 +109,7 @@ end BeliefPropagationSweep(f::Function, edges) = BeliefPropagationSweep(; algorithms = f.(edges)) function AI.initialize_state( - problem::BeliefPropagationProblem, - update_algorithm::AIE.NonIterativeAlgorithm; iterate, kwargs... + ::BeliefPropagationProblem, ::AIE.NonIterativeAlgorithm; iterate, kwargs... ) diffs = iterate.diffs @@ -135,7 +130,7 @@ end function AIE.set_substate!( ::BeliefPropagationProblem, - sweep_algorithm::BeliefPropagationSweep, + ::BeliefPropagationSweep, sweep_state::AIE.DefaultState, noniterative_substate::BeliefPropagationState, ) @@ -145,16 +140,16 @@ function AIE.set_substate!( return sweep_state end -struct MessageUpdateProblem{Messages, Factors} <: AIE.Problem +struct MessageUpdateProblem{Factor, Messages} <: AIE.Problem + factor::Factor messages::Messages - factors::Factors end function AI.solve!( problem::BeliefPropagationProblem, algorithm::SimpleMessageUpdate, state::BeliefPropagationState; - logging_context_prefix = default_logging_context_prefix(problem, algorithm), + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), ) logger = AI.algorithm_logger() @@ -168,8 +163,8 @@ function AI.solve!( new_message = updated_message(algorithm, cache) - if algorithm.compute_diff - diff = message_diff(new_message, cache[edge]) + if !isnothing(algorithm.message_diff_function) + diff = algorithm.message_diff_function(new_message, cache[edge]) if diff > state.maxdiff state.maxdiff = diff @@ -187,7 +182,7 @@ function AI.solve!( return state end -message_diff(m1, m2) = LinearAlgebra.norm(m1 - m2) +default_message_diff_function(m1, m2) = norm(normalize(m1) - normalize(m2)) function updated_message(algorithm, cache) edge = algorithm.edge @@ -195,7 +190,7 @@ function updated_message(algorithm, cache) vertex = src(edge) messages = incoming_messages(cache, vertex; ignore_edges = typeof(edge)[reverse(edge)]) - update_problem = MessageUpdateProblem(messages, factors(cache, vertex)) + update_problem = MessageUpdateProblem(cache[vertex], messages) message_state = AI.solve(update_problem, algorithm; iterate = message(cache, edge)) @@ -206,13 +201,21 @@ function AI.solve!( problem::MessageUpdateProblem, algorithm::SimpleMessageUpdate, state::AIE.DefaultNonIterativeAlgorithmState; - logging_context_prefix = AI.default_logging_context_prefix(problem, algorithm), + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), kwargs... ) - # TODO: logging... + logger = AI.algorithm_logger() + + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) + ) + + state.iterate = contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) - state.iterate = contract_messages(algorithm.contraction_alg, problem.factors, problem.messages) + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreNormalization) + ) if algorithm.normalize # TODO: use `sum` not `norm` @@ -222,28 +225,26 @@ function AI.solve!( end end + AI.emit_message( + logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostNormalization) + ) + return state end -contract_messages(alg, factors, messages) = not_implemented() -function contract_messages( - alg, - factors::Vector{<:AbstractArray}, - messages::Vector{<:AbstractArray}, - ) +function contract_messages(alg, factor::AbstractArray, messages::Vector{<:AbstractArray}) + factors = typeof(factor)[factor] return contract_network(vcat(factors, messages); alg) end -beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network); kwargs...) -function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) +beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network), network; kwargs...) +function beliefpropagation(cache::AbstractBeliefPropagationCache, network = nothing; kwargs...) - # problem = BeliefPropagationProblem(network(cache)) - problem = BeliefPropagationProblem() + problem = BeliefPropagationProblem(network) algorithm = select_algorithm(beliefpropagation, cache; kwargs...) # The nested algorithms will wrap and manipulate this object. - base_state = BeliefPropagationState(; iterate = cache) state = AI.initialize_state(problem, algorithm; iterate = base_state) @@ -253,13 +254,13 @@ function beliefpropagation(cache::AbstractBeliefPropagationCache; kwargs...) return state.iterate.iterate end - function select_algorithm( ::typeof(beliefpropagation), cache::AbstractBeliefPropagationCache; - edges = forest_cover_edge_sequence(network(cache)), - maxiter = is_tree(network(cache)) ? 1 : nothing, + edges = forest_cover_edge_sequence(cache), + maxiter = is_tree(cache) ? 1 : nothing, tol = -Inf, + message_diff_function = tol > -Inf ? (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) : nothing, kwargs... ) @@ -268,14 +269,12 @@ function select_algorithm( end stopping_criterion = AI.StopAfterIteration(maxiter) - compute_diff = false if tol > -Inf stopping_criterion = stopping_criterion | StopWhenConverged(tol) - compute_diff = true end - extended_kwargs = extend_columns((; compute_diff, kwargs...), maxiter) + extended_kwargs = extend_columns((; message_diff_function, kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, len = maxiter) return BeliefPropagation(maxiter; stopping_criterion) do repnum @@ -284,7 +283,7 @@ function select_algorithm( end # A single sweep across the given edges. -function beliefpropagation_sweep(cache::BeliefPropagationCache; edges, kwargs...) +function beliefpropagation_sweep(::BeliefPropagationCache; edges, kwargs...) return BeliefPropagationSweep(edges) do edge return SimpleMessageUpdate(edge; kwargs...) end From 292f2fa10be8626746f87148c95ea0fb0ba17ae8 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 17:28:23 -0500 Subject: [PATCH 34/64] Upgrade to DataGraphs v0.3.1 and NamedGraphs v0.10 --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index efd1d3c..c7133ff 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ Adapt = "4.3" AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" -DataGraphs = "0.2.7" +DataGraphs = "0.3.1" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5" FunctionImplementations = "0.4" @@ -47,7 +47,7 @@ Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.14.2" -NamedGraphs = "0.6.9, 0.7, 0.8" +NamedGraphs = "0.10" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" TensorOperations = "5.3.1" From 9d937aa366d7afb54ab3e918a7039606de148112 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 17:38:37 -0500 Subject: [PATCH 35/64] Fix compat --- Project.toml | 4 ++-- test/Project.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index c7133ff..1da8abe 100644 --- a/Project.toml +++ b/Project.toml @@ -42,11 +42,11 @@ Combinatorics = "1" DataGraphs = "0.3.1" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5" -FunctionImplementations = "0.4" +FunctionImplementations = "0.4.1" Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" -NamedDimsArrays = "0.14.2" +NamedDimsArrays = "0.14.3" NamedGraphs = "0.10" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" diff --git a/test/Project.toml b/test/Project.toml index 975c2c1..cf048b7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,7 +29,7 @@ Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.5" ITensorNetworksNext = "0.3" -NamedDimsArrays = "0.13" +NamedDimsArrays = "0.14" NamedGraphs = "0.10" QuadGK = "2.11.2" SafeTestsets = "0.1" From 5432fe28bb172ff61bb8a191b5de4604da06ef53 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 18:08:12 -0500 Subject: [PATCH 36/64] Fix broken merge Fix broken merge --- .../beliefpropagationproblem.jl | 4 +- src/contract_network.jl | 54 +++++++++---------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 89c28df..c127655 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -79,7 +79,7 @@ end function SimpleMessageUpdate( edge; normalize = true, - contraction_alg = "eager", + contraction_alg = "exact", compute_diff = false, kwargs... ) @@ -275,7 +275,7 @@ function select_algorithm( end extended_kwargs = extend_columns((; message_diff_function, kwargs...), maxiter) - edge_kwargs = rows(extended_kwargs, len = maxiter) + edge_kwargs = rows(extended_kwargs, maxiter) return BeliefPropagation(maxiter; stopping_criterion) do repnum return beliefpropagation_sweep(cache; edges, edge_kwargs[repnum]...) diff --git a/src/contract_network.jl b/src/contract_network.jl index 4fda3a7..a8c3fc7 100644 --- a/src/contract_network.jl +++ b/src/contract_network.jl @@ -1,11 +1,27 @@ using BackendSelection: @Algorithm_str, Algorithm using Base.Broadcast: materialize -using NamedDimsArrays: inds -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, Mul, lazy, optimize_evaluation_order, +using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order, substitute, symnameddims -function contract_network(tn; alg = default_kwargs(contract_network, tn).alg) - return contract_network(alg, tn) +# This is related to `MatrixAlgebraKit.select_algorithm`. +# TODO: Define this in BackendSelection.jl. +backend_value(::Algorithm{alg}) where {alg} = alg +using BackendSelection: parameters +function merge_parameters(alg::Algorithm; kwargs...) + return Algorithm(backend_value(alg); merge(parameters(alg), kwargs)...) +end +to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...) +to_algorithm(alg; kwargs...) = Algorithm(alg; kwargs...) + +# `contract_network` +function contract_network(alg::Algorithm, tn) + return throw(ArgumentError("`contract_network` algorithm `$(alg)` not implemented.")) +end +function default_kwargs(::typeof(contract_network), tn) + return (; alg = Algorithm"exact"(; order_alg = Algorithm"eager"())) +end +function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...) + return contract_network(to_algorithm(alg; kwargs...), tn) end # `contract_network(::Algorithm"exact", ...)` @@ -34,24 +50,12 @@ end # `contraction_order` function contraction_order end -default_kwargs(::typeof(contraction_order), tensors) = (; order = "eager") - -function contraction_expression(tensors; order = default_kwargs(contraction_order, tensors).order) - order = contraction_order(order, tensors) - - # Contraction order may or may not have indices attached, canonicalize the format - # by attaching indices. - subs = Dict(symnameddims(i) => symnameddims(i, tensors[i]) for i in keys(tensors)) - - return substitute(order, subs) -end - -contraction_order(order, tensors) = order -function contraction_order(tensors; order = default_kwargs(contraction_order, tensors).order) - return contraction_order(Algorithm(order), tensors) +default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"()) +function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...) + return contraction_order(to_algorithm(alg; kwargs...), tn) end # Convert the tensor network to a flat symbolic multiplication expression. -function contraction_order(::Algorithm"flat", tensors) +function contraction_order(alg::Algorithm"flat", tn) # Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`. syms = vec([symnameddims(i, Tuple(axes(tn[i]))) for i in keys(tn)]) return lazy(Mul(syms)) @@ -59,11 +63,7 @@ end function contraction_order(alg::Algorithm"left_associative", tn) return prod(i -> symnameddims(i, Tuple(axes(tn[i]))), keys(tn)) end - -function contraction_order( - order_algorithm::Algorithm, - tensors, - ) - order = contraction_order(tensors; order = "flat") - return optimize_evaluation_order(order; alg = order_algorithm) +function contraction_order(alg::Algorithm, tn) + s = contraction_order(Algorithm"flat"(), tn) + return optimize_evaluation_order(s; alg) end From c916c84c19502294b77aeca61165b778ddbd66c8 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 19 Feb 2026 17:44:59 -0500 Subject: [PATCH 37/64] Bug fix; upgrade tests --- .../beliefpropagationproblem.jl | 2 +- test/Project.toml | 1 + test/test_contract_network.jl | 16 +++++++++------- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index c127655..0312843 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -232,7 +232,7 @@ function AI.solve!( return state end -function contract_messages(alg, factor::AbstractArray, messages::Vector{<:AbstractArray}) +function contract_messages(alg, factor::AbstractArray, messages) factors = typeof(factor)[factor] return contract_network(vcat(factors, messages); alg) end diff --git a/test/Project.toml b/test/Project.toml index cf048b7..8b1072a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" AlgorithmsInterface = "d1e3940c-cd12-4505-8585-b0a4b322527d" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index fc863f6..35b2275 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -5,8 +5,11 @@ using ITensorBase: Index using ITensorNetworksNext: TensorNetwork, linkinds, siteinds, contract_network using TensorOperations: TensorOperations using Test: @test, @testset +using BackendSelection: @Algorithm_str, Algorithm @testset "contract_network" begin + orderalg = alg -> Algorithm"exact"(; order_alg = Algorithm(alg)) + @testset "Contract Vectors of ITensors" begin i, j, k = Index(2), Index(2), Index(5) A = [1.0 1.0; 0.5 1.0][i, j] @@ -14,10 +17,9 @@ using Test: @test, @testset C = [5.0, 1.0][j] D = [-2.0, 3.0, 4.0, 5.0, 1.0][k] - ABCD_1 = contract_network([A, B, C, D]; alg = "left_associative") - ABCD_2 = contract_network([A, B, C, D]; alg = "eager") - ABCD_3 = contract_network([A, B, C, D]; alg = "optimal") - + ABCD_1 = contract_network([A, B, C, D]; alg = orderalg("left_associative")) + ABCD_2 = contract_network([A, B, C, D]; alg = orderalg("eager")) + ABCD_3 = contract_network([A, B, C, D]; alg = orderalg("optimal")) @test ABCD_1 == ABCD_2 == ABCD_3 end @@ -31,9 +33,9 @@ using Test: @test, @testset return randn(Tuple(is)) end - z1 = contract_network(tn; alg = "left_associative")[] - z2 = contract_network(tn; alg = "eager")[] - z3 = contract_network(tn; alg = "optimal")[] + z1 = contract_network(tn; alg = orderalg("left_associative"))[] + z2 = contract_network(tn; alg = orderalg("eager"))[] + z3 = contract_network(tn; alg = orderalg("optimal"))[] @test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64) @test abs(z1 - z3) / abs(z1) <= 1.0e3 * eps(Float64) From 4a511a159d298ef466108b7af250b754c6d0dc35 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 24 Feb 2026 09:41:03 -0500 Subject: [PATCH 38/64] Add 2D TN test --- test/Project.toml | 1 + test/test_beliefpropagation.jl | 64 +++++++++++++++++++++++++++------- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 8b1072a..50a58c5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 8c7829b..8a817b2 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,22 +1,48 @@ -using Dictionaries: Dictionary -using ITensorBase: Index +using Dictionaries: Dictionary, set! +using ITensorBase: Index, ITensor, prime, noprime using ITensorNetworksNext: BeliefPropagationCache, ITensorNetworksNext, TensorNetwork, - adapt_messages, - default_message, - default_messages, - edge_scalars, - factors, - messages, - partitionfunction, - setmessages! -using Graphs: edges, vertices + partitionfunction +using DiagonalArrays: δ +using Graphs: src, dst, edges, vertices, AbstractGraph using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree -using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype using Test: @test, @testset +using LinearAlgebra: LinearAlgebra +using NamedDimsArrays: name, inds +function ising_tensornetwork(g::AbstractGraph, β::Real; h = 0.0) + links = Dictionary(edges(g), [Index(2; tags = "edge" => "e$(src(e))_$(dst(e))") for e in edges(g)]) + links = merge(links, Dictionary(reverse.(edges(g)), [links[e] for e in edges(g)])) + + # symmetric sqrt of Boltzmann matrix W = exp(β σσ') + sqrt_Ws = Dictionary() + for e in edges(g) + W = [ exp(-(β + 2 * h)) exp(β); exp(β) exp(-(β - 2 * h)) ] + + F = LinearAlgebra.svd(W) + U, S, V = F.U, F.S, F.Vt + @assert U * LinearAlgebra.diagm(S) * V ≈ W + id = [1.0 0.0; 0.0 1.0] + set!(sqrt_Ws, e, id) + set!(sqrt_Ws, reverse(e), U * LinearAlgebra.diagm(S) * V) + end + ts = Dictionary{vertextype(g), ITensor}() + for v in vertices(g) + es = incident_edges(g, v; dir = :in) + #t = ITensor(1.0, physical_inds[v]...) * delta([links[e] for e in es]) + t = δ(Float64, Tuple([links[e] for e in es])) + for e in es + t_prime = ITensor(sqrt_Ws[e], (name(links[e]), name(prime(links[e])))) * t + newinds = noprime.(inds(t_prime)) + t = ITensor(parent(t_prime), name.(newinds)) + end + set!(ts, v, t) + end + return TensorNetwork(g, ts) +end @testset "BeliefPropagation" begin #Chain of tensors @@ -49,5 +75,17 @@ using Test: @test, @testset bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 1) z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] - @test z_bp ≈ z_exact atol = 1.0e-12 + @test z_bp ≈ z_exact atol = 1.0e-10 + + #Square lattice Ising model + dims = (3, 3) + g = named_grid(dims) + tn = ising_tensornetwork(g, 0.05, h = 0.5) + bpc = ITensorNetworksNext.BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.beliefpropagation(bpc; maxiter = 50, tol = 1.0e-10) + + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test z_bp ≈ z_exact rtol = 1.0e-4 + end From 5b97af3a6b5a219c09b6d7db9e40022ab398bb51 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 24 Feb 2026 09:47:03 -0500 Subject: [PATCH 39/64] Formatting --- docs/make.jl | 9 +-- docs/make_index.jl | 4 +- docs/make_readme.jl | 4 +- .../ITensorNetworksNextTensorOperationsExt.jl | 4 +- .../AlgorithmsInterfaceExtensions.jl | 41 ++++-------- src/LazyNamedDimsArrays/symbolicarray.jl | 8 ++- src/TensorNetworkGenerators/delta_network.jl | 2 +- src/TensorNetworkGenerators/ising_network.jl | 2 +- src/abstracttensornetwork.jl | 16 ++--- .../abstractbeliefpropagationcache.jl | 13 ++-- .../beliefpropagationcache.jl | 58 +++++++++++----- .../beliefpropagationproblem.jl | 66 +++++++++++-------- src/contract_network.jl | 4 +- src/sweeping/eigenproblem.jl | 2 +- src/tensornetwork.jl | 47 ++++++------- test/runtests.jl | 15 +++-- test/test_algorithmsinterfaceextensions.jl | 14 ++-- test/test_aqua.jl | 2 +- test/test_basics.jl | 2 +- test/test_beliefpropagation.jl | 25 ++++--- test/test_contract_network.jl | 6 +- test/test_dmrg.jl | 4 +- test/test_lazynameddimsarrays.jl | 8 +-- test/test_tensornetworkgenerators.jl | 2 +- 24 files changed, 195 insertions(+), 163 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 1b29518..c4f46f3 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,5 +1,5 @@ -using ITensorNetworksNext: ITensorNetworksNext using Documenter: Documenter, DocMeta, deploydocs, makedocs +using ITensorNetworksNext: ITensorNetworksNext DocMeta.setdocmeta!( ITensorNetworksNext, :DocTestSetup, :(using ITensorNetworksNext); recursive = true @@ -14,11 +14,12 @@ makedocs(; format = Documenter.HTML(; canonical = "https://itensor.github.io/ITensorNetworksNext.jl", edit_link = "main", - assets = ["assets/favicon.ico", "assets/extras.css"], + assets = ["assets/favicon.ico", "assets/extras.css"] ), - pages = ["Home" => "index.md", "Reference" => "reference.md"], + pages = ["Home" => "index.md", "Reference" => "reference.md"] ) deploydocs(; - repo = "github.com/ITensor/ITensorNetworksNext.jl", devbranch = "main", push_preview = true + repo = "github.com/ITensor/ITensorNetworksNext.jl", devbranch = "main", + push_preview = true ) diff --git a/docs/make_index.jl b/docs/make_index.jl index 038bc87..af08861 100644 --- a/docs/make_index.jl +++ b/docs/make_index.jl @@ -1,5 +1,5 @@ -using Literate: Literate using ITensorNetworksNext: ITensorNetworksNext +using Literate: Literate function ccq_logo(content) include_ccq_logo = """ @@ -17,5 +17,5 @@ Literate.markdown( joinpath(pkgdir(ITensorNetworksNext), "docs", "src"); flavor = Literate.DocumenterFlavor(), name = "index", - postprocess = ccq_logo, + postprocess = ccq_logo ) diff --git a/docs/make_readme.jl b/docs/make_readme.jl index 088dc58..52d0dbb 100644 --- a/docs/make_readme.jl +++ b/docs/make_readme.jl @@ -1,5 +1,5 @@ -using Literate: Literate using ITensorNetworksNext: ITensorNetworksNext +using Literate: Literate function ccq_logo(content) include_ccq_logo = """ @@ -17,5 +17,5 @@ Literate.markdown( joinpath(pkgdir(ITensorNetworksNext)); flavor = Literate.CommonMarkFlavor(), name = "README", - postprocess = ccq_logo, + postprocess = ccq_logo ) diff --git a/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl index 4766ee6..972b11e 100644 --- a/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl +++ b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl @@ -1,9 +1,9 @@ module ITensorNetworksNextTensorOperationsExt using BackendSelection: @Algorithm_str, Algorithm -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, ismul, symnameddims, - substitute using ITensorNetworksNext.LazyNamedDimsArrays.TermInterface: arguments +using ITensorNetworksNext.LazyNamedDimsArrays: + LazyNamedDimsArrays, ismul, substitute, symnameddims using NamedDimsArrays: inds using TensorOperations: TensorOperations, optimaltree diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index 3c887b7..69a4a97 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -1,8 +1,6 @@ module AlgorithmsInterfaceExtensions -import AlgorithmsInterface as AI - -#========================== Patches for AlgorithmsInterface.jl ============================# +import AlgorithmsInterface as AI #========================== Patches for AlgorithmsInterface.jl ============================# abstract type Problem <: AI.Problem end abstract type Algorithm <: AI.Algorithm end @@ -28,9 +26,7 @@ function AI.initialize_state( problem, algorithm, algorithm.stopping_criterion ) return DefaultState(; stopping_criterion_state, kwargs...) -end - -#============================ DefaultState ================================================# +end #============================ DefaultState ================================================# @kwdef mutable struct DefaultState{ Iterate, StoppingCriterionState <: AI.StoppingCriterionState, @@ -38,16 +34,12 @@ end iterate::Iterate iteration::Int = 0 stopping_criterion_state::StoppingCriterionState -end - -#============================ increment! ==================================================# +end #============================ increment! ==================================================# # Custom version of `increment!` that also takes the problem and algorithm as arguments. function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) return AI.increment!(state) -end - -#============================ solve! ======================================================# +end #============================ solve! ======================================================# # Custom version of `solve!` that allows specifying the logger and also overloads # `increment!` on the problem and algorithm. @@ -58,13 +50,13 @@ default_logging_context_prefix(x) = Symbol(basetypenameof(x), :_) function default_logging_context_prefix(problem::Problem, algorithm::Algorithm) return Symbol( default_logging_context_prefix(problem), - default_logging_context_prefix(algorithm), + default_logging_context_prefix(algorithm) ) end function AI.solve!( problem::Problem, algorithm::Algorithm, state::State; logging_context_prefix = default_logging_context_prefix(problem, algorithm), - kwargs..., + kwargs... ) logger = AI.algorithm_logger() @@ -97,13 +89,11 @@ end function AI.solve( problem::Problem, algorithm::Algorithm; logging_context_prefix = default_logging_context_prefix(problem, algorithm), - kwargs..., + kwargs... ) state = AI.initialize_state(problem, algorithm; kwargs...) return AI.solve!(problem, algorithm, state; logging_context_prefix, kwargs...) -end - -#============================ AlgorithmIterator ===========================================# +end #============================ AlgorithmIterator ===========================================# abstract type AlgorithmIterator end @@ -136,9 +126,7 @@ struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator problem::Problem algorithm::Algorithm state::State -end - -#============================ with_algorithmlogger ========================================# +end #============================ with_algorithmlogger ========================================# # Allow passing functions, not just CallbackActions. @inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...) @@ -146,9 +134,7 @@ end end @inline function with_algorithmlogger(f, args::Pair{Symbol}...) return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...) -end - -#============================ NestedAlgorithm =============================================# +end #============================ NestedAlgorithm =============================================# abstract type NestedAlgorithm <: Algorithm end @@ -213,8 +199,7 @@ end function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) return DefaultNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) -end -#============================ FlattenedAlgorithm ==========================================# +end #============================ FlattenedAlgorithm ==========================================# # Flatten a nested algorithm. abstract type FlattenedAlgorithm <: Algorithm end @@ -284,9 +269,7 @@ end parent_iteration::Int = 1 child_iteration::Int = 0 stopping_criterion_state::StoppingCriterionState -end - -#============================ NonIterativeAlgorithm =======================================# +end #============================ NonIterativeAlgorithm =======================================# # Algorithm that only performs a single step. abstract type NonIterativeAlgorithm <: Algorithm end diff --git a/src/LazyNamedDimsArrays/symbolicarray.jl b/src/LazyNamedDimsArrays/symbolicarray.jl index a0922fd..e3ff4d4 100644 --- a/src/LazyNamedDimsArrays/symbolicarray.jl +++ b/src/LazyNamedDimsArrays/symbolicarray.jl @@ -1,8 +1,12 @@ # TODO: Allow dynamic/unknown number of dimensions by supporting vector axes. -struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} +struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: + AbstractArray{T, N} name::Name axes::Axes - function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T} + function SymbolicArray{T}( + name, + ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} + ) where {T} N = length(ax) return new{T, N, typeof(name), typeof(ax)}(name, ax) end diff --git a/src/TensorNetworkGenerators/delta_network.jl b/src/TensorNetworkGenerators/delta_network.jl index 8b28def..e6a453c 100644 --- a/src/TensorNetworkGenerators/delta_network.jl +++ b/src/TensorNetworkGenerators/delta_network.jl @@ -1,6 +1,6 @@ +using ..ITensorNetworksNext: TensorNetwork using DiagonalArrays: δ using Graphs: AbstractGraph -using ..ITensorNetworksNext: TensorNetwork using NamedGraphs.GraphsExtensions: incident_edges """ diff --git a/src/TensorNetworkGenerators/ising_network.jl b/src/TensorNetworkGenerators/ising_network.jl index 1f2fa31..e37551c 100644 --- a/src/TensorNetworkGenerators/ising_network.jl +++ b/src/TensorNetworkGenerators/ising_network.jl @@ -1,6 +1,6 @@ +using ..ITensorNetworksNext: @preserve_graph using DiagonalArrays: DiagonalArray using Graphs: degree, dst, edges, src -using ..ITensorNetworksNext: @preserve_graph using LinearAlgebra: Diagonal, eigen using NamedDimsArrays: apply, denamed, name, operator, randname using NamedGraphs.GraphsExtensions: vertextype diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index c4b6fcb..7fca799 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,17 +1,17 @@ using Adapt: Adapt, adapt using BackendSelection: @Algorithm_str, Algorithm -using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, set_vertex_data!, +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data using Dictionaries: Dictionary -using Graphs: AbstractEdge, AbstractGraph, Graphs, add_edge!, add_vertex!, - dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices +using Graphs: Graphs, AbstractEdge, AbstractGraph, add_edge!, add_vertex!, dst, edges, + edgetype, ne, neighbors, nv, rem_edge!, src, vertices using LinearAlgebra: LinearAlgebra using MacroTools: @capture using NamedDimsArrays: dimnames, inds -using NamedGraphs: NamedGraph, NamedGraphs, not_implemented +using NamedGraphs.GraphsExtensions: + directed_graph, incident_edges, rem_edges!, similar_graph, vertextype using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger -using NamedGraphs.GraphsExtensions: directed_graph, incident_edges, rem_edges!, - similar_graph, vertextype +using NamedGraphs: NamedGraphs, NamedGraph, not_implemented abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end @@ -125,7 +125,7 @@ is_assignment_expr(expr) = false macro preserve_graph(expr) if !is_setindex!_expr(expr) error( - "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)", + "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)" ) end @capture(expr, array_[indices__] = value_) @@ -207,7 +207,7 @@ Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented() function Base.setindex!( tn::AbstractTensorNetwork, value, - edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger}, + edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger} ) return not_implemented() end diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index b77fb4e..33f185b 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,7 +1,7 @@ -using Graphs: AbstractGraph, AbstractEdge -using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_type +using DataGraphs: AbstractDataGraph, edge_data, edge_data_type, vertex_data +using Graphs: AbstractEdge, AbstractGraph using NamedGraphs.GraphsExtensions: boundary_edges -using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent messages(bp_cache::AbstractGraph) = edge_data(bp_cache) messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] @@ -63,7 +63,6 @@ function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) end function region_scalar(bp_cache::AbstractGraph, vertex) - messages = incoming_messages(bp_cache, vertex) state = factors(bp_cache, vertex) @@ -78,7 +77,10 @@ function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) return map(v -> region_scalar(bp_cache, v), vertices) end -function edge_scalars(bp_cache::AbstractGraph, edges = edges(undirected_graph(underlying_graph(bp_cache)))) +function edge_scalars( + bp_cache::AbstractGraph, + edges = edges(undirected_graph(underlying_graph(bp_cache))) + ) return map(e -> region_scalar(bp_cache, e), edges) end @@ -123,7 +125,6 @@ message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED function free_energy(bp_cache::AbstractBeliefPropagationCache) - numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) if any(t -> real(t) < 0, numerator_terms) diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 2c253e6..5d1a31c 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,19 +1,23 @@ -using DataGraphs: AbstractDataGraph, DataGraphs, edge_data, edge_data_type, - set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, - vertex_data_type -using Dictionaries: Dictionary, delete!, set!, getindices -using Graphs: AbstractGraph, connected_components, is_tree, is_directed +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, edge_data_type, + set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, vertex_data_type +using Dictionaries: Dictionary, delete!, getindices, set! +using Graphs: AbstractGraph, connected_components, is_directed, is_tree using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype -using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype +using NamedGraphs.GraphsExtensions: + default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph - using NamedGraphs: Vertices, convert_vertextype, parent_graph_indices -struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: AbstractBeliefPropagationCache{V, VD, ED} +struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: + AbstractBeliefPropagationCache{V, VD, ED} underlying_graph::G # we only use this for the edges. factors::Dictionary{V, VD} messages::Dictionary{E, ED} - function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary, messages::Dictionary) + function BeliefPropagationCache( + graph::AbstractGraph, + factors::Dictionary, + messages::Dictionary + ) # Ensure the graph is directed, if not make it directed. digraph = is_directed(graph) ? graph : directed_graph(graph) @@ -34,14 +38,22 @@ end DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = bpc.underlying_graph -DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) = haskey(bpc.factors, vertex) +function DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) + return haskey(bpc.factors, vertex) +end DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = bpc.factors[vertex] -DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) = bpc.messages[edge] +function DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) + return bpc.messages[edge] +end -DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) = set!(bpc.factors, vertex, val) -DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) = set!(bpc.messages, edge, val) +function DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) + return set!(bpc.factors, vertex, val) +end +function DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) + return set!(bpc.messages, edge, val) +end # These two methods assume `network` behaves llike a tensor network # (could be e.g. a QuotientView) otherwise how would one know what the factors should be. @@ -64,7 +76,11 @@ function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Diction end function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache(copy(bp_cache.underlying_graph), copy(bp_cache.factors), copy(bp_cache.messages)) + return BeliefPropagationCache( + copy(bp_cache.underlying_graph), + copy(bp_cache.factors), + copy(bp_cache.messages) + ) end # TODO: This needs to go in GraphsExtensions @@ -85,7 +101,8 @@ function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_roo end function induced_subgraph_bpcache(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + underlying_subgraph, vlist = + Graphs.induced_subgraph(underlying_graph(graph), subvertices) assigned = v -> isassigned(graph, v) @@ -100,7 +117,10 @@ function induced_subgraph_bpcache(graph, subvertices) return subgraph, vlist end -function NamedGraphs.induced_subgraph_from_vertices(graph::BeliefPropagationCache, subvertices) +function NamedGraphs.induced_subgraph_from_vertices( + graph::BeliefPropagationCache, + subvertices + ) return induced_subgraph_bpcache(graph, subvertices) end @@ -108,7 +128,6 @@ end # Take a QuotientView of the underlying graph. function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) - graph = underlying_graph(bpc) quotient_view = QuotientView(graph) @@ -137,6 +156,9 @@ function DataGraphs.get_index_data(tn::BeliefPropagationCache, vertex::QuotientV data = collect(map(v -> tn[v], vertices(tn, vertex))) return mapreduce(lazy, *, data) end -function DataGraphs.is_graph_index_assigned(tn::BeliefPropagationCache, vertex::QuotientVertex) +function DataGraphs.is_graph_index_assigned( + tn::BeliefPropagationCache, + vertex::QuotientVertex + ) return isassigned(tn, Vertices(vertices(tn, vertex))) end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 0312843..1a62792 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,12 +1,11 @@ +import .AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI +using DataGraphs: edge_data using Graphs: AbstractEdge, edges, has_edge, vertices -using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph +using LinearAlgebra: norm, normalize using NamedDimsArrays: AbstractNamedDimsArray +using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedGraphs.PartitionedGraphs: quotientvertices -using DataGraphs: edge_data -using LinearAlgebra: norm, normalize - -import AlgorithmsInterface as AI -import .AlgorithmsInterfaceExtensions as AIE @kwdef struct StopWhenConverged <: AI.StoppingCriterion tol::Float64 = 0.0 @@ -24,7 +23,7 @@ function AI.initialize_state!( ::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged, - st::StopWhenConvergedState, + st::StopWhenConvergedState ) st.delta = Inf return st @@ -35,7 +34,7 @@ function AI.is_finished!( ::AIE.Algorithm, state::AIE.State, c::StopWhenConverged, - st::StopWhenConvergedState, + st::StopWhenConvergedState ) # maxdiff = 0.0 initially, so skip this the first time. @@ -50,7 +49,8 @@ struct BeliefPropagationProblem{Network} <: AIE.Problem network::Network end -@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: AIE.NonIterativeAlgorithmState +@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: + AIE.NonIterativeAlgorithmState iterate::Iterate diffs::Diffs = similar(edge_data(iterate), Float64) maxdiff::Float64 = 0.0 @@ -83,7 +83,10 @@ function SimpleMessageUpdate( compute_diff = false, kwargs... ) - return SimpleMessageUpdate(edge, (; normalize, contraction_alg, compute_diff, kwargs...)) + return SimpleMessageUpdate( + edge, + (; normalize, contraction_alg, compute_diff, kwargs...) + ) end function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) @@ -106,12 +109,13 @@ struct BeliefPropagationSweep{ end end -BeliefPropagationSweep(f::Function, edges) = BeliefPropagationSweep(; algorithms = f.(edges)) +function BeliefPropagationSweep(f::Function, edges) + return BeliefPropagationSweep(; algorithms = f.(edges)) +end function AI.initialize_state( ::BeliefPropagationProblem, ::AIE.NonIterativeAlgorithm; iterate, kwargs... ) - diffs = iterate.diffs maxdiff = iterate.maxdiff @@ -122,7 +126,7 @@ end function AI.initialize_state!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, - iteration_state::AIE.State, + iteration_state::AIE.State ) iteration_state.iterate.maxdiff = 0.0 return iteration_state @@ -132,9 +136,8 @@ function AIE.set_substate!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, sweep_state::AIE.DefaultState, - noniterative_substate::BeliefPropagationState, + noniterative_substate::BeliefPropagationState ) - sweep_state.iterate = noniterative_substate return sweep_state @@ -149,9 +152,8 @@ function AI.solve!( problem::BeliefPropagationProblem, algorithm::SimpleMessageUpdate, state::BeliefPropagationState; - logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), + logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm) ) - logger = AI.algorithm_logger() cache = state.iterate @@ -204,17 +206,20 @@ function AI.solve!( logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), kwargs... ) - logger = AI.algorithm_logger() AI.emit_message( logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) ) - state.iterate = contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) + state.iterate = + contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreNormalization) + logger, problem, algorithm, state, Symbol( + logging_context_prefix, + :PreNormalization + ) ) if algorithm.normalize @@ -226,7 +231,8 @@ function AI.solve!( end AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostNormalization) + logger, problem, algorithm, state, + Symbol(logging_context_prefix, :PostNormalization) ) return state @@ -237,9 +243,14 @@ function contract_messages(alg, factor::AbstractArray, messages) return contract_network(vcat(factors, messages); alg) end -beliefpropagation(network; kwargs...) = beliefpropagation(BeliefPropagationCache(network), network; kwargs...) -function beliefpropagation(cache::AbstractBeliefPropagationCache, network = nothing; kwargs...) - +function beliefpropagation(network; kwargs...) + return beliefpropagation(BeliefPropagationCache(network), network; kwargs...) +end +function beliefpropagation( + cache::AbstractBeliefPropagationCache, + network = nothing; + kwargs... + ) problem = BeliefPropagationProblem(network) algorithm = select_algorithm(beliefpropagation, cache; kwargs...) @@ -260,10 +271,13 @@ function select_algorithm( edges = forest_cover_edge_sequence(cache), maxiter = is_tree(cache) ? 1 : nothing, tol = -Inf, - message_diff_function = tol > -Inf ? (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) : nothing, + message_diff_function = if tol > -Inf + (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) + else + nothing + end, kwargs... ) - if isnothing(maxiter) throw(ArgumentError("`maxiter` must be specified for non-tree graphs")) end diff --git a/src/contract_network.jl b/src/contract_network.jl index a8c3fc7..9db4c32 100644 --- a/src/contract_network.jl +++ b/src/contract_network.jl @@ -1,7 +1,7 @@ using BackendSelection: @Algorithm_str, Algorithm using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: Mul, lazy, optimize_evaluation_order, - substitute, symnameddims +using ITensorNetworksNext.LazyNamedDimsArrays: + Mul, lazy, optimize_evaluation_order, substitute, symnameddims # This is related to `MatrixAlgebraKit.select_algorithm`. # TODO: Define this in BackendSelection.jl. diff --git a/src/sweeping/eigenproblem.jl b/src/sweeping/eigenproblem.jl index 36978b2..8fefbd0 100644 --- a/src/sweeping/eigenproblem.jl +++ b/src/sweeping/eigenproblem.jl @@ -1,5 +1,5 @@ -import AlgorithmsInterface as AI import .AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI function dmrg(operator, algorithm, state) problem = EigenProblem(operator) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 0d30970..a371373 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,25 +1,19 @@ +using .LazyNamedDimsArrays: Mul, lazy using Combinatorics: combinations -using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph +using DataGraphs.DataGraphsPartitionedGraphsExt +using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph, edge_data, get_vertices_data, + vertex_data, vertex_data_type using Dictionaries: AbstractDictionary, Indices, dictionary, set!, unset! -using Graphs: AbstractSimpleGraph, rem_vertex!, rem_edge! +using Graphs: AbstractSimpleGraph, rem_edge!, rem_vertex! using NamedDimsArrays: AbstractNamedDimsArray, dimnames -using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype, Vertices, parent_graph_indices -using NamedGraphs.GraphsExtensions: GraphsExtensions, arranged_edges, arrange_edge, vertextype -using NamedGraphs.PartitionedGraphs: - AbstractPartitionedGraph, - PartitionedGraphs, - departition, - partitioned_vertices, - partitionedgraph, - quotient_graph, - quotient_graph_type, - QuotientVertex, - QuotientVertices, - QuotientVertexVertices, +using NamedGraphs.GraphsExtensions: + GraphsExtensions, arrange_edge, arranged_edges, vertextype +using NamedGraphs.PartitionedGraphs: AbstractPartitionedGraph, PartitionedGraphs, + QuotientVertex, QuotientVertexVertices, QuotientVertices, departition, + partitioned_vertices, partitionedgraph, quotient_graph, quotient_graph_type, quotientvertices -using .LazyNamedDimsArrays: lazy, Mul -using DataGraphs: vertex_data_type, vertex_data, edge_data, get_vertices_data -using DataGraphs.DataGraphsPartitionedGraphsExt +using NamedGraphs: + NamedGraphs, NamedEdge, NamedGraph, Vertices, parent_graph_indices, vertextype function _TensorNetwork end @@ -44,7 +38,9 @@ function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) return tn end -function TensorNetwork{V, VD, UG, Tensors}(graph::UG) where {V, VD, UG <: AbstractGraph{V}, Tensors} +function TensorNetwork{V, VD, UG, Tensors}( + graph::UG + ) where {V, VD, UG <: AbstractGraph{V}, Tensors} return _TensorNetwork(graph, Tensors()) end @@ -121,14 +117,20 @@ end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) -Graphs.connected_components(tn::TensorNetwork) = Graphs.connected_components(underlying_graph(tn)) +function Graphs.connected_components(tn::TensorNetwork) + return Graphs.connected_components(underlying_graph(tn)) +end function Graphs.rem_edge!(tn::TensorNetwork, e) if !has_edge(underlying_graph(tn), e) return false end if !isempty(linkinds(tn, e)) - throw(ArgumentError("cannot remove edge $e due to tensor indices existing on this edge.")) + throw( + ArgumentError( + "cannot remove edge $e due to tensor indices existing on this edge." + ) + ) end rem_edge!(underlying_graph(tn), e) return true @@ -150,7 +152,8 @@ function NamedGraphs.induced_subgraph_from_vertices(graph::TensorNetwork, subver end function tensornetwork_induced_subgraph(graph, subvertices) - underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices) + underlying_subgraph, vlist = + Graphs.induced_subgraph(underlying_graph(graph), subvertices) subgraph = TensorNetwork(underlying_subgraph) do vertex return graph[vertex] diff --git a/test/runtests.jl b/test/runtests.jl index 0008050..16689fa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,14 +10,19 @@ const GROUP = uppercase( get(ENV, "GROUP", "ALL") else only(match(pat, ARGS[arg_id]).captures) - end, + end ) -"match files of the form `test_*.jl`, but exclude `*setup*.jl`" +""" +match files of the form `test_*.jl`, but exclude `*setup*.jl` +""" function istestfile(fn) - return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") + return endswith(fn, ".jl") && startswith(basename(fn), "test_") && + !contains(fn, "setup") end -"match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`" +""" +match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl` +""" function isexamplefile(fn) return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") end @@ -57,7 +62,7 @@ end :macrocall, GlobalRef(Suppressor, Symbol("@suppress")), LineNumberNode(@__LINE__, @__FILE__), - :(include($filename)), + :(include($filename)) ) ) end diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 8e0665c..44e6a09 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -164,7 +164,7 @@ end # Test with CallbackAction (wrapped functions) state = AIE.with_algorithmlogger( :TestProblem_TestAlgorithm_PreStep => callback1, - :TestProblem_TestAlgorithm_PostStep => callback2, + :TestProblem_TestAlgorithm_PostStep => callback2 ) do return AI.solve(problem, algorithm; iterate = [0.0]) end @@ -227,7 +227,7 @@ end ) state = AIE.DefaultState(; iterate = [0.0], - stopping_criterion_state = stopping_criterion_state, + stopping_criterion_state = stopping_criterion_state ) # Test progression through iterations @@ -253,7 +253,7 @@ end state = AIE.DefaultState(; iterate = [5.0, 10.0], iteration = 1, - stopping_criterion_state, + stopping_criterion_state ) subproblem, subalgorithm, substate = AIE.get_subproblem(problem, nested_alg, state) @@ -264,7 +264,7 @@ end # Test set_substate! new_substate = AIE.DefaultState(; iterate = [100.0, 200.0], - substate.stopping_criterion_state, + substate.stopping_criterion_state ) AIE.set_substate!(problem, nested_alg, state, new_substate) @test state.iterate ≈ [100.0, 200.0] @@ -321,7 +321,7 @@ end flattened_alg = AIE.DefaultFlattenedAlgorithm(; algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(4), + stopping_criterion = AI.StopAfterIteration(4) ) problem = TestProblem([1.0]) @@ -330,7 +330,7 @@ end ) state = AIE.DefaultFlattenedAlgorithmState(; iterate = [0.0], - stopping_criterion_state = stopping_criterion_state, + stopping_criterion_state = stopping_criterion_state ) # Test initial state @@ -388,7 +388,7 @@ end # Using the helper function flattened_alg = AIE.flattened_algorithm(2) do i AIE.nested_algorithm(1) do j - TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) + return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) end end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index a38563a..8eb4612 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -1,5 +1,5 @@ -using ITensorNetworksNext: ITensorNetworksNext using Aqua: Aqua +using ITensorNetworksNext: ITensorNetworksNext using Test: @testset @testset "Code quality (Aqua.jl)" begin diff --git a/test/test_basics.jl b/test/test_basics.jl index 0c9d803..9f80b25 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,7 +1,7 @@ using Dictionaries: Indices using Graphs: dst, edges, has_edge, ne, nv, src, vertices -using ITensorNetworksNext: TensorNetwork, linkinds, siteinds using ITensorBase: Index +using ITensorNetworksNext: TensorNetwork, linkinds, siteinds using NamedDimsArrays: dimnames using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using NamedGraphs.NamedGraphGenerators: named_grid diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 8a817b2..d1cca76 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -1,26 +1,26 @@ +using DiagonalArrays: δ using Dictionaries: Dictionary, set! -using ITensorBase: Index, ITensor, prime, noprime +using Graphs: AbstractGraph, dst, edges, src, vertices +using ITensorBase: ITensor, Index, noprime, prime using ITensorNetworksNext: - BeliefPropagationCache, - ITensorNetworksNext, - TensorNetwork, - partitionfunction -using DiagonalArrays: δ -using Graphs: src, dst, edges, vertices, AbstractGraph -using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree + ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, partitionfunction +using LinearAlgebra: LinearAlgebra +using NamedDimsArrays: inds, name using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges, vertextype +using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid using Test: @test, @testset -using LinearAlgebra: LinearAlgebra -using NamedDimsArrays: name, inds function ising_tensornetwork(g::AbstractGraph, β::Real; h = 0.0) - links = Dictionary(edges(g), [Index(2; tags = "edge" => "e$(src(e))_$(dst(e))") for e in edges(g)]) + links = Dictionary( + edges(g), + [Index(2; tags = "edge" => "e$(src(e))_$(dst(e))") for e in edges(g)] + ) links = merge(links, Dictionary(reverse.(edges(g)), [links[e] for e in edges(g)])) # symmetric sqrt of Boltzmann matrix W = exp(β σσ') sqrt_Ws = Dictionary() for e in edges(g) - W = [ exp(-(β + 2 * h)) exp(β); exp(β) exp(-(β - 2 * h)) ] + W = [exp(-(β + 2 * h)) exp(β); exp(β) exp(-(β - 2 * h))] F = LinearAlgebra.svd(W) U, S, V = F.U, F.S, F.Vt @@ -87,5 +87,4 @@ end z_bp = partitionfunction(bpc) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact rtol = 1.0e-4 - end diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index 35b2275..b453e76 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -1,11 +1,11 @@ +using BackendSelection: @Algorithm_str, Algorithm using Graphs: edges +using ITensorBase: Index +using ITensorNetworksNext: TensorNetwork, contract_network, linkinds, siteinds using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using NamedGraphs.NamedGraphGenerators: named_grid -using ITensorBase: Index -using ITensorNetworksNext: TensorNetwork, linkinds, siteinds, contract_network using TensorOperations: TensorOperations using Test: @test, @testset -using BackendSelection: @Algorithm_str, Algorithm @testset "contract_network" begin orderalg = alg -> Algorithm"exact"(; order_alg = Algorithm(alg)) diff --git a/test/test_dmrg.jl b/test/test_dmrg.jl index 01f04ac..dba2570 100644 --- a/test/test_dmrg.jl +++ b/test/test_dmrg.jl @@ -1,6 +1,6 @@ import AlgorithmsInterface as AI -using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm using Test: @test, @testset @testset "select_algorithm(dmrg, ...)" begin @@ -21,7 +21,7 @@ using Test: @test, @testset return EigsolveRegion( regions[j]; maxdim = maxdims[i], - cutoff = cutoffs[i], + cutoff = cutoffs[i] ) end end diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index d067c24..751b469 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,9 +1,9 @@ using AbstractTrees: AbstractTrees, print_tree, printnode using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, LazyNamedDimsArray, - Mul, SymbolicArray, ismul, lazy, substitute, symnameddims -using NamedDimsArrays: NamedDimsArray, @names, denamed, dimnames, inds, nameddims, - namedoneto +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, LazyNamedDimsArrays, Mul, + SymbolicArray, ismul, lazy, substitute, symnameddims +using NamedDimsArrays: + @names, NamedDimsArray, denamed, dimnames, inds, nameddims, namedoneto using TermInterface: arguments, arity, children, head, iscall, isexpr, maketerm, operation, sorted_arguments, sorted_children using Test: @test, @test_throws, @testset diff --git a/test/test_tensornetworkgenerators.jl b/test/test_tensornetworkgenerators.jl index 2d092c3..f29a900 100644 --- a/test/test_tensornetworkgenerators.jl +++ b/test/test_tensornetworkgenerators.jl @@ -1,8 +1,8 @@ using DiagonalArrays: δ using Graphs: edges, ne, nv, vertices using ITensorBase: Index -using ITensorNetworksNext: contract_network using ITensorNetworksNext.TensorNetworkGenerators: delta_network, ising_network +using ITensorNetworksNext: contract_network using NamedDimsArrays: inds using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using NamedGraphs.NamedGraphGenerators: named_grid From f2f9011a51aa37b02803046baddf7924e39b077b Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 15:47:35 -0400 Subject: [PATCH 40/64] Refactor `NestedAlgorithm` hooks: `initialize_subsolve` + `finalize_substate!` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename the `NestedAlgorithm` step hooks from `get_subproblem` / `set_substate!` to `initialize_subsolve` / `finalize_substate!`, and move the indexed-list dispatch (`algorithm.algorithms[state.iteration]`) out of the abstract type's default into `DefaultNestedAlgorithm`'s own `initialize_subsolve` method. The abstract `NestedAlgorithm` now has no default `initialize_subsolve` impl (throws `MethodError`), so subtypes must provide their own — the previous default silently assumed every subtype carried an `algorithms` vector, which is too narrow. `BeliefPropagation` and `BeliefPropagationSweep` get an explicit `AIE.initialize_subsolve` override mirroring the indexed-list shape; the existing `AIE.set_substate!` override on `BeliefPropagationSweep` is renamed to `AIE.finalize_substate!`. Co-Authored-By: Claude Opus 4.7 --- Project.toml | 2 +- .../AlgorithmsInterfaceExtensions.jl | 37 +++++++++++-------- .../beliefpropagationproblem.jl | 17 ++++++++- test/test_algorithmsinterfaceextensions.jl | 12 +++--- 4 files changed, 45 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index 301fb17..c8358d0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" -version = "0.4.2" +version = "0.4.3" authors = ["ITensor developers and contributors"] [workspace] diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index d9edb0d..9eb930f 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -103,32 +103,28 @@ end max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) -function get_subproblem( - problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State +# Subtypes of `NestedAlgorithm` must override `initialize_subsolve` — it +# returns the `(subproblem, subalgorithm, substate)` tuple that the next +# inner `AI.solve!` call consumes. The default `finalize_substate!` copies +# the substate's iterate back into the parent state; subtypes can override +# when more is required. +function initialize_subsolve( + problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] - substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) - return subproblem, subalgorithm, substate + return throw(MethodError(initialize_subsolve, (problem, algorithm, state))) end -function set_substate!( - problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State, substate::AI.State +function finalize_substate!( + problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State, substate::AI.State ) state.iterate = substate.iterate return state end function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State) - # Get the subproblem, subalgorithm, and substate. - subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state) - - # Solve the subproblem with the subalgorithm. + subproblem, subalgorithm, substate = initialize_subsolve(problem, algorithm, state) AI.solve!(subproblem, subalgorithm, substate) - - # Update the state with the substate. - set_substate!(problem, algorithm, state, substate) - + finalize_substate!(problem, algorithm, state, substate) return state end @@ -150,6 +146,15 @@ function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) return DefaultNestedAlgorithm(; algorithms = f.(iterable), kwargs...) end +function initialize_subsolve( + problem::AI.Problem, algorithm::DefaultNestedAlgorithm, state::AI.State + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate +end + # ============================ FlattenedAlgorithm ========================================== # Flatten a nested algorithm. diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 004e449..e0c65b1 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -139,7 +139,22 @@ function BeliefPropagationSweep(f::Function, edges) return BeliefPropagationSweep(; algorithms = f.(edges)) end -function AIE.set_substate!( +# `BeliefPropagation` and `BeliefPropagationSweep` carry a flat list of +# child algorithms, mirroring `AIE.DefaultNestedAlgorithm`. Each step picks +# the child algorithm by the current iteration index and reuses the parent +# problem. +function AIE.initialize_subsolve( + problem::BeliefPropagationProblem, + algorithm::Union{BeliefPropagation, BeliefPropagationSweep}, + state::AI.State + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate +end + +function AIE.finalize_substate!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, state::AIE.DefaultState, diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 6f80527..60118da 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -213,8 +213,8 @@ end @test state.iteration == 2 end - @testset "get_subproblem and set_substate!" begin - # Test get_subproblem + @testset "initialize_subsolve and finalize_substate!" begin + # Test initialize_subsolve problem = TestProblem([1.0, 2.0]) nested_alg = AIE.nested_algorithm(2) do i return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(1)) @@ -229,17 +229,19 @@ end stopping_criterion_state ) - subproblem, subalgorithm, substate = AIE.get_subproblem(problem, nested_alg, state) + subproblem, subalgorithm, substate = AIE.initialize_subsolve( + problem, nested_alg, state + ) @test subproblem === problem @test subalgorithm === nested_alg.algorithms[1] @test substate.iterate ≈ [5.0, 10.0] - # Test set_substate! + # Test finalize_substate! new_substate = AIE.DefaultState(; iterate = [100.0, 200.0], substate.stopping_criterion_state ) - AIE.set_substate!(problem, nested_alg, state, new_substate) + AIE.finalize_substate!(problem, nested_alg, state, new_substate) @test state.iterate ≈ [100.0, 200.0] end From 314c29485578a4bb075d08910ad4add77cd3f923 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 16:11:22 -0400 Subject: [PATCH 41/64] Strip AIE down to minimal NestedAlgorithm + abstract scaffolding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete from AIE: - `DefaultNestedAlgorithm` (struct, constructor, and its `initialize_subsolve` indexed-list impl). - `nested_algorithm` factory function + `max_iterations`. - All `FlattenedAlgorithm` machinery (struct, state, helper). - `AlgorithmIterator` / `DefaultAlgorithmIterator` / `algorithm_iterator`. - `with_algorithmlogger`. - `NonIterativeAlgorithm` / `DefaultNonIterativeAlgorithmState`. Keep the minimum BP / nested-solve still needs: `Problem`, `Algorithm`, `State`, `DefaultState`, `AI.initialize_state` / `initialize_state!` / `increment!`, and the minimal `NestedAlgorithm` (abstract type + `initialize_subsolve` + `finalize_substate!` + `AI.step!`). Future features can be added back as concrete subtypes when actually needed. Delete the placeholder sweep / DMRG scaffolding (`src/sweeping/`, `test/test_dmrg.jl`, `test/test_sweeping.jl`): `EigsolveRegion` / `EigenProblem` / `dmrg` / `select_algorithm` were not actually wired up (the `solve!` body threw `not implemented yet`), and the tests covered only construction shape. Inline the kwarg-expansion helpers (`extend_columns`, `rows`, ...) that lived in `sweeping/utils.jl` into `beliefpropagation/beliefpropagationproblem.jl` — its only remaining consumer. Prune `test/test_algorithmsinterfaceextensions.jl` to the surface that still exists; add a `TestNestedAlgorithm` concrete subtype mirroring how `BeliefPropagation` shapes itself on top of the new minimal interface, and a test that the bare `initialize_subsolve` default throws. Co-Authored-By: Claude Opus 4.7 --- .../AlgorithmsInterfaceExtensions.jl | 177 --------- src/ITensorNetworksNext.jl | 2 - .../beliefpropagationproblem.jl | 18 +- src/sweeping/eigenproblem.jl | 44 --- src/sweeping/utils.jl | 12 - test/test_algorithmsinterfaceextensions.jl | 365 +++--------------- test/test_dmrg.jl | 34 -- test/test_sweeping.jl | 65 ---- 8 files changed, 58 insertions(+), 659 deletions(-) delete mode 100644 src/sweeping/eigenproblem.jl delete mode 100644 src/sweeping/utils.jl delete mode 100644 test/test_dmrg.jl delete mode 100644 test/test_sweeping.jl diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index 9eb930f..c91e24e 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -47,62 +47,10 @@ function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) return AI.increment!(state) end -# ============================ AlgorithmIterator =========================================== - -abstract type AlgorithmIterator end - -function algorithm_iterator( - problem::Problem, algorithm::Algorithm, state::State - ) - return DefaultAlgorithmIterator(problem, algorithm, state) -end - -function AI.is_finished!(iterator::AlgorithmIterator) - return AI.is_finished!(iterator.problem, iterator.algorithm, iterator.state) -end -function AI.is_finished(iterator::AlgorithmIterator) - return AI.is_finished(iterator.problem, iterator.algorithm, iterator.state) -end -function AI.increment!(iterator::AlgorithmIterator) - return AI.increment!(iterator.problem, iterator.algorithm, iterator.state) -end -function AI.step!(iterator::AlgorithmIterator) - return AI.step!(iterator.problem, iterator.algorithm, iterator.state) -end -function Base.iterate(iterator::AlgorithmIterator, init = nothing) - AI.is_finished!(iterator) && return nothing - AI.increment!(iterator) - AI.step!(iterator) - return iterator.state, nothing -end - -struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator - problem::Problem - algorithm::Algorithm - state::State -end - -# ============================ with_algorithmlogger ======================================== - -# Allow passing functions, not just CallbackActions. -@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...) - return AI.with_algorithmlogger(f, args...) -end -@inline function with_algorithmlogger(f, args::Pair{Symbol}...) - return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...) -end - # ============================ NestedAlgorithm ============================================= abstract type NestedAlgorithm <: Algorithm end -nested_algorithm(f::Function, int::Int; kwargs...) = nested_algorithm(f, 1:int; kwargs...) -function nested_algorithm(f::Function, iterable; kwargs...) - return DefaultNestedAlgorithm(f, iterable; kwargs...) -end - -max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) - # Subtypes of `NestedAlgorithm` must override `initialize_subsolve` — it # returns the `(subproblem, subalgorithm, substate)` tuple that the next # inner `AI.solve!` call consumes. The default `finalize_substate!` copies @@ -128,129 +76,4 @@ function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.Sta return state end -#= - DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm}) - -An algorithm that consists of running an algorithm at each iteration -from a list of stored algorithms. -=# -@kwdef struct DefaultNestedAlgorithm{ - ChildAlgorithm <: Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - } <: NestedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) -end -function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) - return DefaultNestedAlgorithm(; algorithms = f.(iterable), kwargs...) -end - -function initialize_subsolve( - problem::AI.Problem, algorithm::DefaultNestedAlgorithm, state::AI.State - ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] - substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) - return subproblem, subalgorithm, substate -end - -# ============================ FlattenedAlgorithm ========================================== - -# Flatten a nested algorithm. -abstract type FlattenedAlgorithm <: Algorithm end -abstract type FlattenedAlgorithmState <: State end - -function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...) -end - -function AI.initialize_state( - problem::Problem, algorithm::FlattenedAlgorithm; kwargs... - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...) -end -function AI.increment!( - problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState - ) - # Increment the total iteration count. - state.iteration += 1 - # TODO: Use `is_finished!` instead? - if state.child_iteration ≥ max_iterations(algorithm.algorithms[state.parent_iteration]) - # We're on the last iteration of the child algorithm, so move to the next - # child algorithm. - state.parent_iteration += 1 - state.child_iteration = 1 - else - # Iterate the child algorithm. - state.child_iteration += 1 - end - return state -end -function AI.step!( - problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState - ) - algorithm_sweep = algorithm.algorithms[state.parent_iteration] - state_sweep = AI.initialize_state( - problem, algorithm_sweep; - state.iterate, iteration = state.child_iteration - ) - AI.step!(problem, algorithm_sweep, state_sweep) - state.iterate = state_sweep.iterate - return state -end - -@kwdef struct DefaultFlattenedAlgorithm{ - ChildAlgorithm <: Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - } <: FlattenedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = - AI.StopAfterIteration(sum(max_iterations, algorithms)) -end -function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) -end - -@kwdef mutable struct DefaultFlattenedAlgorithmState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: FlattenedAlgorithmState - iterate::Iterate - iteration::Int = 0 - parent_iteration::Int = 1 - child_iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState -end - -# ============================ NonIterativeAlgorithm ======================================= - -# Algorithm that only performs a single step. -abstract type NonIterativeAlgorithm <: Algorithm end -abstract type NonIterativeAlgorithmState <: State end - -function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...) - return DefaultNonIterativeAlgorithmState(; kwargs...) -end - -function AI.initialize_state!( - problem::Problem, - algorithm::NonIterativeAlgorithm, - state::NonIterativeAlgorithmState - ) - return state -end - -function AI.solve_loop!(problem::Problem, algorithm::NonIterativeAlgorithm, state::State) - return throw(MethodError(AI.solve_loop!, (problem, algorithm, state))) -end - -@kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <: - NonIterativeAlgorithmState - iterate::Iterate -end - end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 0b2b898..394546a 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -12,8 +12,6 @@ include("abstracttensornetwork.jl") include("tensornetwork.jl") include("TensorNetworkGenerators/TensorNetworkGenerators.jl") include("contract_network.jl") -include("sweeping/utils.jl") -include("sweeping/eigenproblem.jl") include("beliefpropagation/messagecache.jl") include("beliefpropagation/beliefpropagationproblem.jl") diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index e0c65b1..8f88ff4 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -8,6 +8,19 @@ using NamedDimsArrays: AbstractNamedDimsArray using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedGraphs.PartitionedGraphs: quotientvertices +# Utility functions for processing keyword arguments. +function repeat_last(v::AbstractVector, len::Int) + return [v; fill(v[end], max(len - length(v), 0))] +end +repeat_last(v, len::Int) = fill(v, len) +function extend_columns(nt::NamedTuple, len::Int) + return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) +end +rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) +function rows(nt::NamedTuple, len::Int = rowlength(nt)) + return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] +end + @kwdef struct StopWhenConverged <: AI.StoppingCriterion tol::Float64 end @@ -140,9 +153,8 @@ function BeliefPropagationSweep(f::Function, edges) end # `BeliefPropagation` and `BeliefPropagationSweep` carry a flat list of -# child algorithms, mirroring `AIE.DefaultNestedAlgorithm`. Each step picks -# the child algorithm by the current iteration index and reuses the parent -# problem. +# child algorithms. Each step picks the child algorithm by the current +# iteration index and reuses the parent problem. function AIE.initialize_subsolve( problem::BeliefPropagationProblem, algorithm::Union{BeliefPropagation, BeliefPropagationSweep}, diff --git a/src/sweeping/eigenproblem.jl b/src/sweeping/eigenproblem.jl deleted file mode 100644 index 8fefbd0..0000000 --- a/src/sweeping/eigenproblem.jl +++ /dev/null @@ -1,44 +0,0 @@ -import .AlgorithmsInterfaceExtensions as AIE -import AlgorithmsInterface as AI - -function dmrg(operator, algorithm, state) - problem = EigenProblem(operator) - return AI.solve(problem, algorithm; iterate = state).iterate -end -function dmrg(operator, state; kwargs...) - problem = EigenProblem(operator) - algorithm = select_algorithm(dmrg, operator, state; kwargs...) - return AI.solve(problem, algorithm; iterate = state).iterate -end - -# TODO: Allow specifying the region algorithm type? -function select_algorithm(::typeof(dmrg), operator, state; nsweeps, regions, kwargs...) - extended_kwargs = extend_columns((; kwargs...), nsweeps) - region_kwargs = rows(extended_kwargs) - return AIE.nested_algorithm(nsweeps) do i - return AIE.nested_algorithm(length(regions)) do j - return EigsolveRegion(regions[j]; region_kwargs[i]...) - end - end -end -#= - EigenProblem(operator) - -Represents the problem we are trying to solve and minimal algorithm-independent -information, so for an eigenproblem it is the operator we want the eigenvector of. -=# -struct EigenProblem{Operator} <: AIE.Problem - operator::Operator -end - -struct EigsolveRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm - region::R - kwargs::Kwargs -end -EigsolveRegion(region; kwargs...) = EigsolveRegion(region, (; kwargs...)) - -function AI.solve!( - problem::EigenProblem, algorithm::EigsolveRegion, state::AIE.State; kwargs... - ) - return error("EigsolveRegion step for EigenProblem not implemented yet.") -end diff --git a/src/sweeping/utils.jl b/src/sweeping/utils.jl deleted file mode 100644 index 39e09e4..0000000 --- a/src/sweeping/utils.jl +++ /dev/null @@ -1,12 +0,0 @@ -# Utility functions for processing keyword arguments. -function repeat_last(v::AbstractVector, len::Int) - return [v; fill(v[end], max(len - length(v), 0))] -end -repeat_last(v, len::Int) = fill(v, len) -function extend_columns(nt::NamedTuple, len::Int) - return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) -end -rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) -function rows(nt::NamedTuple, len::Int = rowlength(nt)) - return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] -end diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 60118da..0ec2d16 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -1,6 +1,6 @@ import AlgorithmsInterface as AI import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -using Test: @test, @testset +using Test: @test, @test_throws, @testset # Define test problems, algorithms, and states for testing struct TestProblem <: AIE.Problem @@ -11,10 +11,6 @@ end stopping_criterion::StoppingCriterion = AI.StopAfterIteration(10) end -@kwdef struct TestAlgorithmStep{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(5) -end - function AI.step!( problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState ) @@ -22,16 +18,29 @@ function AI.step!( return state end -function AI.step!( - problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState +# Concrete `NestedAlgorithm` subtype: holds a flat list of child algorithms +# and picks them by iteration index. Mirrors how `BeliefPropagation` shapes +# itself on top of the new minimal `AIE.NestedAlgorithm`. +@kwdef struct TestNestedAlgorithm{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +end + +function AIE.initialize_subsolve( + problem::TestProblem, algorithm::TestNestedAlgorithm, state::AI.State ) - state.iterate .+= 2 # Different increment step - return state + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate end @testset "AlgorithmsInterfaceExtensions" begin @testset "DefaultState" begin - # Test DefaultState construction iterate = [1.0, 2.0, 3.0] stopping_criterion_state = AI.initialize_state( TestProblem([1.0]), TestAlgorithm(), TestAlgorithm().stopping_criterion @@ -41,13 +50,11 @@ end @test state.iteration == 0 @test state.stopping_criterion_state isa AI.StoppingCriterionState - # Test DefaultState with custom iteration state.iteration = 5 @test state.iteration == 5 end @testset "initialize_state!" begin - # Test initialize_state! with iterate kwarg problem = TestProblem([1.0, 2.0]) algorithm = TestAlgorithm() stopping_criterion_state = AI.initialize_state( @@ -63,17 +70,14 @@ end end @testset "initialize_state" begin - # Test initialize_state without exclamation problem = TestProblem([1.0, 2.0]) algorithm = TestAlgorithm() - state = AI.initialize_state(problem, algorithm; iterate = [0.0, 0.0]) @test state isa AIE.DefaultState @test state.iteration == 0 end @testset "increment!" begin - # Test increment! with problem and algorithm problem = TestProblem([1.0, 2.0]) algorithm = TestAlgorithm() stopping_criterion_state = AI.initialize_state( @@ -81,352 +85,69 @@ end ) state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) - # Increment and verify iteration counter increases AI.increment!(problem, algorithm, state) @test state.iteration == 1 - AI.increment!(problem, algorithm, state) @test state.iteration == 2 end @testset "solve! and solve" begin - # Test solve! with simple problem problem = TestProblem([1.0, 2.0]) algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(3)) + state = AI.initialize_state(problem, algorithm; iterate = [10.0, 20.0]) - initial_iterate = [10.0, 20.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - - # Solve with custom initial iterate initial_iterate = [5.0, 10.0] final_iterate = AI.solve!( problem, algorithm, state; iterate = copy(initial_iterate) ) - @test state.iteration == 3 @test final_iterate == state.iterate # Each step increments by 1, so after 3 steps: [5, 10] + 3 = [8, 13] @test state.iterate ≈ [8.0, 13.0] - # Test solve without exclamation problem2 = TestProblem([1.0, 2.0]) algorithm2 = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) - initial_iterate2 = [5.0, 10.0] - - final_iterate2 = AI.solve(problem2, algorithm2; iterate = copy(initial_iterate2)) + final_iterate2 = AI.solve(problem2, algorithm2; iterate = [5.0, 10.0]) @test final_iterate2 ≈ [7.0, 12.0] end - @testset "DefaultAlgorithmIterator" begin - # Test algorithm iterator creation - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) - initial_iterate = [0.0, 0.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - iterator = AIE.algorithm_iterator(problem, algorithm, state) - - @test iterator isa AIE.DefaultAlgorithmIterator - @test iterator.problem === problem - @test iterator.algorithm === algorithm - @test iterator.state === state - - # Test iteration interface - @test !AI.is_finished!(iterator) - - # Step through iterator - state_out, _ = iterate(iterator) - @test state_out.iteration == 1 - @test state_out.iterate ≈ [1.0, 1.0] # Incremented by step! - - state_out, _ = iterate(iterator) - @test state_out.iteration == 2 - - @test AI.is_finished!(iterator) - end - - @testset "DefaultNestedAlgorithm" begin - # Test creating nested algorithm with function - nested_alg = AIE.nested_algorithm(3) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - - @test nested_alg isa AIE.DefaultNestedAlgorithm - @test length(nested_alg.algorithms) == 3 - @test AIE.max_iterations(nested_alg) == 3 - - # Test stepping through nested algorithm - problem = TestProblem([1.0, 2.0]) - stopping_criterion_state = AI.initialize_state( - problem, nested_alg, nested_alg.stopping_criterion - ) - state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) - - initial_iterate = [0.0, 0.0] - AI.solve!( - problem, nested_alg, state; iterate = copy(initial_iterate) - ) - - @test state.iteration == 3 - # Each nested algorithm runs once with 2 steps, incrementing by 2 - # Total: 3 algorithms × 2 iterations × 2 increment = 12 - @test state.iterate ≈ [12.0, 12.0] - end - - @testset "NestedAlgorithm basic tests" begin - # Test basic nested algorithm functionality - nested_alg = AIE.nested_algorithm(2) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - - problem = TestProblem([1.0, 2.0]) - - # Test state initialization - state_nested = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) - - @test state_nested isa AIE.DefaultState - @test state_nested.iteration == 0 - @test AIE.max_iterations(nested_alg) == 2 - end - - @testset "increment! for nested algorithms" begin - # Test increment! logic for nested algorithm state + @testset "NestedAlgorithm defaults" begin + # The bare `initialize_subsolve` default throws a `MethodError`, + # forcing concrete subtypes to provide their own override. problem = TestProblem([1.0]) - nested_alg = AIE.nested_algorithm(2) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - - stopping_criterion_state = AI.initialize_state( - problem, nested_alg, nested_alg.stopping_criterion - ) - state = AIE.DefaultState(; - iterate = [0.0], - stopping_criterion_state = stopping_criterion_state - ) - - # Test progression through iterations - @test state.iteration == 0 - - AI.increment!(problem, nested_alg, state) - @test state.iteration == 1 - - AI.increment!(problem, nested_alg, state) - @test state.iteration == 2 - end - - @testset "initialize_subsolve and finalize_substate!" begin - # Test initialize_subsolve - problem = TestProblem([1.0, 2.0]) - nested_alg = AIE.nested_algorithm(2) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(1)) - end - - stopping_criterion_state = AI.initialize_state( - problem, nested_alg, nested_alg.stopping_criterion - ) + algorithm = TestAlgorithm() state = AIE.DefaultState(; - iterate = [5.0, 10.0], - iteration = 1, - stopping_criterion_state - ) - - subproblem, subalgorithm, substate = AIE.initialize_subsolve( - problem, nested_alg, state - ) - @test subproblem === problem - @test subalgorithm === nested_alg.algorithms[1] - @test substate.iterate ≈ [5.0, 10.0] - - # Test finalize_substate! - new_substate = AIE.DefaultState(; - iterate = [100.0, 200.0], - substate.stopping_criterion_state - ) - AIE.finalize_substate!(problem, nested_alg, state, new_substate) - @test state.iterate ≈ [100.0, 200.0] - end - - @testset "DefaultFlattenedAlgorithm" begin - # Create nested algorithms that support max_iterations - nested_algs = map(1:3) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - flattened_alg = AIE.DefaultFlattenedAlgorithm(; - algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(6) # 3 algorithms × 2 iterations each - ) - - @test flattened_alg isa AIE.DefaultFlattenedAlgorithm - @test length(flattened_alg.algorithms) == 3 - - # Test state initialization - problem = TestProblem([1.0, 2.0]) - state_flat = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) - - @test state_flat isa AIE.DefaultFlattenedAlgorithmState - @test state_flat.iteration == 0 - @test state_flat.parent_iteration == 1 - @test state_flat.child_iteration == 0 - end - - @testset "DefaultFlattenedAlgorithmState increment!" begin - # Create nested algorithms for flattened algorithm - nested_algs = map(1:2) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - flattened_alg = AIE.DefaultFlattenedAlgorithm(; - algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(4) - ) - - problem = TestProblem([1.0]) - stopping_criterion_state = AI.initialize_state( - problem, flattened_alg, flattened_alg.stopping_criterion - ) - state = AIE.DefaultFlattenedAlgorithmState(; iterate = [0.0], - stopping_criterion_state = stopping_criterion_state + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) ) + @test_throws MethodError AIE.initialize_subsolve(problem, algorithm, state) - # Test initial state - @test state.iteration == 0 - @test state.parent_iteration == 1 - @test state.child_iteration == 0 - - # First increment - should increment child_iteration - AI.increment!(problem, flattened_alg, state) - @test state.iteration == 1 - @test state.parent_iteration == 1 - @test state.child_iteration == 1 - - # Second increment - should increment child_iteration again - AI.increment!(problem, flattened_alg, state) - @test state.iteration == 2 - @test state.parent_iteration == 2 # Should move to next parent - @test state.child_iteration == 1 - end - - @testset "FlattenedAlgorithm step!" begin - # Test individual step! calls for flattened algorithm - nested_algs = map(1:2) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - flattened_alg = AIE.DefaultFlattenedAlgorithm(; - algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(4) + # `finalize_substate!` copies the substate's iterate back into the + # parent state. + substate = AIE.DefaultState(; + iterate = [42.0], + stopping_criterion_state = state.stopping_criterion_state ) - - problem = TestProblem([1.0, 2.0]) - state = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) - - # Manually step through to test step! functionality - AI.increment!(problem, flattened_alg, state) - @test state.parent_iteration == 1 - @test state.child_iteration == 1 - - AI.step!(problem, flattened_alg, state) - # The nested algorithm runs TestAlgorithmStep with 2 iterations, each incrementing by 2 - @test state.iterate ≈ [4.0, 4.0] - end - - @testset "flattened_algorithm helper" begin - # Test the flattened_algorithm helper function - nested_algs = map(1:2) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - # Using the helper function - flattened_alg = AIE.flattened_algorithm(2) do i - AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - @test flattened_alg isa AIE.DefaultFlattenedAlgorithm - @test length(flattened_alg.algorithms) == 2 + AIE.finalize_substate!(problem, algorithm, state, substate) + @test state.iterate == [42.0] end - @testset "AlgorithmIterator is_finished (without !)" begin - # Test is_finished without mutation + @testset "TestNestedAlgorithm" begin problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) - initial_iterate = [0.0, 0.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - iterator = AIE.algorithm_iterator(problem, algorithm, state) - - # Before any iterations - @test !AI.is_finished(iterator) - - # Run the algorithm - AI.solve!(problem, algorithm, state; iterate = copy(initial_iterate)) - - # After completion - @test AI.is_finished(iterator) - end - - @testset "AlgorithmIterator step!" begin - # Test step! method for iterator - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) - initial_iterate = [0.0, 0.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - iterator = AIE.algorithm_iterator(problem, algorithm, state) - - # Step the iterator - AI.step!(iterator) - @test iterator.state.iterate ≈ [1.0, 1.0] - - AI.step!(iterator) - @test iterator.state.iterate ≈ [2.0, 2.0] - end - - @testset "NestedAlgorithm with different sub-algorithms" begin - # Test nested algorithm with varying sub-algorithms - nested_alg = AIE.DefaultNestedAlgorithm(; + nested_alg = TestNestedAlgorithm(; algorithms = [ TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), - TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)), - TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)), ] ) + @test nested_alg isa AIE.NestedAlgorithm - @test AIE.max_iterations(nested_alg) == 3 - @test length(nested_alg.algorithms) == 3 - - problem = TestProblem([1.0, 2.0]) state = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) - AI.solve!(problem, nested_alg, state; iterate = [0.0, 0.0]) - - # First algorithm: 1 iteration × 1 increment = 1 - # Second algorithm: 2 iterations × 2 increment = 4 - # Third algorithm: 1 iteration × 1 increment = 1 - # Total: 1 + 4 + 1 = 6 - @test state.iterate ≈ [6.0, 6.0] - @test state.iteration == 3 - end - - @testset "Edge cases" begin - # Test with single nested algorithm - nested_alg = AIE.nested_algorithm(1) do i - return TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) - end - - problem = TestProblem([1.0]) - state = AI.initialize_state(problem, nested_alg; iterate = [0.0]) - AI.solve!(problem, nested_alg, state; iterate = [0.0]) - - @test state.iterate ≈ [1.0] - @test state.iteration == 1 + # Two child algorithms: 1 iter + 2 iter = 3 inner increments total. + @test state.iteration == 2 + @test state.iterate ≈ [3.0, 3.0] end end diff --git a/test/test_dmrg.jl b/test/test_dmrg.jl deleted file mode 100644 index dba2570..0000000 --- a/test/test_dmrg.jl +++ /dev/null @@ -1,34 +0,0 @@ -import AlgorithmsInterface as AI -import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm -using Test: @test, @testset - -@testset "select_algorithm(dmrg, ...)" begin - operator = "operator" - init = "init" - nsweeps = 3 - regions = ["region1", "region2"] - maxdim = [10, 20] - cutoff = 1.0e-7 - algorithm = select_algorithm(dmrg, operator, init; nsweeps, regions, maxdim, cutoff) - @test algorithm isa AIE.NestedAlgorithm - @test length(algorithm.algorithms) == nsweeps - - maxdims = [10, 20, 20] - cutoffs = [1.0e-7, 1.0e-7, 1.0e-7] - algorithm′ = AIE.nested_algorithm(nsweeps) do i - return AIE.nested_algorithm(length(regions)) do j - return EigsolveRegion( - regions[j]; - maxdim = maxdims[i], - cutoff = cutoffs[i] - ) - end - end - for i in 1:nsweeps - for j in 1:length(regions) - @test algorithm.algorithms[i].algorithms[j] == - algorithm′.algorithms[i].algorithms[j] - end - end -end diff --git a/test/test_sweeping.jl b/test/test_sweeping.jl deleted file mode 100644 index 01881d9..0000000 --- a/test/test_sweeping.jl +++ /dev/null @@ -1,65 +0,0 @@ -import AlgorithmsInterface as AI -import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -using Test: @test, @testset - -struct TestProblem <: AIE.Problem -end - -struct TestRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm - region::R - kwargs::Kwargs -end -TestRegion(region; kwargs...) = TestRegion(region, (; kwargs...)) - -function AI.solve_loop!(problem::TestProblem, algorithm::TestRegion, state::AIE.State) - new_iterate = (; algorithm.region, algorithm.kwargs.foo, algorithm.kwargs.bar) - state.iterate = [state.iterate; [new_iterate]] - return state -end - -@testset "Sweeping" begin - @testset "TestRegion" begin - algorithm = TestRegion("region"; foo = 1, bar = 2) - @test algorithm isa AIE.NonIterativeAlgorithm - @test algorithm isa AIE.Algorithm - @test algorithm isa AI.Algorithm - @test algorithm.region == "region" - @test algorithm.kwargs == (; foo = 1, bar = 2) - - problem = TestProblem() - iterate = [] - iterate = AI.solve(problem, algorithm; iterate) - @test iterate == [(; region = "region", foo = 1, bar = 2)] - end - @testset "Sweep" begin - algorithm = AIE.nested_algorithm(3) do i - return TestRegion("region$i"; foo = i, bar = 2i) - end - problem = TestProblem() - iterate = [] - iterate = AI.solve(problem, algorithm; iterate) - @test iterate == [ - (; region = "region1", foo = 1, bar = 2), - (; region = "region2", foo = 2, bar = 4), - (; region = "region3", foo = 3, bar = 6), - ] - end - @testset "Sweeping" begin - algorithm = AIE.nested_algorithm(2) do i - AIE.nested_algorithm(3) do j - return TestRegion("sweep$i, region$j"; foo = (i, j), bar = (2i, 2j)) - end - end - problem = TestProblem() - iterate = [] - iterate = AI.solve(problem, algorithm; iterate) - @test iterate == [ - (; region = "sweep1, region1", foo = (1, 1), bar = (2, 2)), - (; region = "sweep1, region2", foo = (1, 2), bar = (2, 4)), - (; region = "sweep1, region3", foo = (1, 3), bar = (2, 6)), - (; region = "sweep2, region1", foo = (2, 1), bar = (4, 2)), - (; region = "sweep2, region2", foo = (2, 2), bar = (4, 4)), - (; region = "sweep2, region3", foo = (2, 3), bar = (4, 6)), - ] - end -end From e35b1a20edc60424d68b9a962ca1e66108a59c8e Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 16:19:43 -0400 Subject: [PATCH 42/64] Drop AIE `Problem` / `Algorithm` / `State` / `DefaultState` scaffolding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `AlgorithmsInterfaceExtensions` is now just the `NestedAlgorithm` abstract type plus `initialize_subsolve` / `finalize_substate!` / `AI.step!`. The `Problem` / `Algorithm` / `State` abstract types and `DefaultState` are gone, along with the `AI.initialize_state` / `initialize_state!` / `increment!` overloads that hung off them. Belief propagation grows its own state machinery instead of leaning on AIE: a `BeliefPropagationState <: AI.State` (mutable, `iterate` / `iteration` / `stopping_criterion_state`), plus per-algorithm `AI.initialize_state` / `initialize_state!` / `increment!` overloads on `Union{BeliefPropagation, BeliefPropagationSweep}`. This mirrors the pattern `BPApplyGate` already uses in the apply-operator path — each algorithm owns its state, no shared default. `BeliefPropagationProblem`, `BeliefPropagation`, and `BeliefPropagationSweep` now subtype the bare `AI.Problem` / `AI.Algorithm` types; the `StopWhenConverged` dispatches accept `AI.Problem` / `AI.Algorithm` / `AI.State` since `StopWhenConverged` itself is the unique type doing the disambiguation. `test_algorithmsinterfaceextensions.jl` is rewritten to define its own `TestProblem` / `TestChildAlgorithm` / `TestChildState` / `TestNestedAlgorithm` directly on `AI.*`, so the test exercises the same "each algorithm owns its state" pattern. Co-Authored-By: Claude Opus 4.7 --- .../AlgorithmsInterfaceExtensions.jl | 47 +---- .../beliefpropagationproblem.jl | 80 ++++++-- test/test_algorithmsinterfaceextensions.jl | 182 ++++++++---------- 3 files changed, 151 insertions(+), 158 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index c91e24e..ddf1083 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -2,54 +2,9 @@ module AlgorithmsInterfaceExtensions import AlgorithmsInterface as AI -# ========================== Patches for AlgorithmsInterface.jl ============================ - -abstract type Problem <: AI.Problem end -abstract type Algorithm <: AI.Algorithm end -abstract type State <: AI.State end - -function AI.initialize_state!( - problem::Problem, algorithm::Algorithm, state::State; iteration = 0, kwargs... - ) - for (k, v) in pairs(kwargs) - setproperty!(state, k, v) - end - state.iteration = iteration - AI.initialize_state!( - problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state - ) - return state -end - -function AI.initialize_state( - problem::Problem, algorithm::Algorithm; iterate, kwargs... - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion; iterate - ) - return DefaultState(; iterate, stopping_criterion_state, kwargs...) -end - -# ============================ DefaultState ================================================ - -@kwdef mutable struct DefaultState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: State - iterate::Iterate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState -end - -# ============================ increment! ================================================== - -# Custom version of `increment!` that also takes the problem and algorithm as arguments. -function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) - return AI.increment!(state) -end - # ============================ NestedAlgorithm ============================================= -abstract type NestedAlgorithm <: Algorithm end +abstract type NestedAlgorithm <: AI.Algorithm end # Subtypes of `NestedAlgorithm` must override `initialize_subsolve` — it # returns the `(subproblem, subalgorithm, substate)` tuple that the next diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 8f88ff4..91402ff 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -31,13 +31,13 @@ end previous_iterate::Iterate end -function AI.initialize_state(::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged; iterate) +function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate) return StopWhenConvergedState(; previous_iterate = copy(iterate)) end function AI.initialize_state!( - ::AIE.Problem, - ::AIE.Algorithm, + ::AI.Problem, + ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState ) @@ -46,9 +46,9 @@ function AI.initialize_state!( end function AI.is_finished!( - problem::AIE.Problem, - algorithm::AIE.Algorithm, - state::AIE.State, + problem::AI.Problem, + algorithm::AI.Algorithm, + state::AI.State, c::StopWhenConverged, st::StopWhenConvergedState ) @@ -73,19 +73,30 @@ function AI.is_finished!( end function AI.is_finished( - ::AIE.Problem, - ::AIE.Algorithm, - ::AIE.State, + ::AI.Problem, + ::AI.Algorithm, + ::AI.State, c::StopWhenConverged, st::StopWhenConvergedState ) return st.delta < c.tol end -struct BeliefPropagationProblem{Factors} <: AIE.Problem +struct BeliefPropagationProblem{Factors} <: AI.Problem factors::Factors end +# Shared state type for the BP NestedAlgorithm subtypes +# (`BeliefPropagation` / `BeliefPropagationSweep`). Mirrors the small +# state shape `BPApplyGate` uses in the apply-operator path. +@kwdef mutable struct BeliefPropagationState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, + } <: AI.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + function iterate_diff( cache1::MessageCache, cache2::MessageCache @@ -98,7 +109,7 @@ function iterate_diff( end @kwdef struct BeliefPropagation{ - ChildAlgorithm <: AIE.Algorithm, + ChildAlgorithm <: AI.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm @@ -152,12 +163,53 @@ function BeliefPropagationSweep(f::Function, edges) return BeliefPropagationSweep(; algorithms = f.(edges)) end +# State construction / reset / increment for the BP NestedAlgorithm +# pair. Both `BeliefPropagation` and `BeliefPropagationSweep` carry a +# `stopping_criterion`, so the standard AlgorithmsInterface state chain +# resolves the same way. +const BeliefPropagationLikeAlgorithm = Union{BeliefPropagation, BeliefPropagationSweep} + +function AI.initialize_state( + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationLikeAlgorithm; + iterate, kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return BeliefPropagationState(; iterate, stopping_criterion_state, kwargs...) +end + +function AI.initialize_state!( + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationLikeAlgorithm, + state::BeliefPropagationState; + iteration = 0, kwargs... + ) + for (k, v) in pairs(kwargs) + setproperty!(state, k, v) + end + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + +function AI.increment!( + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationLikeAlgorithm, + state::BeliefPropagationState + ) + return AI.increment!(state) +end + # `BeliefPropagation` and `BeliefPropagationSweep` carry a flat list of # child algorithms. Each step picks the child algorithm by the current # iteration index and reuses the parent problem. function AIE.initialize_subsolve( problem::BeliefPropagationProblem, - algorithm::Union{BeliefPropagation, BeliefPropagationSweep}, + algorithm::BeliefPropagationLikeAlgorithm, state::AI.State ) subproblem = problem @@ -169,7 +221,7 @@ end function AIE.finalize_substate!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, - state::AIE.DefaultState, + state::BeliefPropagationState, cache::MessageCache ) state.iterate = cache @@ -211,7 +263,7 @@ function beliefpropagation( if isnothing(maxiter) throw( ArgumentError( - "`maxiter` must be specified for non-tree graphs, even when + "`maxiter` must be specified for non-tree graphs, even when `stopping_criterion` is provided." ) ) diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 0ec2d16..1345bfd 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -2,27 +2,60 @@ import AlgorithmsInterface as AI import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE using Test: @test, @test_throws, @testset -# Define test problems, algorithms, and states for testing -struct TestProblem <: AIE.Problem - data::Vector{Float64} +# Concrete `NestedAlgorithm` subtype: holds a flat list of child algorithms +# and picks them by iteration index. Mirrors how `BeliefPropagation` shapes +# itself on top of the minimal `AIE.NestedAlgorithm`. +struct TestProblem <: AI.Problem end + +@kwdef struct TestChildAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AI.Algorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(2) +end + +@kwdef mutable struct TestChildState{SCState <: AI.StoppingCriterionState} <: AI.State + iterate::Vector{Float64} + iteration::Int = 0 + stopping_criterion_state::SCState +end + +function AI.initialize_state( + problem::TestProblem, algorithm::TestChildAlgorithm; + iterate, kwargs... + ) + sc_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return TestChildState(; iterate, stopping_criterion_state = sc_state, kwargs...) +end + +function AI.initialize_state!( + problem::TestProblem, algorithm::TestChildAlgorithm, state::TestChildState; + iteration = 0, kwargs... + ) + for (k, v) in pairs(kwargs) + setproperty!(state, k, v) + end + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state end -@kwdef struct TestAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(10) +function AI.increment!( + problem::TestProblem, algorithm::TestChildAlgorithm, state::TestChildState + ) + return AI.increment!(state) end function AI.step!( - problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState + problem::TestProblem, algorithm::TestChildAlgorithm, state::TestChildState ) - state.iterate .+= 1 # Simple increment step + state.iterate .+= 1 return state end -# Concrete `NestedAlgorithm` subtype: holds a flat list of child algorithms -# and picks them by iteration index. Mirrors how `BeliefPropagation` shapes -# itself on top of the new minimal `AIE.NestedAlgorithm`. @kwdef struct TestNestedAlgorithm{ - ChildAlgorithm <: AIE.Algorithm, + ChildAlgorithm <: AI.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm @@ -30,6 +63,37 @@ end stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end +# Reuse the child-state shape for the parent algorithm too. +function AI.initialize_state( + problem::TestProblem, algorithm::TestNestedAlgorithm; + iterate, kwargs... + ) + sc_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return TestChildState(; iterate, stopping_criterion_state = sc_state, kwargs...) +end + +function AI.initialize_state!( + problem::TestProblem, algorithm::TestNestedAlgorithm, state::TestChildState; + iteration = 0, kwargs... + ) + for (k, v) in pairs(kwargs) + setproperty!(state, k, v) + end + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + +function AI.increment!( + problem::TestProblem, algorithm::TestNestedAlgorithm, state::TestChildState + ) + return AI.increment!(state) +end + function AIE.initialize_subsolve( problem::TestProblem, algorithm::TestNestedAlgorithm, state::AI.State ) @@ -40,113 +104,35 @@ function AIE.initialize_subsolve( end @testset "AlgorithmsInterfaceExtensions" begin - @testset "DefaultState" begin - iterate = [1.0, 2.0, 3.0] - stopping_criterion_state = AI.initialize_state( - TestProblem([1.0]), TestAlgorithm(), TestAlgorithm().stopping_criterion - ) - state = AIE.DefaultState(; iterate = copy(iterate), stopping_criterion_state) - @test state.iterate == iterate - @test state.iteration == 0 - @test state.stopping_criterion_state isa AI.StoppingCriterionState - - state.iteration = 5 - @test state.iteration == 5 - end - - @testset "initialize_state!" begin - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm() - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - state = AIE.DefaultState(; - iteration = 2, iterate = [0.0, 0.0], stopping_criterion_state - ) - AI.initialize_state!(problem, algorithm, state) - @test state.iterate == [0.0, 0.0] - @test state.iteration == 0 - @test state.stopping_criterion_state == stopping_criterion_state - end - - @testset "initialize_state" begin - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm() - state = AI.initialize_state(problem, algorithm; iterate = [0.0, 0.0]) - @test state isa AIE.DefaultState - @test state.iteration == 0 - end - - @testset "increment!" begin - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm() - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) - - AI.increment!(problem, algorithm, state) - @test state.iteration == 1 - AI.increment!(problem, algorithm, state) - @test state.iteration == 2 - end - - @testset "solve! and solve" begin - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(3)) - state = AI.initialize_state(problem, algorithm; iterate = [10.0, 20.0]) - - initial_iterate = [5.0, 10.0] - final_iterate = AI.solve!( - problem, algorithm, state; iterate = copy(initial_iterate) - ) - @test state.iteration == 3 - @test final_iterate == state.iterate - # Each step increments by 1, so after 3 steps: [5, 10] + 3 = [8, 13] - @test state.iterate ≈ [8.0, 13.0] - - problem2 = TestProblem([1.0, 2.0]) - algorithm2 = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) - final_iterate2 = AI.solve(problem2, algorithm2; iterate = [5.0, 10.0]) - @test final_iterate2 ≈ [7.0, 12.0] - end - @testset "NestedAlgorithm defaults" begin # The bare `initialize_subsolve` default throws a `MethodError`, # forcing concrete subtypes to provide their own override. - problem = TestProblem([1.0]) - algorithm = TestAlgorithm() - state = AIE.DefaultState(; - iterate = [0.0], - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - ) + problem = TestProblem() + algorithm = TestChildAlgorithm() + state = AI.initialize_state(problem, algorithm; iterate = [0.0]) @test_throws MethodError AIE.initialize_subsolve(problem, algorithm, state) # `finalize_substate!` copies the substate's iterate back into the # parent state. - substate = AIE.DefaultState(; - iterate = [42.0], - stopping_criterion_state = state.stopping_criterion_state - ) + substate = AI.initialize_state(problem, algorithm; iterate = [42.0]) AIE.finalize_substate!(problem, algorithm, state, substate) @test state.iterate == [42.0] end @testset "TestNestedAlgorithm" begin - problem = TestProblem([1.0, 2.0]) + problem = TestProblem() nested_alg = TestNestedAlgorithm(; algorithms = [ - TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), - TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)), + TestChildAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + TestChildAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)), ] ) @test nested_alg isa AIE.NestedAlgorithm state = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) AI.solve!(problem, nested_alg, state; iterate = [0.0, 0.0]) - # Two child algorithms: 1 iter + 2 iter = 3 inner increments total. + # Two child algorithms: 1 inner step + 2 inner steps = 3 total + # `state.iterate .+= 1` calls. @test state.iteration == 2 @test state.iterate ≈ [3.0, 3.0] end From b1821ce82ba6ac3fa6dd0e2cf8c0fc7677fc9d52 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 16:41:41 -0400 Subject: [PATCH 43/64] Refactor BP into three problem/algorithm/state triples MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Layer 1 — outer BP loop: `BeliefPropagationProblem` / `BeliefPropagationAlgorithm` / `BeliefPropagationState` (iterative, `<: AIE.NestedAlgorithm`) Layer 2 — one sweep over edges: `BeliefPropagationSweepProblem` / `BeliefPropagationSweepAlgorithm` / `BeliefPropagationSweepState` (iterative, `<: AIE.NestedAlgorithm`) Layer 3 — single-edge message update: `MessageUpdateProblem` / `SimpleMessageUpdateAlgorithm` / `MessageUpdateState` (non-iterative, overrides `AI.solve_loop!` — same shape `BPApplyGate` uses in the gate-apply code) Every layer subtypes `AI.Problem` / `AI.Algorithm` / `AI.State` directly. No `Union{...}` shortcuts — each layer's `initialize_state` / `initialize_state!` / `initialize_subsolve` is its own method. The `StopWhenConverged` dispatches accept `AI.Problem` / `AI.Algorithm` / `AI.State` since `StopWhenConverged` itself is the unique disambiguating type. Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 253 ++++++++++-------- test/test_algorithmsinterfaceextensions.jl | 4 +- 2 files changed, 140 insertions(+), 117 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 91402ff..1202e2b 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -21,6 +21,8 @@ function rows(nt::NamedTuple, len::Int = rowlength(nt)) return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] end +# === `StopWhenConverged` stopping criterion === + @kwdef struct StopWhenConverged <: AI.StoppingCriterion tol::Float64 end @@ -36,10 +38,7 @@ function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; end function AI.initialize_state!( - ::AI.Problem, - ::AI.Algorithm, - ::StopWhenConverged, - st::StopWhenConvergedState + ::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState ) st.delta = Inf return st @@ -59,7 +58,7 @@ function AI.is_finished!( st.previous_iterate = copy(iterate) - # maxdiff = 0.0 initially, so skip this the first time. + # delta = 0 initially, so skip this the first time. state.iteration == 0 && return false st.delta = delta @@ -82,33 +81,78 @@ function AI.is_finished( return st.delta < c.tol end -struct BeliefPropagationProblem{Factors} <: AI.Problem +function iterate_diff(cache1::MessageCache, cache2::MessageCache) + return maximum(edges(cache1)) do edge + m1 = cache1[edge] + m2 = cache2[edge] + return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) + end +end + +# === Layer 3: single-edge message update (non-iterative) === + +struct MessageUpdateProblem{Factors} <: AI.Problem factors::Factors end -# Shared state type for the BP NestedAlgorithm subtypes -# (`BeliefPropagation` / `BeliefPropagationSweep`). Mirrors the small -# state shape `BPApplyGate` uses in the apply-operator path. -@kwdef mutable struct BeliefPropagationState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: AI.State +@kwdef struct SimpleMessageUpdateAlgorithm{ + E <: AbstractEdge, ContractionAlg, + } <: AI.Algorithm + edge::E + normalize::Bool = true + contraction_alg::ContractionAlg = Algorithm"exact" +end + +@kwdef mutable struct MessageUpdateState{Iterate} <: AI.State iterate::Iterate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState end -function iterate_diff( - cache1::MessageCache, - cache2::MessageCache +function AI.initialize_state( + ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm; iterate ) - return maximum(edges(cache1)) do edge - m1 = cache1[edge] - m2 = cache2[edge] - return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) + return MessageUpdateState(; iterate) +end + +# Non-iterative algorithm: no per-call state to reset. +function AI.initialize_state!( + ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm, state::MessageUpdateState + ) + return state +end + +# Non-iterative algorithm: bypass the step!/stopping-criterion loop. +function AI.solve_loop!( + problem::MessageUpdateProblem, + algorithm::SimpleMessageUpdateAlgorithm, + state::MessageUpdateState + ) + cache = state.iterate + edge = algorithm.edge + + messages = collect(incoming_messages(cache, edge)) + factor = problem.factors[src(edge)] + + new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) + + if algorithm.normalize + message_norm = sum(new_message) + if !iszero(message_norm) + new_message /= message_norm + end end + + cache[edge] = new_message + + return state +end + +# === Layer 2: one sweep over edges (iterative) === + +struct BeliefPropagationSweepProblem{Factors} <: AI.Problem + factors::Factors end -@kwdef struct BeliefPropagation{ +@kwdef struct BeliefPropagationSweepAlgorithm{ ChildAlgorithm <: AI.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, @@ -117,78 +161,101 @@ end stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -function BeliefPropagation(f::Function, niterations::Int; kwargs...) - return BeliefPropagation(; algorithms = f.(1:niterations), kwargs...) +function BeliefPropagationSweepAlgorithm(f::Function, edges) + return BeliefPropagationSweepAlgorithm(; algorithms = f.(edges)) end -struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} - edge::E - kwargs::Kwargs +@kwdef mutable struct BeliefPropagationSweepState{ + Iterate, SCState <: AI.StoppingCriterionState, + } <: AI.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::SCState end -function SimpleMessageUpdate( - edge; - normalize = true, - contraction_alg = Algorithm"exact", - kwargs... +function AI.initialize_state( + problem::BeliefPropagationSweepProblem, + algorithm::BeliefPropagationSweepAlgorithm; + iterate, iteration::Int = 0 + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate ) - return SimpleMessageUpdate( - edge, - (; normalize, contraction_alg, kwargs...) + return BeliefPropagationSweepState(; + iterate, iteration, stopping_criterion_state ) end -function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) - if name in (:edge, :kwargs) - return getfield(alg, name) - else - return getproperty(getfield(alg, :kwargs), name) - end +function AI.initialize_state!( + problem::BeliefPropagationSweepProblem, + algorithm::BeliefPropagationSweepAlgorithm, + state::BeliefPropagationSweepState; + iteration::Int = 0 + ) + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state end -AI.initialize_state(::BeliefPropagationProblem, ::SimpleMessageUpdate; iterate) = iterate +function AIE.initialize_subsolve( + problem::BeliefPropagationSweepProblem, + algorithm::BeliefPropagationSweepAlgorithm, + state::BeliefPropagationSweepState + ) + subproblem = MessageUpdateProblem(problem.factors) + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate +end -struct BeliefPropagationSweep{ - ChildAlgorithm, Algorithms <: AbstractVector{ChildAlgorithm}, +# === Layer 1: BP outer loop (iterative) === + +struct BeliefPropagationProblem{Factors} <: AI.Problem + factors::Factors +end + +@kwdef struct BeliefPropagationAlgorithm{ + ChildAlgorithm <: AI.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm algorithms::Algorithms - stopping_criterion::AI.StopAfterIteration - function BeliefPropagationSweep(; algorithms) - stopping_criterion = AI.StopAfterIteration(length(algorithms)) - return new{eltype(algorithms), typeof(algorithms)}(algorithms, stopping_criterion) - end + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -function BeliefPropagationSweep(f::Function, edges) - return BeliefPropagationSweep(; algorithms = f.(edges)) +function BeliefPropagationAlgorithm(f::Function, niterations::Int; kwargs...) + return BeliefPropagationAlgorithm(; algorithms = f.(1:niterations), kwargs...) end -# State construction / reset / increment for the BP NestedAlgorithm -# pair. Both `BeliefPropagation` and `BeliefPropagationSweep` carry a -# `stopping_criterion`, so the standard AlgorithmsInterface state chain -# resolves the same way. -const BeliefPropagationLikeAlgorithm = Union{BeliefPropagation, BeliefPropagationSweep} +@kwdef mutable struct BeliefPropagationState{ + Iterate, SCState <: AI.StoppingCriterionState, + } <: AI.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::SCState +end function AI.initialize_state( problem::BeliefPropagationProblem, - algorithm::BeliefPropagationLikeAlgorithm; - iterate, kwargs... + algorithm::BeliefPropagationAlgorithm; + iterate, iteration::Int = 0 ) stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion; iterate ) - return BeliefPropagationState(; iterate, stopping_criterion_state, kwargs...) + return BeliefPropagationState(; + iterate, iteration, stopping_criterion_state + ) end function AI.initialize_state!( problem::BeliefPropagationProblem, - algorithm::BeliefPropagationLikeAlgorithm, + algorithm::BeliefPropagationAlgorithm, state::BeliefPropagationState; - iteration = 0, kwargs... + iteration::Int = 0 ) - for (k, v) in pairs(kwargs) - setproperty!(state, k, v) - end state.iteration = iteration AI.initialize_state!( problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state @@ -196,62 +263,18 @@ function AI.initialize_state!( return state end -function AI.increment!( - problem::BeliefPropagationProblem, - algorithm::BeliefPropagationLikeAlgorithm, - state::BeliefPropagationState - ) - return AI.increment!(state) -end - -# `BeliefPropagation` and `BeliefPropagationSweep` carry a flat list of -# child algorithms. Each step picks the child algorithm by the current -# iteration index and reuses the parent problem. function AIE.initialize_subsolve( problem::BeliefPropagationProblem, - algorithm::BeliefPropagationLikeAlgorithm, - state::AI.State + algorithm::BeliefPropagationAlgorithm, + state::BeliefPropagationState ) - subproblem = problem + subproblem = BeliefPropagationSweepProblem(problem.factors) subalgorithm = algorithm.algorithms[state.iteration] substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end -function AIE.finalize_substate!( - ::BeliefPropagationProblem, - ::BeliefPropagationSweep, - state::BeliefPropagationState, - cache::MessageCache - ) - state.iterate = cache - - return state -end - -function AI.solve!( - problem::BeliefPropagationProblem, - algorithm::SimpleMessageUpdate, - cache::MessageCache - ) - edge = algorithm.edge - - messages = collect(incoming_messages(cache, edge)) - factor = problem.factors[src(edge)] - - new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) - - if algorithm.normalize - message_norm = sum(new_message) - if !iszero(message_norm) - new_message /= message_norm - end - end - - cache[edge] = new_message - - return cache -end +# === Top-level user entry point === function beliefpropagation( factors, messages; @@ -287,9 +310,9 @@ function beliefpropagation( extended_kwargs = extend_columns((; kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, maxiter) - algorithm = BeliefPropagation(maxiter; stopping_criterion) do repnum - return BeliefPropagationSweep(edges) do edge - return SimpleMessageUpdate(edge; edge_kwargs[repnum]...) + algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum + return BeliefPropagationSweepAlgorithm(edges) do edge + return SimpleMessageUpdateAlgorithm(; edge, edge_kwargs[repnum]...) end end diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 1345bfd..f580826 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -3,8 +3,8 @@ import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE using Test: @test, @test_throws, @testset # Concrete `NestedAlgorithm` subtype: holds a flat list of child algorithms -# and picks them by iteration index. Mirrors how `BeliefPropagation` shapes -# itself on top of the minimal `AIE.NestedAlgorithm`. +# and picks them by iteration index. Mirrors how `BeliefPropagationAlgorithm` +# shapes itself on top of the minimal `AIE.NestedAlgorithm`. struct TestProblem <: AI.Problem end @kwdef struct TestChildAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AI.Algorithm From fdb3c04ab4f188c6494de432044e67eb75694085 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 16:46:48 -0400 Subject: [PATCH 44/64] Reorder BP source top-to-bottom MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit API first, then the three layers from outer (BP) to inner (message update), then the supporting `StopWhenConverged` stopping criterion + `iterate_diff` helper, then the low-level kwarg utilities at the bottom. Reads as "what's the API → how it's composed → supporting pieces" the way the file is most likely to be skimmed. Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 355 +++++++++--------- 1 file changed, 178 insertions(+), 177 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 1202e2b..44b42ed 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -8,142 +8,115 @@ using NamedDimsArrays: AbstractNamedDimsArray using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedGraphs.PartitionedGraphs: quotientvertices -# Utility functions for processing keyword arguments. -function repeat_last(v::AbstractVector, len::Int) - return [v; fill(v[end], max(len - length(v), 0))] -end -repeat_last(v, len::Int) = fill(v, len) -function extend_columns(nt::NamedTuple, len::Int) - return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) -end -rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) -function rows(nt::NamedTuple, len::Int = rowlength(nt)) - return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] -end - -# === `StopWhenConverged` stopping criterion === - -@kwdef struct StopWhenConverged <: AI.StoppingCriterion - tol::Float64 -end - -@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState - delta::Float64 = Inf - at_iteration::Int = -1 - previous_iterate::Iterate -end - -function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate) - return StopWhenConvergedState(; previous_iterate = copy(iterate)) -end - -function AI.initialize_state!( - ::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState - ) - st.delta = Inf - return st -end +# === Top-level user entry point === -function AI.is_finished!( - problem::AI.Problem, - algorithm::AI.Algorithm, - state::AI.State, - c::StopWhenConverged, - st::StopWhenConvergedState +function beliefpropagation( + factors, messages; + edges = nothing, + maxiter = is_tree(factors) ? 1 : nothing, + stopping_criterion = nothing, + kwargs... ) - iterate = state.iterate - previous_iterate = st.previous_iterate + if isnothing(maxiter) + throw( + ArgumentError( + "`maxiter` must be specified for non-tree graphs, even when + `stopping_criterion` is provided." + ) + ) + end - delta = iterate_diff(iterate, previous_iterate) + cache = MessageCache(messages) + problem = BeliefPropagationProblem(factors) - st.previous_iterate = copy(iterate) + ## Algorithm construction: - # delta = 0 initially, so skip this the first time. - state.iteration == 0 && return false + edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges - st.delta = delta + base_stopping_criterion = AI.StopAfterIteration(maxiter) - if AI.is_finished(problem, algorithm, state, c, st) - st.at_iteration = state.iteration - return true + if !isnothing(stopping_criterion) + base_stopping_criterion |= stopping_criterion end - return false -end + stopping_criterion = base_stopping_criterion -function AI.is_finished( - ::AI.Problem, - ::AI.Algorithm, - ::AI.State, - c::StopWhenConverged, - st::StopWhenConvergedState - ) - return st.delta < c.tol -end + extended_kwargs = extend_columns((; kwargs...), maxiter) + edge_kwargs = rows(extended_kwargs, maxiter) -function iterate_diff(cache1::MessageCache, cache2::MessageCache) - return maximum(edges(cache1)) do edge - m1 = cache1[edge] - m2 = cache2[edge] - return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) + algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum + return BeliefPropagationSweepAlgorithm(edges) do edge + return SimpleMessageUpdateAlgorithm(; edge, edge_kwargs[repnum]...) + end end + + ## + + return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) end -# === Layer 3: single-edge message update (non-iterative) === +# === Layer 1: BP outer loop (iterative) === -struct MessageUpdateProblem{Factors} <: AI.Problem +struct BeliefPropagationProblem{Factors} <: AI.Problem factors::Factors end -@kwdef struct SimpleMessageUpdateAlgorithm{ - E <: AbstractEdge, ContractionAlg, - } <: AI.Algorithm - edge::E - normalize::Bool = true - contraction_alg::ContractionAlg = Algorithm"exact" +@kwdef struct BeliefPropagationAlgorithm{ + ChildAlgorithm <: AI.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -@kwdef mutable struct MessageUpdateState{Iterate} <: AI.State +function BeliefPropagationAlgorithm(f::Function, niterations::Int; kwargs...) + return BeliefPropagationAlgorithm(; algorithms = f.(1:niterations), kwargs...) +end + +@kwdef mutable struct BeliefPropagationState{ + Iterate, SCState <: AI.StoppingCriterionState, + } <: AI.State iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::SCState end function AI.initialize_state( - ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm; iterate + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationAlgorithm; + iterate, iteration::Int = 0 + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return BeliefPropagationState(; + iterate, iteration, stopping_criterion_state ) - return MessageUpdateState(; iterate) end -# Non-iterative algorithm: no per-call state to reset. function AI.initialize_state!( - ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm, state::MessageUpdateState + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationAlgorithm, + state::BeliefPropagationState; + iteration::Int = 0 + ) + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state ) return state end -# Non-iterative algorithm: bypass the step!/stopping-criterion loop. -function AI.solve_loop!( - problem::MessageUpdateProblem, - algorithm::SimpleMessageUpdateAlgorithm, - state::MessageUpdateState +function AIE.initialize_subsolve( + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationAlgorithm, + state::BeliefPropagationState ) - cache = state.iterate - edge = algorithm.edge - - messages = collect(incoming_messages(cache, edge)) - factor = problem.factors[src(edge)] - - new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) - - if algorithm.normalize - message_norm = sum(new_message) - if !iszero(message_norm) - new_message /= message_norm - end - end - - cache[edge] = new_message - - return state + subproblem = BeliefPropagationSweepProblem(problem.factors) + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate end # === Layer 2: one sweep over edges (iterative) === @@ -210,113 +183,141 @@ function AIE.initialize_subsolve( return subproblem, subalgorithm, substate end -# === Layer 1: BP outer loop (iterative) === +# === Layer 3: single-edge message update (non-iterative) === -struct BeliefPropagationProblem{Factors} <: AI.Problem +struct MessageUpdateProblem{Factors} <: AI.Problem factors::Factors end -@kwdef struct BeliefPropagationAlgorithm{ - ChildAlgorithm <: AI.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - } <: AIE.NestedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) -end - -function BeliefPropagationAlgorithm(f::Function, niterations::Int; kwargs...) - return BeliefPropagationAlgorithm(; algorithms = f.(1:niterations), kwargs...) +@kwdef struct SimpleMessageUpdateAlgorithm{ + E <: AbstractEdge, ContractionAlg, + } <: AI.Algorithm + edge::E + normalize::Bool = true + contraction_alg::ContractionAlg = Algorithm"exact" end -@kwdef mutable struct BeliefPropagationState{ - Iterate, SCState <: AI.StoppingCriterionState, - } <: AI.State +@kwdef mutable struct MessageUpdateState{Iterate} <: AI.State iterate::Iterate - iteration::Int = 0 - stopping_criterion_state::SCState end function AI.initialize_state( - problem::BeliefPropagationProblem, - algorithm::BeliefPropagationAlgorithm; - iterate, iteration::Int = 0 - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion; iterate - ) - return BeliefPropagationState(; - iterate, iteration, stopping_criterion_state + ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm; iterate ) + return MessageUpdateState(; iterate) end +# Non-iterative algorithm: no per-call state to reset. function AI.initialize_state!( - problem::BeliefPropagationProblem, - algorithm::BeliefPropagationAlgorithm, - state::BeliefPropagationState; - iteration::Int = 0 - ) - state.iteration = iteration - AI.initialize_state!( - problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm, state::MessageUpdateState ) return state end -function AIE.initialize_subsolve( - problem::BeliefPropagationProblem, - algorithm::BeliefPropagationAlgorithm, - state::BeliefPropagationState +# Non-iterative algorithm: bypass the step!/stopping-criterion loop. +function AI.solve_loop!( + problem::MessageUpdateProblem, + algorithm::SimpleMessageUpdateAlgorithm, + state::MessageUpdateState ) - subproblem = BeliefPropagationSweepProblem(problem.factors) - subalgorithm = algorithm.algorithms[state.iteration] - substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) - return subproblem, subalgorithm, substate + cache = state.iterate + edge = algorithm.edge + + messages = collect(incoming_messages(cache, edge)) + factor = problem.factors[src(edge)] + + new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) + + if algorithm.normalize + message_norm = sum(new_message) + if !iszero(message_norm) + new_message /= message_norm + end + end + + cache[edge] = new_message + + return state end -# === Top-level user entry point === +# === `StopWhenConverged` stopping criterion === -function beliefpropagation( - factors, messages; - edges = nothing, - maxiter = is_tree(factors) ? 1 : nothing, - stopping_criterion = nothing, - kwargs... +@kwdef struct StopWhenConverged <: AI.StoppingCriterion + tol::Float64 +end + +@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState + delta::Float64 = Inf + at_iteration::Int = -1 + previous_iterate::Iterate +end + +function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate) + return StopWhenConvergedState(; previous_iterate = copy(iterate)) +end + +function AI.initialize_state!( + ::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState ) - if isnothing(maxiter) - throw( - ArgumentError( - "`maxiter` must be specified for non-tree graphs, even when - `stopping_criterion` is provided." - ) - ) - end + st.delta = Inf + return st +end - cache = MessageCache(messages) - problem = BeliefPropagationProblem(factors) +function AI.is_finished!( + problem::AI.Problem, + algorithm::AI.Algorithm, + state::AI.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + iterate = state.iterate + previous_iterate = st.previous_iterate - ## Algorithm construction: + delta = iterate_diff(iterate, previous_iterate) - edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges + st.previous_iterate = copy(iterate) - base_stopping_criterion = AI.StopAfterIteration(maxiter) + # delta = 0 initially, so skip this the first time. + state.iteration == 0 && return false - if !isnothing(stopping_criterion) - base_stopping_criterion |= stopping_criterion + st.delta = delta + + if AI.is_finished(problem, algorithm, state, c, st) + st.at_iteration = state.iteration + return true end - stopping_criterion = base_stopping_criterion + return false +end - extended_kwargs = extend_columns((; kwargs...), maxiter) - edge_kwargs = rows(extended_kwargs, maxiter) +function AI.is_finished( + ::AI.Problem, + ::AI.Algorithm, + ::AI.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + return st.delta < c.tol +end - algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum - return BeliefPropagationSweepAlgorithm(edges) do edge - return SimpleMessageUpdateAlgorithm(; edge, edge_kwargs[repnum]...) - end +function iterate_diff(cache1::MessageCache, cache2::MessageCache) + return maximum(edges(cache1)) do edge + m1 = cache1[edge] + m2 = cache2[edge] + return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) end +end - ## +# === Utility functions for processing keyword arguments === - return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) +function repeat_last(v::AbstractVector, len::Int) + return [v; fill(v[end], max(len - length(v), 0))] +end +repeat_last(v, len::Int) = fill(v, len) +function extend_columns(nt::NamedTuple, len::Int) + return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) +end +rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) +function rows(nt::NamedTuple, len::Int = rowlength(nt)) + return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] end From 56b2d82570c6cc02551be85e7811d66557b2173d Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 17:10:38 -0400 Subject: [PATCH 45/64] Move StopWhenConverged + iterate_diff verb into AIE MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `StopWhenConverged` is fully generic — its dispatches accept `AI.Problem` / `AI.Algorithm` / `AI.State`, with `StopWhenConverged` itself as the unique disambiguating type. The only piece that needed to live in BP was the metric: the BP `iterate_diff(::MessageCache, ::MessageCache)` override. AIE now owns: `StopWhenConverged`, `StopWhenConvergedState`, the four `AI.*` overloads, and a bare `iterate_diff(a, b)` verb that throws by default. BP provides its concrete `AIE.iterate_diff(::MessageCache, ::MessageCache)`. The top-level `ITensorNetworksNext` namespace re-exports `StopWhenConverged` + `iterate_diff` for callers that use the package-qualified names. Co-Authored-By: Claude Opus 4.7 --- .../AlgorithmsInterfaceExtensions.jl | 66 +++++++++++++++++++ .../beliefpropagationproblem.jl | 63 +----------------- 2 files changed, 69 insertions(+), 60 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index ddf1083..c7ac727 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -31,4 +31,70 @@ function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.Sta return state end +# ============================ StopWhenConverged =========================================== + +# Stopping criterion that fires once `iterate_diff(iterate, previous_iterate) < tol`. +# Concrete iterate types must supply an `iterate_diff` method. +function iterate_diff(a, b) + return throw(MethodError(iterate_diff, (a, b))) +end + +@kwdef struct StopWhenConverged <: AI.StoppingCriterion + tol::Float64 +end + +@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState + delta::Float64 = Inf + at_iteration::Int = -1 + previous_iterate::Iterate +end + +function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate) + return StopWhenConvergedState(; previous_iterate = copy(iterate)) +end + +function AI.initialize_state!( + ::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState + ) + st.delta = Inf + return st +end + +function AI.is_finished!( + problem::AI.Problem, + algorithm::AI.Algorithm, + state::AI.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + iterate = state.iterate + previous_iterate = st.previous_iterate + + delta = iterate_diff(iterate, previous_iterate) + + st.previous_iterate = copy(iterate) + + # delta = 0 initially, so skip this the first time. + state.iteration == 0 && return false + + st.delta = delta + + if AI.is_finished(problem, algorithm, state, c, st) + st.at_iteration = state.iteration + return true + end + + return false +end + +function AI.is_finished( + ::AI.Problem, + ::AI.Algorithm, + ::AI.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + return st.delta < c.tol +end + end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 44b42ed..32e00ad 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,5 +1,6 @@ import .AlgorithmsInterfaceExtensions as AIE import AlgorithmsInterface as AI +using .AlgorithmsInterfaceExtensions: StopWhenConverged, iterate_diff using BackendSelection: @Algorithm_str, Algorithm using DataGraphs: edge_data using Graphs: AbstractEdge, edges, has_edge, vertices @@ -240,67 +241,9 @@ function AI.solve_loop!( return state end -# === `StopWhenConverged` stopping criterion === +# === `iterate_diff` for `MessageCache` (used by `AIE.StopWhenConverged`) === -@kwdef struct StopWhenConverged <: AI.StoppingCriterion - tol::Float64 -end - -@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState - delta::Float64 = Inf - at_iteration::Int = -1 - previous_iterate::Iterate -end - -function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate) - return StopWhenConvergedState(; previous_iterate = copy(iterate)) -end - -function AI.initialize_state!( - ::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState - ) - st.delta = Inf - return st -end - -function AI.is_finished!( - problem::AI.Problem, - algorithm::AI.Algorithm, - state::AI.State, - c::StopWhenConverged, - st::StopWhenConvergedState - ) - iterate = state.iterate - previous_iterate = st.previous_iterate - - delta = iterate_diff(iterate, previous_iterate) - - st.previous_iterate = copy(iterate) - - # delta = 0 initially, so skip this the first time. - state.iteration == 0 && return false - - st.delta = delta - - if AI.is_finished(problem, algorithm, state, c, st) - st.at_iteration = state.iteration - return true - end - - return false -end - -function AI.is_finished( - ::AI.Problem, - ::AI.Algorithm, - ::AI.State, - c::StopWhenConverged, - st::StopWhenConvergedState - ) - return st.delta < c.tol -end - -function iterate_diff(cache1::MessageCache, cache2::MessageCache) +function AIE.iterate_diff(cache1::MessageCache, cache2::MessageCache) return maximum(edges(cache1)) do edge m1 = cache1[edge] m2 = cache2[edge] From 4fd8a47689e23598e8968de3b726b424ec95755e Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 17:36:34 -0400 Subject: [PATCH 46/64] Move `edge` field from `SimpleMessageUpdateAlgorithm` to `MessageUpdateProblem` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The edge is per-step data; it belongs on the problem side, not the algorithm side. After the move: - `MessageUpdateProblem{Factors, Edge <: AbstractEdge}` carries `factors` + `edge` (type parameter renamed from `E` to `Edge`). - `SimpleMessageUpdateAlgorithm{ContractionAlg}` is now just `normalize` + `contraction_alg` — no per-edge data. - `BeliefPropagationSweepProblem` gains an `edges` field; its `initialize_subsolve` picks `edge = problem.edges[state.iteration]` and threads it into the `MessageUpdateProblem`. - `BeliefPropagationProblem` gains an `edges` field too, and the outer `initialize_subsolve` threads it down to the sweep subproblem. - The entry-point `do edge` closure becomes `do _` since the algorithm no longer needs the edge. Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 32e00ad..640f8e0 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -28,12 +28,13 @@ function beliefpropagation( end cache = MessageCache(messages) - problem = BeliefPropagationProblem(factors) ## Algorithm construction: edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges + problem = BeliefPropagationProblem(factors, edges) + base_stopping_criterion = AI.StopAfterIteration(maxiter) if !isnothing(stopping_criterion) @@ -46,8 +47,8 @@ function beliefpropagation( edge_kwargs = rows(extended_kwargs, maxiter) algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum - return BeliefPropagationSweepAlgorithm(edges) do edge - return SimpleMessageUpdateAlgorithm(; edge, edge_kwargs[repnum]...) + return BeliefPropagationSweepAlgorithm(edges) do _ + return SimpleMessageUpdateAlgorithm(; edge_kwargs[repnum]...) end end @@ -58,8 +59,9 @@ end # === Layer 1: BP outer loop (iterative) === -struct BeliefPropagationProblem{Factors} <: AI.Problem +struct BeliefPropagationProblem{Factors, Edges} <: AI.Problem factors::Factors + edges::Edges end @kwdef struct BeliefPropagationAlgorithm{ @@ -114,7 +116,7 @@ function AIE.initialize_subsolve( algorithm::BeliefPropagationAlgorithm, state::BeliefPropagationState ) - subproblem = BeliefPropagationSweepProblem(problem.factors) + subproblem = BeliefPropagationSweepProblem(problem.factors, problem.edges) subalgorithm = algorithm.algorithms[state.iteration] substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate @@ -122,8 +124,9 @@ end # === Layer 2: one sweep over edges (iterative) === -struct BeliefPropagationSweepProblem{Factors} <: AI.Problem +struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem factors::Factors + edges::Edges end @kwdef struct BeliefPropagationSweepAlgorithm{ @@ -178,7 +181,8 @@ function AIE.initialize_subsolve( algorithm::BeliefPropagationSweepAlgorithm, state::BeliefPropagationSweepState ) - subproblem = MessageUpdateProblem(problem.factors) + edge = problem.edges[state.iteration] + subproblem = MessageUpdateProblem(problem.factors, edge) subalgorithm = algorithm.algorithms[state.iteration] substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate @@ -186,14 +190,12 @@ end # === Layer 3: single-edge message update (non-iterative) === -struct MessageUpdateProblem{Factors} <: AI.Problem +struct MessageUpdateProblem{Factors, Edge <: AbstractEdge} <: AI.Problem factors::Factors + edge::Edge end -@kwdef struct SimpleMessageUpdateAlgorithm{ - E <: AbstractEdge, ContractionAlg, - } <: AI.Algorithm - edge::E +@kwdef struct SimpleMessageUpdateAlgorithm{ContractionAlg} <: AI.Algorithm normalize::Bool = true contraction_alg::ContractionAlg = Algorithm"exact" end @@ -222,7 +224,7 @@ function AI.solve_loop!( state::MessageUpdateState ) cache = state.iterate - edge = algorithm.edge + edge = problem.edge messages = collect(incoming_messages(cache, edge)) factor = problem.factors[src(edge)] From 81f7b4f8863d660f28527ff2757e3d48f5fa5102 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 17:46:49 -0400 Subject: [PATCH 47/64] Move BP edge ordering from `BeliefPropagationProblem` to `BeliefPropagationSweepAlgorithm` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit At the outer BP layer the edge update order is an algorithmic choice (which edges to sweep and in what order, potentially varying per rep), not problem data. So edges now live on `BeliefPropagationSweepAlgorithm` — one sweep algorithm is "do a sweep with these edges". The outer `initialize_subsolve` transfers them into a `BeliefPropagationSweepProblem` when stepping down a layer. Side effect: `BeliefPropagationSweepAlgorithm` no longer needs an `algorithms::Vector` field. The per-edge `SimpleMessageUpdateAlgorithm`s were always identical copies (they no longer carry an edge), so it collapses to a single `message_update_algorithm` template — matching the shape `ApplyOperators` uses in the gate-apply code. Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 640f8e0..887b354 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -28,13 +28,12 @@ function beliefpropagation( end cache = MessageCache(messages) + problem = BeliefPropagationProblem(factors) ## Algorithm construction: edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges - problem = BeliefPropagationProblem(factors, edges) - base_stopping_criterion = AI.StopAfterIteration(maxiter) if !isnothing(stopping_criterion) @@ -47,9 +46,12 @@ function beliefpropagation( edge_kwargs = rows(extended_kwargs, maxiter) algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum - return BeliefPropagationSweepAlgorithm(edges) do _ - return SimpleMessageUpdateAlgorithm(; edge_kwargs[repnum]...) - end + return BeliefPropagationSweepAlgorithm(; + edges, + message_update_algorithm = SimpleMessageUpdateAlgorithm(; + edge_kwargs[repnum]... + ) + ) end ## @@ -59,9 +61,8 @@ end # === Layer 1: BP outer loop (iterative) === -struct BeliefPropagationProblem{Factors, Edges} <: AI.Problem +struct BeliefPropagationProblem{Factors} <: AI.Problem factors::Factors - edges::Edges end @kwdef struct BeliefPropagationAlgorithm{ @@ -116,8 +117,8 @@ function AIE.initialize_subsolve( algorithm::BeliefPropagationAlgorithm, state::BeliefPropagationState ) - subproblem = BeliefPropagationSweepProblem(problem.factors, problem.edges) subalgorithm = algorithm.algorithms[state.iteration] + subproblem = BeliefPropagationSweepProblem(problem.factors, subalgorithm.edges) substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end @@ -130,16 +131,13 @@ struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem end @kwdef struct BeliefPropagationSweepAlgorithm{ + Edges, ChildAlgorithm <: AI.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) -end - -function BeliefPropagationSweepAlgorithm(f::Function, edges) - return BeliefPropagationSweepAlgorithm(; algorithms = f.(edges)) + edges::Edges + message_update_algorithm::ChildAlgorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(edges)) end @kwdef mutable struct BeliefPropagationSweepState{ @@ -183,7 +181,7 @@ function AIE.initialize_subsolve( ) edge = problem.edges[state.iteration] subproblem = MessageUpdateProblem(problem.factors, edge) - subalgorithm = algorithm.algorithms[state.iteration] + subalgorithm = algorithm.message_update_algorithm substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end From 6dbf365b1b4205816664e545ef3340595478beb9 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 18:39:29 -0400 Subject: [PATCH 48/64] Store edges on `BeliefPropagationAlgorithm`, not the sweep algorithm Edges shouldn't be stored at both the outer BP algorithm level (as the algorithmic edge-ordering choice) and the sweep algorithm level (as runtime data that gets copied into the sweep problem). Keep them only on `BeliefPropagationAlgorithm`; the outer `initialize_subsolve` then transfers them into the `BeliefPropagationSweepProblem` when stepping down. `BeliefPropagationSweepAlgorithm` is now just `message_update_algorithm` + `stopping_criterion` (the stopping criterion is sized to the edges at construction time in `beliefpropagation()`). Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 887b354..95a4e08 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -45,12 +45,13 @@ function beliefpropagation( extended_kwargs = extend_columns((; kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, maxiter) - algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum + sweep_stopping_criterion = AI.StopAfterIteration(length(edges)) + algorithm = BeliefPropagationAlgorithm(maxiter; edges, stopping_criterion) do repnum return BeliefPropagationSweepAlgorithm(; - edges, message_update_algorithm = SimpleMessageUpdateAlgorithm(; edge_kwargs[repnum]... - ) + ), + stopping_criterion = sweep_stopping_criterion ) end @@ -66,16 +67,18 @@ struct BeliefPropagationProblem{Factors} <: AI.Problem end @kwdef struct BeliefPropagationAlgorithm{ + Edges, ChildAlgorithm <: AI.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm + edges::Edges algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -function BeliefPropagationAlgorithm(f::Function, niterations::Int; kwargs...) - return BeliefPropagationAlgorithm(; algorithms = f.(1:niterations), kwargs...) +function BeliefPropagationAlgorithm(f::Function, niterations::Int; edges, kwargs...) + return BeliefPropagationAlgorithm(; edges, algorithms = f.(1:niterations), kwargs...) end @kwdef mutable struct BeliefPropagationState{ @@ -118,7 +121,7 @@ function AIE.initialize_subsolve( state::BeliefPropagationState ) subalgorithm = algorithm.algorithms[state.iteration] - subproblem = BeliefPropagationSweepProblem(problem.factors, subalgorithm.edges) + subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end @@ -131,13 +134,11 @@ struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem end @kwdef struct BeliefPropagationSweepAlgorithm{ - Edges, ChildAlgorithm <: AI.Algorithm, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm - edges::Edges message_update_algorithm::ChildAlgorithm - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(edges)) + stopping_criterion::StoppingCriterion end @kwdef mutable struct BeliefPropagationSweepState{ From 0d871c20f5400dc0269f07baaec3c89b98df15ba Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 18:54:35 -0400 Subject: [PATCH 49/64] Index per-edge BP algorithms by edge; drop `AbstractVector` constraints Make `BeliefPropagationSweepAlgorithm.algorithms` indexable by edge (default: a `Dict{Edge, MessageUpdateAlgorithm}` populated with the same template at construction time), and drop the `Algorithms <: AbstractVector{ChildAlgorithm}` constraint from both `BeliefPropagationAlgorithm` and `BeliefPropagationSweepAlgorithm` so any indexable container works. Edges remain sourced from `BeliefPropagationSweepProblem` (the source of truth); the sweep `initialize_subsolve` looks up `algorithm.algorithms[edge]` to pick the per-edge update algorithm. Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 95a4e08..2dd7b32 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -45,13 +45,12 @@ function beliefpropagation( extended_kwargs = extend_columns((; kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, maxiter) - sweep_stopping_criterion = AI.StopAfterIteration(length(edges)) algorithm = BeliefPropagationAlgorithm(maxiter; edges, stopping_criterion) do repnum + message_update_algorithm = SimpleMessageUpdateAlgorithm(; + edge_kwargs[repnum]... + ) return BeliefPropagationSweepAlgorithm(; - message_update_algorithm = SimpleMessageUpdateAlgorithm(; - edge_kwargs[repnum]... - ), - stopping_criterion = sweep_stopping_criterion + algorithms = Dict(edge => message_update_algorithm for edge in edges) ) end @@ -68,11 +67,11 @@ end @kwdef struct BeliefPropagationAlgorithm{ Edges, - ChildAlgorithm <: AI.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, + Algorithms, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm edges::Edges + # Indexable by iteration count (e.g. `Vector` or `Dict{Int, ...}`). algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end @@ -134,11 +133,14 @@ struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem end @kwdef struct BeliefPropagationSweepAlgorithm{ - ChildAlgorithm <: AI.Algorithm, + Algorithms, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm - message_update_algorithm::ChildAlgorithm - stopping_criterion::StoppingCriterion + # Indexable by edge (e.g. `Dict{Edge, MessageUpdateAlgorithm}`); the + # default constructor in `beliefpropagation()` builds one with the same + # template copied across every edge. + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end @kwdef mutable struct BeliefPropagationSweepState{ @@ -182,7 +184,7 @@ function AIE.initialize_subsolve( ) edge = problem.edges[state.iteration] subproblem = MessageUpdateProblem(problem.factors, edge) - subalgorithm = algorithm.message_update_algorithm + subalgorithm = algorithm.algorithms[edge] substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end From e5a58d05dc177a9dba07a4fd7ac17d3929c17d31 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 18:57:53 -0400 Subject: [PATCH 50/64] Rename `beliefpropagationproblem.jl` to `beliefpropagation.jl` The file now defines three problem/algorithm/state triples plus the top-level `beliefpropagation` entry point, so the old name is misleading. Co-Authored-By: Claude Opus 4.7 --- src/ITensorNetworksNext.jl | 2 +- .../{beliefpropagationproblem.jl => beliefpropagation.jl} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/beliefpropagation/{beliefpropagationproblem.jl => beliefpropagation.jl} (100%) diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 394546a..41ce78e 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -14,6 +14,6 @@ include("TensorNetworkGenerators/TensorNetworkGenerators.jl") include("contract_network.jl") include("beliefpropagation/messagecache.jl") -include("beliefpropagation/beliefpropagationproblem.jl") +include("beliefpropagation/beliefpropagation.jl") end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagation.jl similarity index 100% rename from src/beliefpropagation/beliefpropagationproblem.jl rename to src/beliefpropagation/beliefpropagation.jl From 2e8ee3b26a070fc1ce33d89ceafcbbac056f562f Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 21:03:05 -0400 Subject: [PATCH 51/64] Simplify BP API: single child algorithms + select_* selectors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Public-API changes: * `beliefpropagation` no longer takes a `maxiter` kwarg; pass either `stopping_criterion = (; maxiter = 10)` or a full criterion such as `stopping_criterion = AI.StopAfterIteration(10) | StopWhenConverged(1.0e-10)`. * The `message_update_algorithm` kwarg accepts either a NamedTuple of keyword arguments forwarded to `SimpleMessageUpdateAlgorithm`, or a full `AI.Algorithm`. * The `edges` kwarg defaults to `default_beliefpropagation_edges(factors)` rather than being computed inside the body. * Per-iteration kwarg broadcasting (`extend_columns`, `rows`, `rowlength`, `repeat_last`) is removed. Internals: * `BeliefPropagationSweepAlgorithm` now holds a single `message_update_algorithm` (was a `Dict{Edge, ...}` of per-edge algorithms). Edge-dependent updates are expressed by defining a new `AI.Algorithm` subtype and dispatching on `problem.edge` in `solve_loop!`. * `BeliefPropagationAlgorithm` holds a single `sweep_algorithm` (was a `Vector` of per-iteration sweep algorithms). Iteration-dependent sweep behavior is expressed by defining a new sweep algorithm subtype and varying the inner algorithm in `AIE.initialize_subsolve` using `state.iteration`. * New `select_*` selectors `select_beliefpropagation_stopping_criterion` and `select_message_update_algorithm`, modeled on `MatrixAlgebraKit.select_truncation`, normalize NamedTuple-or-object inputs into algorithm/criterion objects. * `default_beliefpropagation_edges` and `default_message_update_algorithm` expose the defaults as named functions. * Stopping-criterion default removal: the entry point no longer constructs a `StopAfterIteration(maxiter)` itself — choice of criterion is entirely the caller's responsibility, except for the inner sweep, which still defaults to `StopAfterIteration(length(edges))`. * Type parameter `SCState` renamed to `StoppingCriterionState` on both BP state structs. * Helper definitions reordered so default/select helpers sit above `beliefpropagation` for readability. Co-Authored-By: Claude Opus 4.7 --- src/beliefpropagation/beliefpropagation.jl | 131 ++++++++++----------- test/test_beliefpropagation.jl | 16 +-- 2 files changed, 73 insertions(+), 74 deletions(-) diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index 2dd7b32..667debd 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -11,50 +11,69 @@ using NamedGraphs.PartitionedGraphs: quotientvertices # === Top-level user entry point === -function beliefpropagation( - factors, messages; - edges = nothing, - maxiter = is_tree(factors) ? 1 : nothing, - stopping_criterion = nothing, - kwargs... +default_beliefpropagation_edges(graph) = forest_cover_edge_sequence(graph) + +default_message_update_algorithm(; kwargs...) = SimpleMessageUpdateAlgorithm(; kwargs...) + +select_message_update_algorithm(algorithm::AI.Algorithm) = algorithm +function select_message_update_algorithm(kwargs::NamedTuple) + return default_message_update_algorithm(; kwargs...) +end + +function select_beliefpropagation_stopping_criterion( + stopping_criterion::AI.StoppingCriterion + ) + return stopping_criterion +end +function select_beliefpropagation_stopping_criterion(::Nothing) + return throw( + ArgumentError( + "`stopping_criterion` must be specified, e.g.\n" * + " `stopping_criterion = (; maxiter = 10)` or\n" * + " `stopping_criterion = AI.StopAfterIteration(10) | StopWhenConverged(1.0e-10)`." + ) ) +end +function select_beliefpropagation_stopping_criterion(kwargs::NamedTuple) + return select_beliefpropagation_stopping_criterion(; kwargs...) +end +function select_beliefpropagation_stopping_criterion(; maxiter = nothing, kwargs...) if isnothing(maxiter) + throw(ArgumentError("`maxiter` must be specified in `stopping_criterion`.")) + end + if !isempty(kwargs) throw( ArgumentError( - "`maxiter` must be specified for non-tree graphs, even when - `stopping_criterion` is provided." + "Unrecognized `stopping_criterion` kwargs: $(keys(kwargs)). " * + "Only `maxiter` is currently supported." ) ) end + return AI.StopAfterIteration(maxiter) +end - cache = MessageCache(messages) +function beliefpropagation( + factors, messages; + edges = default_beliefpropagation_edges(factors), + stopping_criterion = nothing, + message_update_algorithm = default_message_update_algorithm() + ) problem = BeliefPropagationProblem(factors) + cache = MessageCache(messages) - ## Algorithm construction: - - edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges - - base_stopping_criterion = AI.StopAfterIteration(maxiter) - - if !isnothing(stopping_criterion) - base_stopping_criterion |= stopping_criterion - end - - stopping_criterion = base_stopping_criterion - - extended_kwargs = extend_columns((; kwargs...), maxiter) - edge_kwargs = rows(extended_kwargs, maxiter) + stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) + message_update_algorithm = select_message_update_algorithm(message_update_algorithm) - algorithm = BeliefPropagationAlgorithm(maxiter; edges, stopping_criterion) do repnum - message_update_algorithm = SimpleMessageUpdateAlgorithm(; - edge_kwargs[repnum]... - ) - return BeliefPropagationSweepAlgorithm(; - algorithms = Dict(edge => message_update_algorithm for edge in edges) - ) - end + sweep_algorithm = BeliefPropagationSweepAlgorithm(; + message_update_algorithm, + stopping_criterion = AI.StopAfterIteration(length(edges)) + ) - ## + algorithm = BeliefPropagationAlgorithm(; + edges, + sweep_algorithm, + stopping_criterion + ) return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) end @@ -67,25 +86,20 @@ end @kwdef struct BeliefPropagationAlgorithm{ Edges, - Algorithms, + SweepAlgorithm <: AI.Algorithm, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm edges::Edges - # Indexable by iteration count (e.g. `Vector` or `Dict{Int, ...}`). - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) -end - -function BeliefPropagationAlgorithm(f::Function, niterations::Int; edges, kwargs...) - return BeliefPropagationAlgorithm(; edges, algorithms = f.(1:niterations), kwargs...) + sweep_algorithm::SweepAlgorithm + stopping_criterion::StoppingCriterion end @kwdef mutable struct BeliefPropagationState{ - Iterate, SCState <: AI.StoppingCriterionState, + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, } <: AI.State iterate::Iterate iteration::Int = 0 - stopping_criterion_state::SCState + stopping_criterion_state::StoppingCriterionState end function AI.initialize_state( @@ -119,7 +133,7 @@ function AIE.initialize_subsolve( algorithm::BeliefPropagationAlgorithm, state::BeliefPropagationState ) - subalgorithm = algorithm.algorithms[state.iteration] + subalgorithm = algorithm.sweep_algorithm subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate @@ -133,22 +147,19 @@ struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem end @kwdef struct BeliefPropagationSweepAlgorithm{ - Algorithms, + MessageUpdateAlgorithm <: AI.Algorithm, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm - # Indexable by edge (e.g. `Dict{Edge, MessageUpdateAlgorithm}`); the - # default constructor in `beliefpropagation()` builds one with the same - # template copied across every edge. - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) + message_update_algorithm::MessageUpdateAlgorithm + stopping_criterion::StoppingCriterion end @kwdef mutable struct BeliefPropagationSweepState{ - Iterate, SCState <: AI.StoppingCriterionState, + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, } <: AI.State iterate::Iterate iteration::Int = 0 - stopping_criterion_state::SCState + stopping_criterion_state::StoppingCriterionState end function AI.initialize_state( @@ -184,7 +195,7 @@ function AIE.initialize_subsolve( ) edge = problem.edges[state.iteration] subproblem = MessageUpdateProblem(problem.factors, edge) - subalgorithm = algorithm.algorithms[edge] + subalgorithm = algorithm.message_update_algorithm substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end @@ -230,7 +241,7 @@ function AI.solve_loop!( messages = collect(incoming_messages(cache, edge)) factor = problem.factors[src(edge)] - new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) + new_message = contract_network([messages; [factor]]; algorithm.contraction_alg) if algorithm.normalize message_norm = sum(new_message) @@ -253,17 +264,3 @@ function AIE.iterate_diff(cache1::MessageCache, cache2::MessageCache) return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) end end - -# === Utility functions for processing keyword arguments === - -function repeat_last(v::AbstractVector, len::Int) - return [v; fill(v[end], max(len - length(v), 0))] -end -repeat_last(v, len::Int) = fill(v, len) -function extend_columns(nt::NamedTuple, len::Int) - return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) -end -rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) -function rows(nt::NamedTuple, len::Int = rowlength(nt)) - return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] -end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 01ca6e7..4362f7e 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -167,7 +167,9 @@ end messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) - cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) + cache = ITensorNetworksNext.beliefpropagation( + tn, messages; stopping_criterion = (; maxiter = 1) + ) z_bp = exp(bethe_free_energy(tn, cache)) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -184,7 +186,9 @@ end messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) - cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) + cache = ITensorNetworksNext.beliefpropagation( + tn, messages; stopping_criterion = (; maxiter = 1) + ) z_bp = exp(bethe_free_energy(tn, cache)) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -198,13 +202,11 @@ end messages = Dict(edge => randt(tn, edge) for edge in all_edges(g)) - stopping_criterion = StopWhenConverged(tol = 1.0e-10) + stopping_criterion = + AI.StopAfterIteration(10) | StopWhenConverged(tol = 1.0e-10) cache = ITensorNetworksNext.beliefpropagation( - tn, - messages; - maxiter = 10, - stopping_criterion + tn, messages; stopping_criterion ) z_bp = exp(bethe_free_energy(tn, cache)) From c8cafb9d78de3620a31876d609484a88534e9c31 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 21:11:28 -0400 Subject: [PATCH 52/64] Minor reorg and line-collapse cleanup in `beliefpropagation` Reorder the body of `beliefpropagation` so the algorithm construction flows top-to-bottom and `cache` is constructed alongside the `AI.solve` call, and collapse a few short multi-line forms onto single lines where the formatter is happy with it. Co-Authored-By: Claude Opus 4.7 --- src/beliefpropagation/beliefpropagation.jl | 25 ++++++---------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index 667debd..b9104ff 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -20,11 +20,7 @@ function select_message_update_algorithm(kwargs::NamedTuple) return default_message_update_algorithm(; kwargs...) end -function select_beliefpropagation_stopping_criterion( - stopping_criterion::AI.StoppingCriterion - ) - return stopping_criterion -end +select_beliefpropagation_stopping_criterion(c::AI.StoppingCriterion) = c function select_beliefpropagation_stopping_criterion(::Nothing) return throw( ArgumentError( @@ -59,21 +55,16 @@ function beliefpropagation( message_update_algorithm = default_message_update_algorithm() ) problem = BeliefPropagationProblem(factors) - cache = MessageCache(messages) - stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) message_update_algorithm = select_message_update_algorithm(message_update_algorithm) - sweep_algorithm = BeliefPropagationSweepAlgorithm(; message_update_algorithm, stopping_criterion = AI.StopAfterIteration(length(edges)) ) + stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) + algorithm = BeliefPropagationAlgorithm(; edges, sweep_algorithm, stopping_criterion) - algorithm = BeliefPropagationAlgorithm(; - edges, - sweep_algorithm, - stopping_criterion - ) + cache = MessageCache(messages) return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) end @@ -110,9 +101,7 @@ function AI.initialize_state( stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion; iterate ) - return BeliefPropagationState(; - iterate, iteration, stopping_criterion_state - ) + return BeliefPropagationState(; iterate, iteration, stopping_criterion_state) end function AI.initialize_state!( @@ -170,9 +159,7 @@ function AI.initialize_state( stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion; iterate ) - return BeliefPropagationSweepState(; - iterate, iteration, stopping_criterion_state - ) + return BeliefPropagationSweepState(; iterate, iteration, stopping_criterion_state) end function AI.initialize_state!( From b0a1e90c3bb9119e180490575612ee6f036c5fe1 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Wed, 20 May 2026 13:04:54 -0400 Subject: [PATCH 53/64] Collapse BP message-update AI layer; introduce `MessageUpdateAlgorithm` strategy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The third `AlgorithmsInterface` layer (`MessageUpdateProblem` / `SimpleMessageUpdateAlgorithm` / `MessageUpdateState`) was a non-iterative function call dressed in problem/algorithm/state ceremony; this commit drops that framing. The per-edge update body lives in `AI.step!(::BeliefPropagationSweepProblem, ::BeliefPropagationSweepAlgorithm, ::BeliefPropagationSweepState)` and delegates to a new strategy interface: - `abstract type MessageUpdateAlgorithm end` - `message_update!(algorithm, cache, factors, edge)` - `SimpleMessageUpdate <: MessageUpdateAlgorithm` (the default; holds `normalize` + `contraction_alg`) `BeliefPropagationSweepAlgorithm` now holds a `message_update_algorithm` field (default `SimpleMessageUpdate()`) and is no longer a `NestedAlgorithm`. Algorithm selection follows `MatrixAlgebraKit.select_algorithm`. Generic `default_algorithm(f; kwargs...)` and `select_algorithm(f, alg; kwargs...)` helpers live in `AlgorithmsInterfaceExtensions`; operations register a default by overloading `default_algorithm(::typeof(f); kwargs...)`. BP overloads both for `message_update!`. Callers can pass either an explicit `MessageUpdateAlgorithm` instance, a `NamedTuple` of keyword arguments forwarded to the default algorithm, or `nothing` plus flat kwargs (also forwarded to the default). A convenience signature `message_update!(cache, factors, edge; alg = nothing, kwargs...)` routes through `AIE.select_algorithm` so the same dispatch pattern is reachable from a free-standing call. The `message_update_algorithm` kwarg on `beliefpropagation` uses the same selector. To keep the message-store iterate persistent across outer-loop iterations without duplicating it on the outer state, introduce an `AIE.NestedState` abstract type with generic `getproperty` / `setproperty!` / `propertynames` forwarders for `:iterate` through a `:substate` field. `BeliefPropagationState <: AIE.NestedState` now wraps a `BeliefPropagationSweepState` as `substate`; the outer state holds no `iterate` field of its own. The default `AIE.finalize_substate!` (which copies `substate.iterate` back to `state.iterate`) becomes a self-write through the forwarder — harmless and removes the need for a BP-specific override. Rename: `BeliefPropagationAlgorithm.sweep_algorithm` field → `subalgorithm` (matches the generic name already used in `AIE.initialize_subsolve`'s return tuple). Co-Authored-By: Claude Opus 4.7 --- .../AlgorithmsInterfaceExtensions.jl | 45 +++++++ src/beliefpropagation/beliefpropagation.jl | 112 +++++++++--------- 2 files changed, 99 insertions(+), 58 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index c7ac727..eac073f 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -31,6 +31,51 @@ function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.Sta return state end +# ============================ NestedState ================================================= + +# State that wraps an inner `substate` and forwards `:iterate` accesses to it, +# so the inner-loop iterate is shared without duplicating storage on the outer +# state. Subtypes must store the inner state as a field named `substate`. +abstract type NestedState <: AI.State end + +# Use `getfield` on the right-hand side so future edits to this forwarder +# can't accidentally recurse through the overload. +function Base.getproperty(state::NestedState, name::Symbol) + name === :iterate && return getfield(state, :substate).iterate + return getfield(state, name) +end +function Base.setproperty!(state::NestedState, name::Symbol, value) + name === :iterate && return (getfield(state, :substate).iterate = value) + return setfield!(state, name, value) +end +function Base.propertynames(state::NestedState) + return (fieldnames(typeof(state))..., :iterate) +end + +# ============================ select_algorithm / default_algorithm ======================== + +# Modeled on `MatrixAlgebraKit.select_algorithm` / `default_algorithm`. An +# operation `f` (a function) declares its default algorithm strategy by +# overloading `default_algorithm(::typeof(f); kwargs...)`. The strategy can +# then be selected at a call site via `select_algorithm(f, alg; kwargs...)`, +# accepting either `nothing` (use defaults), a `NamedTuple` of keyword +# arguments forwarded to the default constructor, or an explicit algorithm +# instance — the last dispatched per-operation (e.g. on a strategy +# supertype). +function default_algorithm(f; kwargs...) + return throw(MethodError(default_algorithm, (f,))) +end + +select_algorithm(f, ::Nothing; kwargs...) = default_algorithm(f; kwargs...) +function select_algorithm(f, alg::NamedTuple; kwargs...) + isempty(kwargs) || throw( + ArgumentError( + "Additional keyword arguments are not allowed when `alg` is a `NamedTuple`." + ) + ) + return default_algorithm(f; alg...) +end + # ============================ StopWhenConverged =========================================== # Stopping criterion that fires once `iterate_diff(iterate, previous_iterate) < tol`. diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index b9104ff..876e5cc 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -13,13 +13,6 @@ using NamedGraphs.PartitionedGraphs: quotientvertices default_beliefpropagation_edges(graph) = forest_cover_edge_sequence(graph) -default_message_update_algorithm(; kwargs...) = SimpleMessageUpdateAlgorithm(; kwargs...) - -select_message_update_algorithm(algorithm::AI.Algorithm) = algorithm -function select_message_update_algorithm(kwargs::NamedTuple) - return default_message_update_algorithm(; kwargs...) -end - select_beliefpropagation_stopping_criterion(c::AI.StoppingCriterion) = c function select_beliefpropagation_stopping_criterion(::Nothing) return throw( @@ -52,17 +45,19 @@ function beliefpropagation( factors, messages; edges = default_beliefpropagation_edges(factors), stopping_criterion = nothing, - message_update_algorithm = default_message_update_algorithm() + message_update_algorithm = nothing ) problem = BeliefPropagationProblem(factors) - message_update_algorithm = select_message_update_algorithm(message_update_algorithm) - sweep_algorithm = BeliefPropagationSweepAlgorithm(; + message_update_algorithm = AIE.select_algorithm( + message_update!, message_update_algorithm + ) + subalgorithm = BeliefPropagationSweepAlgorithm(; message_update_algorithm, stopping_criterion = AI.StopAfterIteration(length(edges)) ) stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) - algorithm = BeliefPropagationAlgorithm(; edges, sweep_algorithm, stopping_criterion) + algorithm = BeliefPropagationAlgorithm(; edges, subalgorithm, stopping_criterion) cache = MessageCache(messages) @@ -77,18 +72,18 @@ end @kwdef struct BeliefPropagationAlgorithm{ Edges, - SweepAlgorithm <: AI.Algorithm, + Subalgorithm <: AI.Algorithm, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm edges::Edges - sweep_algorithm::SweepAlgorithm + subalgorithm::Subalgorithm stopping_criterion::StoppingCriterion end @kwdef mutable struct BeliefPropagationState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: AI.State - iterate::Iterate + Substate <: AI.State, StoppingCriterionState <: AI.StoppingCriterionState, + } <: AIE.NestedState + substate::Substate iteration::Int = 0 stopping_criterion_state::StoppingCriterionState end @@ -98,10 +93,12 @@ function AI.initialize_state( algorithm::BeliefPropagationAlgorithm; iterate, iteration::Int = 0 ) + subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) + substate = AI.initialize_state(subproblem, algorithm.subalgorithm; iterate) stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion; iterate ) - return BeliefPropagationState(; iterate, iteration, stopping_criterion_state) + return BeliefPropagationState(; iteration, stopping_criterion_state, substate) end function AI.initialize_state!( @@ -122,10 +119,8 @@ function AIE.initialize_subsolve( algorithm::BeliefPropagationAlgorithm, state::BeliefPropagationState ) - subalgorithm = algorithm.sweep_algorithm subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) - substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) - return subproblem, subalgorithm, substate + return subproblem, algorithm.subalgorithm, state.substate end # === Layer 2: one sweep over edges (iterative) === @@ -136,10 +131,10 @@ struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem end @kwdef struct BeliefPropagationSweepAlgorithm{ - MessageUpdateAlgorithm <: AI.Algorithm, + MessageUpdateAlgorithm, StoppingCriterion <: AI.StoppingCriterion, - } <: AIE.NestedAlgorithm - message_update_algorithm::MessageUpdateAlgorithm + } <: AI.Algorithm + message_update_algorithm::MessageUpdateAlgorithm = SimpleMessageUpdate() stopping_criterion::StoppingCriterion end @@ -175,58 +170,60 @@ function AI.initialize_state!( return state end -function AIE.initialize_subsolve( +function AI.step!( problem::BeliefPropagationSweepProblem, algorithm::BeliefPropagationSweepAlgorithm, state::BeliefPropagationSweepState ) edge = problem.edges[state.iteration] - subproblem = MessageUpdateProblem(problem.factors, edge) - subalgorithm = algorithm.message_update_algorithm - substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) - return subproblem, subalgorithm, substate + message_update!( + algorithm.message_update_algorithm, state.iterate, problem.factors, edge + ) + return state end -# === Layer 3: single-edge message update (non-iterative) === +# === Layer 3: single-edge message update strategy === -struct MessageUpdateProblem{Factors, Edge <: AbstractEdge} <: AI.Problem - factors::Factors - edge::Edge -end +# Strategy interface: a `MessageUpdateAlgorithm` defines how a single +# message is computed and written back into the message store. Plug in a +# new strategy by subtyping `MessageUpdateAlgorithm` and overloading +# `message_update!(strategy, cache, factors, edge)`. +abstract type MessageUpdateAlgorithm end -@kwdef struct SimpleMessageUpdateAlgorithm{ContractionAlg} <: AI.Algorithm - normalize::Bool = true - contraction_alg::ContractionAlg = Algorithm"exact" -end +function message_update! end -@kwdef mutable struct MessageUpdateState{Iterate} <: AI.State - iterate::Iterate +# Algorithm selection (MAK-style; see `AIE.select_algorithm`). +function AIE.default_algorithm(::typeof(message_update!); kwargs...) + return SimpleMessageUpdate(; kwargs...) end - -function AI.initialize_state( - ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm; iterate +function AIE.select_algorithm( + ::typeof(message_update!), alg::MessageUpdateAlgorithm; kwargs... ) - return MessageUpdateState(; iterate) + isempty(kwargs) || throw( + ArgumentError( + "Additional keyword arguments are not allowed when `alg` is a `MessageUpdateAlgorithm` instance." + ) + ) + return alg end -# Non-iterative algorithm: no per-call state to reset. -function AI.initialize_state!( - ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm, state::MessageUpdateState +# Convenience entry: pick the strategy via `AIE.select_algorithm` +# (accepts either `alg = ::MessageUpdateAlgorithm` / `::NamedTuple`, or flat +# kwargs forwarded to the default algorithm), then dispatch. +function message_update!(cache, factors, edge; alg = nothing, kwargs...) + return message_update!( + AIE.select_algorithm(message_update!, alg; kwargs...), cache, factors, edge ) - return state end -# Non-iterative algorithm: bypass the step!/stopping-criterion loop. -function AI.solve_loop!( - problem::MessageUpdateProblem, - algorithm::SimpleMessageUpdateAlgorithm, - state::MessageUpdateState - ) - cache = state.iterate - edge = problem.edge +@kwdef struct SimpleMessageUpdate{ContractionAlg} <: MessageUpdateAlgorithm + normalize::Bool = true + contraction_alg::ContractionAlg = Algorithm"exact" +end +function message_update!(algorithm::SimpleMessageUpdate, cache, factors, edge) messages = collect(incoming_messages(cache, edge)) - factor = problem.factors[src(edge)] + factor = factors[src(edge)] new_message = contract_network([messages; [factor]]; algorithm.contraction_alg) @@ -238,8 +235,7 @@ function AI.solve_loop!( end cache[edge] = new_message - - return state + return nothing end # === `iterate_diff` for `MessageCache` (used by `AIE.StopWhenConverged`) === From e2220fe54b69e0aad00a8e8e3a1701e9e9aadcb5 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Fri, 13 Feb 2026 17:20:29 -0500 Subject: [PATCH 54/64] Working Parallel BP --- Project.toml | 4 + .../ITensorNetworksNextDaggerExt.jl | 86 ++++++++++ .../daggerbeliefpropagation.jl | 150 ++++++++++++++++++ .../ITensorNetworksNextDistributedExt.jl | 84 ++++++++++ .../distributedbeliefpropagation.jl | 116 ++++++++++++++ src/ITensorNetworksNext.jl | 2 + .../ITensorNetworksNextParallel.jl | 27 ++++ src/ITensorNetworksNextParallel/dagger.jl | 38 +++++ .../distributed.jl | 38 +++++ 9 files changed, 545 insertions(+) create mode 100644 ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl create mode 100644 ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl create mode 100644 ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl create mode 100644 ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl create mode 100644 src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl create mode 100644 src/ITensorNetworksNextParallel/dagger.jl create mode 100644 src/ITensorNetworksNextParallel/distributed.jl diff --git a/Project.toml b/Project.toml index c8358d0..435e87a 100644 --- a/Project.toml +++ b/Project.toml @@ -30,9 +30,13 @@ WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [weakdeps] TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" [extensions] ITensorNetworksNextTensorOperationsExt = "TensorOperations" +ITensorNetworksNextDistributedExt = "Distributed" +ITensorNetworksNextDaggerExt = "Dagger" [compat] AbstractTrees = "0.4.5" diff --git a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl new file mode 100644 index 0000000..b5e2c80 --- /dev/null +++ b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl @@ -0,0 +1,86 @@ +module ITensorNetworksNextDaggerExt + +using Dagger +using Dagger.Distributed +using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerNestedAlgorithm, DaggerState, + ITensorNetworksNextParallel + +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE + + +function ITensorNetworksNextParallel.DaggerNestedAlgorithm(f::Function, iterable; workers = workers(), kwargs...) + return DaggerNestedAlgorithm(; algorithms = map(f, iterable), workers, kwargs...) +end + +function initialize_dagger_state( + problem::AIE.Problem, + algorithm::AIE.Algorithm; + iterate, + remote_subiterates = Dict{Int, Dagger.Chunk}(), + ) + + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + + remote_results = Dict{Int, Dagger.DTask}() + + return DaggerState(; iterate, remote_subiterates, stopping_criterion_state, remote_results) +end + +function AI.initialize_state( + problem::AIE.Problem, + algorithm::DaggerNestedAlgorithm; + kwargs... + ) + return initialize_dagger_state(problem, algorithm; kwargs...) +end + +function AIE.get_subproblem( + problem::AIE.Problem, + algorithm::AIE.NestedAlgorithm, + state::DaggerState + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + + iterate = state.iterate + remote_subiterates = state.remote_subiterates + + substate = AI.initialize_state(subproblem, subalgorithm; iterate, remote_subiterates) + + return subproblem, subalgorithm, substate +end + + +function AI.step!( + problem::AI.Problem, + algorithm::DaggerNestedAlgorithm, + state::DaggerState; + kwargs... + ) + + subproblem, subalgorithm, subiterate_chunk = AIE.get_subproblem(problem, algorithm, state) + + dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm; iterate = subiterate_chunk) + + AIE.set_substate!(problem, algorithm, state, dtask) + + return state +end + +function AIE.set_substate!( + ::AIE.Problem, + ::DaggerNestedAlgorithm, + state::DaggerState, + dtask::Dagger.DTask, + ) + state.remote_results[state.iteration] = dtask + + return state +end + +include("daggerbeliefpropagation.jl") + +end # ITensorNetworksNextDaggerExt diff --git a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl new file mode 100644 index 0000000..a4acdce --- /dev/null +++ b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl @@ -0,0 +1,150 @@ +using Dagger +using Dagger.Distributed + +using DataGraphs: DataGraphs, get_edge_data, get_vertex_data, is_edge_assigned, + is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph +using Dictionaries: Indices +using Graphs: AbstractEdge, AbstractGraph, dst, edges, src, vertices +using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerBeliefPropagationCache, + DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel, dagger_algorithm, + subcache +using ITensorNetworksNext: BeliefPropagation, BeliefPropagationCache, + BeliefPropagationProblem, BeliefPropagationState, ITensorNetworksNext, + beliefpropagation, forest_cover_edge_sequence, select_algorithm +using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientedges, quotientvertices +using NamedGraphs: NamedGraphs +using NamedGraphs.GraphsExtensions: boundary_edges + +function ITensorNetworksNextParallel.subcache(cache::DaggerBeliefPropagationCache, inds) + return subcache(cache.underlying_cache, inds) +end + +function ITensorNetworksNextParallel.DaggerBeliefPropagationCache(network::AbstractGraph) + underlying_cache = BeliefPropagationCache(network) + + keys = Indices(quotientvertices(underlying_cache)) + + workers = Iterators.cycle(Distributed.workers()) + worker_dict = similar(keys, Int) + + for quotient_vertex in keys + worker, workers = Iterators.peel(workers) + worker_dict[quotient_vertex] = worker + end + + quotient_chunks = map(keys) do quotient_vertex + worker = worker_dict[quotient_vertex] + iterate = subcache(underlying_cache, quotient_vertex) + chunk = Dagger.@mutable worker = worker BeliefPropagationState(; iterate) + return chunk + end + + return DaggerBeliefPropagationCache(underlying_cache, quotient_chunks) +end + +DataGraphs.underlying_graph(cache::DaggerBeliefPropagationCache) = underlying_graph(cache.underlying_cache) + +DataGraphs.is_vertex_assigned(bpc::DaggerBeliefPropagationCache, vertex) = is_vertex_assigned(bpc.underlying_cache, vertex) +DataGraphs.is_edge_assigned(bpc::DaggerBeliefPropagationCache, edge) = is_edge_assigned(bpc.undelying_cache, edge) + +DataGraphs.get_vertex_data(bpc::DaggerBeliefPropagationCache, vertex) = get_vertex_data(bpc.underlying_cache, vertex) +DataGraphs.get_edge_data(bpc::DaggerBeliefPropagationCache, edge::AbstractEdge) = get_edge_data(bpc.undelying_caches, edge) + +DataGraphs.set_vertex_data!(bpc::DaggerBeliefPropagationCache, val, vertex) = set_vertex_data!(bpc.underlying_cache, val, vertex) +DataGraphs.set_edge_data!(bpc::DaggerBeliefPropagationCache, val, edge) = set_edge_data!(bpc.underlying_cache, val, edge) + +NamedGraphs.to_graph_index(::DaggerBeliefPropagationCache, qv::QuotientVertex) = qv +function DataGraphs.get_index_data(cache::DaggerBeliefPropagationCache, qv::QuotientVertex) + return cache.quotient_chunks[qv] +end + +function ITensorNetworksNext.beliefpropagation_sweep(cache::DaggerBeliefPropagationCache; edges, workers = workers(), kwargs...) + + keys = collect(quotientvertices(cache)) + + return dagger_algorithm(keys; keys, workers) do quotient_vertex + + subcache = fetch(cache[quotient_vertex]).iterate + + subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges + incoming_edges = boundary_edges(cache, vertices(cache, quotient_vertex); dir = :in) + + alg = select_algorithm( + beliefpropagation, + subcache; + # Don't update the incoming messages + edges = setdiff(subcache_edges, incoming_edges), + maxiter = 1, + kwargs... + ) + + return alg + end +end + +function AI.initialize_state( + problem::AIE.Problem, + algorithm::BeliefPropagation{<:DaggerNestedAlgorithm}; + kwargs... + ) + return initialize_dagger_state(problem, algorithm; kwargs...) +end + +function AIE.get_subproblem( + problem::BeliefPropagationProblem, + algorithm::DaggerNestedAlgorithm, + state::DaggerState, + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + + quotient_vertex = algorithm.keys[state.iteration] + + cache = state.iterate.iterate + + subiterate = cache[quotient_vertex] + + return subproblem, subalgorithm, subiterate +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + algorithm::AIE.NestedAlgorithm, + state::AIE.State, + substate::DaggerState, + ) + + dst_cache = state.iterate.iterate + + state.iterate.maxdiff = 0.0 + + current_algorithm = algorithm.algorithms[state.iteration] + + for (i, quotient_vertex) in enumerate(current_algorithm.keys) + get_maxdiff = dtask -> dtask.iterate.maxdiff + src_maxdiff = fetch(Dagger.@spawn get_maxdiff(substate.remote_results[i])) + + if src_maxdiff > state.iterate.maxdiff + state.iterate.maxdiff = src_maxdiff + end + end + + + transfer_edges! = (dst_chunk, src_chunk, edges) -> begin + src_subcache = src_chunk.iterate + dst_subcache = dst_chunk.iterate + for edge in edges + dst_subcache[edge] = src_subcache[edge] + end + end + + transfer_dtasks = map(quotientedges(dst_cache)) do quotient_edge + src_subcache = dst_cache[src(quotient_edge)] + dst_subcache = dst_cache[dst(quotient_edge)] + return Dagger.@spawn transfer_edges!(dst_subcache, fetch(src_subcache), edges(dst_cache, quotient_edge)) + end + + wait.(transfer_dtasks) + + return state +end diff --git a/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl b/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl new file mode 100644 index 0000000..c19db03 --- /dev/null +++ b/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl @@ -0,0 +1,84 @@ +module ITensorNetworksNextDistributedExt + +using Distributed + +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE + +import ITensorNetworksNext.ITensorNetworksNextParallel as Parallel + +function initialize_distributed_state( + problem::AIE.Problem, + algorithm::AIE.Algorithm; + keys, + iterate, + kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) + remote_results = Dict{eltype(keys), Distributed.Future}() + + return Parallel.DistributedState(; iterate, stopping_criterion_state, remote_results) +end + +function AI.initialize_state( + problem::AIE.Problem, + algorithm::Parallel.DistributedNestedAlgorithm; + kwargs... + ) + return initialize_distributed_state(problem, algorithm; keys = algorithm.keys, kwargs...) +end + +function Parallel.DistributedNestedAlgorithm(f::Function, iterable; kwargs...) + return Parallel.DistributedNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) +end + +function AIE.get_subproblem( + problem::AI.Problem, algorithm::Parallel.DistributedNestedAlgorithm, state::Parallel.DistributedState + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + + return subproblem, subalgorithm, state.iterate +end + +function AI.step!( + problem::AI.Problem, + algorithm::Parallel.DistributedNestedAlgorithm, + state::Parallel.DistributedState; + kwargs... + ) + + subproblem, subalgorithm, subiterate = AIE.get_subproblem(problem, algorithm, state) + + # Do whatever should have happened at `step!`, but store the result as a future. + + function solve(subproblem, subalgorithm, iterate) + rv = AI.solve(subproblem, subalgorithm; iterate) + return rv + end + + future = remotecall(solve, algorithm.workers, subproblem, subalgorithm, subiterate) + + AIE.set_substate!(problem, algorithm, state, future) + + return state +end + +function AIE.set_substate!( + ::AIE.Problem, + algorithm::Parallel.DistributedNestedAlgorithm, + state::Parallel.DistributedState, + future::Distributed.Future, + ) + key = algorithm.keys[state.iteration] + + state.remote_results[key] = future + + return state +end + +include("distributedbeliefpropagation.jl") + +end # ITensorNetworksNextDistributedExt diff --git a/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl b/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl new file mode 100644 index 0000000..3848e65 --- /dev/null +++ b/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl @@ -0,0 +1,116 @@ +using DataGraphs: DataGraphs, get_edge_data, get_vertex_data, is_edge_assigned, + is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph +using Graphs: AbstractEdge, AbstractGraph, edges, vertices +using ITensorNetworksNext.ITensorNetworksNextParallel: DistributedBeliefPropagationCache, + DistributedNestedAlgorithm, DistributedState, ITensorNetworksNextParallel, + distributed_algorithm +using ITensorNetworksNext: BeliefPropagation, BeliefPropagationCache, + BeliefPropagationProblem, BeliefPropagationState, ITensorNetworksNext, + beliefpropagation, forest_cover_edge_sequence, select_algorithm, setmessages! +using NamedGraphs.GraphsExtensions: boundary_edges +using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientvertices +using NamedGraphs: NamedGraphs + +function ITensorNetworksNextParallel.DistributedBeliefPropagationCache(network::AbstractGraph) + underlying_cache = BeliefPropagationCache(network) + return DistributedBeliefPropagationCache(underlying_cache) +end + +DataGraphs.underlying_graph(cache::DistributedBeliefPropagationCache) = underlying_graph(cache.underlying_cache) + +DataGraphs.is_vertex_assigned(bpc::DistributedBeliefPropagationCache, vertex) = is_vertex_assigned(bpc.underlying_cache, vertex) +DataGraphs.is_edge_assigned(bpc::DistributedBeliefPropagationCache, edge) = is_edge_assigned(bpc.undelying_cache, edge) + +DataGraphs.get_vertex_data(bpc::DistributedBeliefPropagationCache, vertex) = get_vertex_data(bpc.underlying_cache, vertex) +DataGraphs.get_edge_data(bpc::DistributedBeliefPropagationCache, edge::AbstractEdge) = get_edge_data(bpc.underlying_cache, edge) + +DataGraphs.set_vertex_data!(bpc::DistributedBeliefPropagationCache, val, vertex) = set_vertex_data!(bpc.underlying_cache, val, vertex) +DataGraphs.set_edge_data!(bpc::DistributedBeliefPropagationCache, val, edge) = set_edge_data!(bpc.underlying_cache, val, edge) + +NamedGraphs.to_graph_index(::DistributedBeliefPropagationCache, qv::QuotientVertex) = qv +function DataGraphs.get_index_data(cache::DistributedBeliefPropagationCache, qv::QuotientVertex) + return ITensorNetworksNextParallel.subcache(cache.underlying_cache, qv) +end +function ITensorNetworksNext.beliefpropagation_sweep( + cache::DistributedBeliefPropagationCache; edges, kwargs... + ) + + keys = collect(quotientvertices(cache)) + + return distributed_algorithm(keys; keys, workers = WorkerPool(workers())) do quotient_vertex + + subcache = cache[quotient_vertex] + subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges + incoming_edges = boundary_edges(cache, vertices(cache, quotient_vertex); dir = :in) + + alg = select_algorithm( + beliefpropagation, + subcache; + edges = setdiff(subcache_edges, incoming_edges), + maxiter = 1, + kwargs... + ) + + return alg + end +end + +function AI.initialize_state( + problem::AIE.Problem, + algorithm::BeliefPropagation{<:DistributedNestedAlgorithm}; + kwargs... + ) + + keys = first(algorithm.algorithms).keys + + return initialize_distributed_state(problem, algorithm; keys = keys, kwargs...) +end + +function AIE.get_subproblem( + problem::BeliefPropagationProblem, + algorithm::DistributedNestedAlgorithm, + state::DistributedState + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + + cache = state.iterate.iterate + + quotient_vertex = algorithm.keys[state.iteration] + subiterate = BeliefPropagationState(; iterate = cache[quotient_vertex]) + + return subproblem, subalgorithm, subiterate +end + +function AIE.set_substate!( + ::BeliefPropagationProblem, + ::AIE.NestedAlgorithm, + state::AIE.State, + substate::DistributedState, + ) + + dst_cache = state.iterate.iterate + + state.iterate.maxdiff = 0.0 + + for quotient_vertex in quotientvertices(dst_cache) + + src_state = fetch(substate.remote_results[quotient_vertex]).iterate + + src_cache = src_state.iterate + src_maxdiff = src_state.maxdiff + + incoming_edges = boundary_edges(dst_cache, vertices(dst_cache, quotient_vertex); dir = :in) + + updated_messages = setdiff(edges(src_cache), incoming_edges) + + setmessages!(dst_cache, src_cache, updated_messages) + + if src_maxdiff > state.iterate.maxdiff + state.iterate.maxdiff = src_maxdiff + end + + end + + return state +end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index e4f4a09..c5bf49e 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -20,4 +20,6 @@ include("beliefpropagation/abstractbeliefpropagationcache.jl") include("beliefpropagation/beliefpropagationcache.jl") include("beliefpropagation/beliefpropagationproblem.jl") +include("ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl") + end diff --git a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl new file mode 100644 index 0000000..1131ca3 --- /dev/null +++ b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl @@ -0,0 +1,27 @@ +module ITensorNetworksNextParallel + +using Graphs: neighbors, add_vertex!, vertices +using NamedGraphs.GraphsExtensions: subgraph +using NamedGraphs.PartitionedGraphs: QuotientVertex +using ..ITensorNetworksNext: BeliefPropagationCache + +subcache(cache::BeliefPropagationCache, vertex::QuotientVertex) = subcache(cache, vertices(cache, vertex)) +function subcache(cache::BeliefPropagationCache, vertices) + subcache = subgraph(cache, vertices) + + for vertex in vertices + for neighbor_vertex in neighbors(cache, vertex) + add_vertex!(subcache, neighbor_vertex) + # Add in necessary messages. + subcache[vertex => neighbor_vertex] = cache[vertex => neighbor_vertex] + subcache[neighbor_vertex => vertex] = cache[neighbor_vertex => vertex] + end + end + + return subcache +end + +include("distributed.jl") +include("dagger.jl") + +end # ITensorNetworksNextParallel diff --git a/src/ITensorNetworksNextParallel/dagger.jl b/src/ITensorNetworksNextParallel/dagger.jl new file mode 100644 index 0000000..eef37bc --- /dev/null +++ b/src/ITensorNetworksNextParallel/dagger.jl @@ -0,0 +1,38 @@ +import AlgorithmsInterface as AI +import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +using ITensorNetworksNext: AbstractBeliefPropagationCache + +@kwdef mutable struct DaggerState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, Chunk, DTask, + } <: AIE.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState + remote_subiterates::Dict{Int, Chunk} = Dict{Int, Any}() + remote_results::Dict{Int, DTask} = Dict{Int, Any}() +end + +@kwdef struct DaggerNestedAlgorithm{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + KeyType, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) + workers::Vector{Int} + keys::Vector{KeyType} = collect(1:length(algorithms)) +end + +function dagger_algorithm(f::Function, iterable; kwargs...) + return DaggerNestedAlgorithm(f, iterable; kwargs...) +end + +# ================================== belief propagation ================================== # + +struct DaggerBeliefPropagationCache{ + V, VD, ED, UC <: AbstractBeliefPropagationCache{V, VD, ED}, Chunks, + } <: AbstractBeliefPropagationCache{V, VD, ED} + underlying_cache::UC + quotient_chunks::Chunks +end diff --git a/src/ITensorNetworksNextParallel/distributed.jl b/src/ITensorNetworksNextParallel/distributed.jl new file mode 100644 index 0000000..01c1344 --- /dev/null +++ b/src/ITensorNetworksNextParallel/distributed.jl @@ -0,0 +1,38 @@ +import AlgorithmsInterface as AI +import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE + +using ..ITensorNetworksNext: AbstractBeliefPropagationCache + +@kwdef mutable struct DistributedState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, Future, KeyType, + } <: AIE.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState + remote_results::Dict{KeyType, Future} = Dict{Int, Any}() +end + +@kwdef struct DistributedNestedAlgorithm{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + WorkerPool, + KeyType, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) + workers::WorkerPool + keys::Vector{KeyType} = collect(1:length(algorithms)) +end + +function distributed_algorithm(f::Function, iterable; kwargs...) + return DistributedNestedAlgorithm(f, iterable; kwargs...) +end + +# ================================== belief propagation ================================== # + +struct DistributedBeliefPropagationCache{ + V, VD, ED, UC <: AbstractBeliefPropagationCache{V, VD, ED}, + } <: AbstractBeliefPropagationCache{V, VD, ED} + underlying_cache::UC +end From 13f717cbf33f19c8ee902eb1c8dd08dfa16d6b8a Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 26 Feb 2026 15:01:50 -0500 Subject: [PATCH 55/64] Remove basic `Distributed.jl` implementation. --- .../ITensorNetworksNextDaggerExt.jl | 60 ++++----- .../daggerbeliefpropagation.jl | 115 +++++++++-------- .../ITensorNetworksNextDistributedExt.jl | 84 ------------- .../distributedbeliefpropagation.jl | 116 ------------------ .../ITensorNetworksNextParallel.jl | 26 ++-- src/ITensorNetworksNextParallel/dagger.jl | 6 +- .../distributed.jl | 38 ------ src/abstracttensornetwork.jl | 1 + .../abstractbeliefpropagationcache.jl | 20 ++- 9 files changed, 118 insertions(+), 348 deletions(-) delete mode 100644 ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl delete mode 100644 ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl delete mode 100644 src/ITensorNetworksNextParallel/distributed.jl diff --git a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl index b5e2c80..95c4105 100644 --- a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl +++ b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl @@ -1,37 +1,33 @@ module ITensorNetworksNextDaggerExt -using Dagger -using Dagger.Distributed -using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerNestedAlgorithm, DaggerState, - ITensorNetworksNextParallel - import AlgorithmsInterface as AI import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP +using Dagger +using ITensorNetworksNext.ITensorNetworksNextParallel: + DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel - -function ITensorNetworksNextParallel.DaggerNestedAlgorithm(f::Function, iterable; workers = workers(), kwargs...) +function ITNNP.DaggerNestedAlgorithm(f::Function, iterable; workers = workers(), kwargs...) return DaggerNestedAlgorithm(; algorithms = map(f, iterable), workers, kwargs...) end -function initialize_dagger_state( - problem::AIE.Problem, - algorithm::AIE.Algorithm; - iterate, - remote_subiterates = Dict{Int, Dagger.Chunk}(), - ) - +function initialize_dagger_state(problem::AIE.Problem, algorithm::AIE.Algorithm; iterate) stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion ) remote_results = Dict{Int, Dagger.DTask}() - return DaggerState(; iterate, remote_subiterates, stopping_criterion_state, remote_results) + return ITNNP.DaggerState(; + iterate, + remote_results, + stopping_criterion_state + ) end function AI.initialize_state( problem::AIE.Problem, - algorithm::DaggerNestedAlgorithm; + algorithm::ITNNP.DaggerNestedAlgorithm; kwargs... ) return initialize_dagger_state(problem, algorithm; kwargs...) @@ -39,43 +35,31 @@ end function AIE.get_subproblem( problem::AIE.Problem, - algorithm::AIE.NestedAlgorithm, - state::DaggerState + algorithm::ITNNP.DaggerNestedAlgorithm, + state::ITNNP.DaggerState ) subproblem = problem subalgorithm = algorithm.algorithms[state.iteration] - iterate = state.iterate - remote_subiterates = state.remote_subiterates + # This might be a Dagger.chun object. + iterate = ITNNP.get_subiterate(subproblem, subalgorithm, state) - substate = AI.initialize_state(subproblem, subalgorithm; iterate, remote_subiterates) + substate = Dagger.@spawn AI.initialize_state(subproblem, subalgorithm; iterate) return subproblem, subalgorithm, substate end - function AI.step!( problem::AI.Problem, - algorithm::DaggerNestedAlgorithm, - state::DaggerState; + algorithm::ITNNP.DaggerNestedAlgorithm, + state::ITNNP.DaggerState; kwargs... ) + subproblem, subalgorithm, substate_future = + AIE.get_subproblem(problem, algorithm, state) - subproblem, subalgorithm, subiterate_chunk = AIE.get_subproblem(problem, algorithm, state) - - dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm; iterate = subiterate_chunk) - - AIE.set_substate!(problem, algorithm, state, dtask) + dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm, substate_future) - return state -end - -function AIE.set_substate!( - ::AIE.Problem, - ::DaggerNestedAlgorithm, - state::DaggerState, - dtask::Dagger.DTask, - ) state.remote_results[state.iteration] = dtask return state diff --git a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl index a4acdce..8029fdd 100644 --- a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl +++ b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl @@ -1,30 +1,22 @@ +import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP using Dagger -using Dagger.Distributed - using DataGraphs: DataGraphs, get_edge_data, get_vertex_data, is_edge_assigned, is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph using Dictionaries: Indices using Graphs: AbstractEdge, AbstractGraph, dst, edges, src, vertices -using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerBeliefPropagationCache, - DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel, dagger_algorithm, - subcache -using ITensorNetworksNext: BeliefPropagation, BeliefPropagationCache, - BeliefPropagationProblem, BeliefPropagationState, ITensorNetworksNext, - beliefpropagation, forest_cover_edge_sequence, select_algorithm +using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagation, BeliefPropagationCache, + BeliefPropagationProblem, BeliefPropagationState, beliefpropagation, + forest_cover_edge_sequence, select_algorithm +using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientedges, quotientvertices using NamedGraphs: NamedGraphs -using NamedGraphs.GraphsExtensions: boundary_edges -function ITensorNetworksNextParallel.subcache(cache::DaggerBeliefPropagationCache, inds) - return subcache(cache.underlying_cache, inds) -end - -function ITensorNetworksNextParallel.DaggerBeliefPropagationCache(network::AbstractGraph) +function ITNNP.DaggerBeliefPropagationCache(network::AbstractGraph) underlying_cache = BeliefPropagationCache(network) keys = Indices(quotientvertices(underlying_cache)) - workers = Iterators.cycle(Distributed.workers()) + workers = Iterators.cycle(Dagger.Distributed.workers()) worker_dict = similar(keys, Int) for quotient_vertex in keys @@ -39,31 +31,54 @@ function ITensorNetworksNextParallel.DaggerBeliefPropagationCache(network::Abstr return chunk end - return DaggerBeliefPropagationCache(underlying_cache, quotient_chunks) + return ITNNP.DaggerBeliefPropagationCache(underlying_cache, quotient_chunks) end -DataGraphs.underlying_graph(cache::DaggerBeliefPropagationCache) = underlying_graph(cache.underlying_cache) +function DataGraphs.underlying_graph(cache::ITNNP.DaggerBeliefPropagationCache) + return underlying_graph(cache.underlying_cache) +end -DataGraphs.is_vertex_assigned(bpc::DaggerBeliefPropagationCache, vertex) = is_vertex_assigned(bpc.underlying_cache, vertex) -DataGraphs.is_edge_assigned(bpc::DaggerBeliefPropagationCache, edge) = is_edge_assigned(bpc.undelying_cache, edge) +function DataGraphs.is_vertex_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, vertex) + return is_vertex_assigned(bpc.underlying_cache, vertex) +end +function DataGraphs.is_edge_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, edge) + return is_edge_assigned(bpc.undelying_cache, edge) +end -DataGraphs.get_vertex_data(bpc::DaggerBeliefPropagationCache, vertex) = get_vertex_data(bpc.underlying_cache, vertex) -DataGraphs.get_edge_data(bpc::DaggerBeliefPropagationCache, edge::AbstractEdge) = get_edge_data(bpc.undelying_caches, edge) +function DataGraphs.get_vertex_data(bpc::ITNNP.DaggerBeliefPropagationCache, vertex) + return get_vertex_data(bpc.underlying_cache, vertex) +end +function DataGraphs.get_edge_data( + bpc::ITNNP.DaggerBeliefPropagationCache, + edge::AbstractEdge + ) + return get_edge_data(bpc.undelying_caches, edge) +end -DataGraphs.set_vertex_data!(bpc::DaggerBeliefPropagationCache, val, vertex) = set_vertex_data!(bpc.underlying_cache, val, vertex) -DataGraphs.set_edge_data!(bpc::DaggerBeliefPropagationCache, val, edge) = set_edge_data!(bpc.underlying_cache, val, edge) +function DataGraphs.set_vertex_data!(bpc::ITNNP.DaggerBeliefPropagationCache, val, vertex) + return set_vertex_data!(bpc.underlying_cache, val, vertex) +end +function DataGraphs.set_edge_data!(bpc::ITNNP.DaggerBeliefPropagationCache, val, edge) + return set_edge_data!(bpc.underlying_cache, val, edge) +end -NamedGraphs.to_graph_index(::DaggerBeliefPropagationCache, qv::QuotientVertex) = qv -function DataGraphs.get_index_data(cache::DaggerBeliefPropagationCache, qv::QuotientVertex) +NamedGraphs.to_graph_index(::ITNNP.DaggerBeliefPropagationCache, qv::QuotientVertex) = qv +function DataGraphs.get_index_data( + cache::ITNNP.DaggerBeliefPropagationCache, + qv::QuotientVertex + ) return cache.quotient_chunks[qv] end -function ITensorNetworksNext.beliefpropagation_sweep(cache::DaggerBeliefPropagationCache; edges, workers = workers(), kwargs...) - +function ITensorNetworksNext.beliefpropagation_sweep( + cache::ITNNP.DaggerBeliefPropagationCache; + edges, + workers = workers(), + kwargs... + ) keys = collect(quotientvertices(cache)) - return dagger_algorithm(keys; keys, workers) do quotient_vertex - + return ITNNP.dagger_algorithm(keys; keys, workers) do quotient_vertex subcache = fetch(cache[quotient_vertex]).iterate subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges @@ -84,64 +99,62 @@ end function AI.initialize_state( problem::AIE.Problem, - algorithm::BeliefPropagation{<:DaggerNestedAlgorithm}; + algorithm::BeliefPropagation{<:ITNNP.DaggerNestedAlgorithm}; kwargs... ) - return initialize_dagger_state(problem, algorithm; kwargs...) + return ITNNP.initialize_dagger_state(problem, algorithm; kwargs...) end -function AIE.get_subproblem( - problem::BeliefPropagationProblem, - algorithm::DaggerNestedAlgorithm, - state::DaggerState, +function ITNNP.get_subiterate( + ::BeliefPropagationProblem, + ::BeliefPropagation, # Our parallel region runs a small BP + state::ITNNP.DaggerState ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] - - quotient_vertex = algorithm.keys[state.iteration] - cache = state.iterate.iterate + quotient_vertex = collect(quotientvertices(cache))[state.iteration] + subiterate = cache[quotient_vertex] - return subproblem, subalgorithm, subiterate + return subiterate end function AIE.set_substate!( ::BeliefPropagationProblem, - algorithm::AIE.NestedAlgorithm, + ::AIE.NestedAlgorithm, state::AIE.State, - substate::DaggerState, + substate::ITNNP.DaggerState ) - dst_cache = state.iterate.iterate state.iterate.maxdiff = 0.0 - current_algorithm = algorithm.algorithms[state.iteration] - - for (i, quotient_vertex) in enumerate(current_algorithm.keys) + for remote_result in substate.remote_results get_maxdiff = dtask -> dtask.iterate.maxdiff - src_maxdiff = fetch(Dagger.@spawn get_maxdiff(substate.remote_results[i])) + src_maxdiff = fetch(Dagger.@spawn get_maxdiff(remote_result)) if src_maxdiff > state.iterate.maxdiff state.iterate.maxdiff = src_maxdiff end end - - transfer_edges! = (dst_chunk, src_chunk, edges) -> begin + function transfer_edges!(dst_chunk, src_chunk, edges) src_subcache = src_chunk.iterate dst_subcache = dst_chunk.iterate for edge in edges dst_subcache[edge] = src_subcache[edge] end + return end transfer_dtasks = map(quotientedges(dst_cache)) do quotient_edge src_subcache = dst_cache[src(quotient_edge)] dst_subcache = dst_cache[dst(quotient_edge)] - return Dagger.@spawn transfer_edges!(dst_subcache, fetch(src_subcache), edges(dst_cache, quotient_edge)) + return Dagger.@spawn transfer_edges!( + dst_subcache, + fetch(src_subcache), + edges(dst_cache, quotient_edge) + ) end wait.(transfer_dtasks) diff --git a/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl b/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl deleted file mode 100644 index c19db03..0000000 --- a/ext/ITensorNetworksNextDistributedExt/ITensorNetworksNextDistributedExt.jl +++ /dev/null @@ -1,84 +0,0 @@ -module ITensorNetworksNextDistributedExt - -using Distributed - -import AlgorithmsInterface as AI -import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE - -import ITensorNetworksNext.ITensorNetworksNextParallel as Parallel - -function initialize_distributed_state( - problem::AIE.Problem, - algorithm::AIE.Algorithm; - keys, - iterate, - kwargs... - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - remote_results = Dict{eltype(keys), Distributed.Future}() - - return Parallel.DistributedState(; iterate, stopping_criterion_state, remote_results) -end - -function AI.initialize_state( - problem::AIE.Problem, - algorithm::Parallel.DistributedNestedAlgorithm; - kwargs... - ) - return initialize_distributed_state(problem, algorithm; keys = algorithm.keys, kwargs...) -end - -function Parallel.DistributedNestedAlgorithm(f::Function, iterable; kwargs...) - return Parallel.DistributedNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) -end - -function AIE.get_subproblem( - problem::AI.Problem, algorithm::Parallel.DistributedNestedAlgorithm, state::Parallel.DistributedState - ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] - - return subproblem, subalgorithm, state.iterate -end - -function AI.step!( - problem::AI.Problem, - algorithm::Parallel.DistributedNestedAlgorithm, - state::Parallel.DistributedState; - kwargs... - ) - - subproblem, subalgorithm, subiterate = AIE.get_subproblem(problem, algorithm, state) - - # Do whatever should have happened at `step!`, but store the result as a future. - - function solve(subproblem, subalgorithm, iterate) - rv = AI.solve(subproblem, subalgorithm; iterate) - return rv - end - - future = remotecall(solve, algorithm.workers, subproblem, subalgorithm, subiterate) - - AIE.set_substate!(problem, algorithm, state, future) - - return state -end - -function AIE.set_substate!( - ::AIE.Problem, - algorithm::Parallel.DistributedNestedAlgorithm, - state::Parallel.DistributedState, - future::Distributed.Future, - ) - key = algorithm.keys[state.iteration] - - state.remote_results[key] = future - - return state -end - -include("distributedbeliefpropagation.jl") - -end # ITensorNetworksNextDistributedExt diff --git a/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl b/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl deleted file mode 100644 index 3848e65..0000000 --- a/ext/ITensorNetworksNextDistributedExt/distributedbeliefpropagation.jl +++ /dev/null @@ -1,116 +0,0 @@ -using DataGraphs: DataGraphs, get_edge_data, get_vertex_data, is_edge_assigned, - is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph -using Graphs: AbstractEdge, AbstractGraph, edges, vertices -using ITensorNetworksNext.ITensorNetworksNextParallel: DistributedBeliefPropagationCache, - DistributedNestedAlgorithm, DistributedState, ITensorNetworksNextParallel, - distributed_algorithm -using ITensorNetworksNext: BeliefPropagation, BeliefPropagationCache, - BeliefPropagationProblem, BeliefPropagationState, ITensorNetworksNext, - beliefpropagation, forest_cover_edge_sequence, select_algorithm, setmessages! -using NamedGraphs.GraphsExtensions: boundary_edges -using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientvertices -using NamedGraphs: NamedGraphs - -function ITensorNetworksNextParallel.DistributedBeliefPropagationCache(network::AbstractGraph) - underlying_cache = BeliefPropagationCache(network) - return DistributedBeliefPropagationCache(underlying_cache) -end - -DataGraphs.underlying_graph(cache::DistributedBeliefPropagationCache) = underlying_graph(cache.underlying_cache) - -DataGraphs.is_vertex_assigned(bpc::DistributedBeliefPropagationCache, vertex) = is_vertex_assigned(bpc.underlying_cache, vertex) -DataGraphs.is_edge_assigned(bpc::DistributedBeliefPropagationCache, edge) = is_edge_assigned(bpc.undelying_cache, edge) - -DataGraphs.get_vertex_data(bpc::DistributedBeliefPropagationCache, vertex) = get_vertex_data(bpc.underlying_cache, vertex) -DataGraphs.get_edge_data(bpc::DistributedBeliefPropagationCache, edge::AbstractEdge) = get_edge_data(bpc.underlying_cache, edge) - -DataGraphs.set_vertex_data!(bpc::DistributedBeliefPropagationCache, val, vertex) = set_vertex_data!(bpc.underlying_cache, val, vertex) -DataGraphs.set_edge_data!(bpc::DistributedBeliefPropagationCache, val, edge) = set_edge_data!(bpc.underlying_cache, val, edge) - -NamedGraphs.to_graph_index(::DistributedBeliefPropagationCache, qv::QuotientVertex) = qv -function DataGraphs.get_index_data(cache::DistributedBeliefPropagationCache, qv::QuotientVertex) - return ITensorNetworksNextParallel.subcache(cache.underlying_cache, qv) -end -function ITensorNetworksNext.beliefpropagation_sweep( - cache::DistributedBeliefPropagationCache; edges, kwargs... - ) - - keys = collect(quotientvertices(cache)) - - return distributed_algorithm(keys; keys, workers = WorkerPool(workers())) do quotient_vertex - - subcache = cache[quotient_vertex] - subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges - incoming_edges = boundary_edges(cache, vertices(cache, quotient_vertex); dir = :in) - - alg = select_algorithm( - beliefpropagation, - subcache; - edges = setdiff(subcache_edges, incoming_edges), - maxiter = 1, - kwargs... - ) - - return alg - end -end - -function AI.initialize_state( - problem::AIE.Problem, - algorithm::BeliefPropagation{<:DistributedNestedAlgorithm}; - kwargs... - ) - - keys = first(algorithm.algorithms).keys - - return initialize_distributed_state(problem, algorithm; keys = keys, kwargs...) -end - -function AIE.get_subproblem( - problem::BeliefPropagationProblem, - algorithm::DistributedNestedAlgorithm, - state::DistributedState - ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] - - cache = state.iterate.iterate - - quotient_vertex = algorithm.keys[state.iteration] - subiterate = BeliefPropagationState(; iterate = cache[quotient_vertex]) - - return subproblem, subalgorithm, subiterate -end - -function AIE.set_substate!( - ::BeliefPropagationProblem, - ::AIE.NestedAlgorithm, - state::AIE.State, - substate::DistributedState, - ) - - dst_cache = state.iterate.iterate - - state.iterate.maxdiff = 0.0 - - for quotient_vertex in quotientvertices(dst_cache) - - src_state = fetch(substate.remote_results[quotient_vertex]).iterate - - src_cache = src_state.iterate - src_maxdiff = src_state.maxdiff - - incoming_edges = boundary_edges(dst_cache, vertices(dst_cache, quotient_vertex); dir = :in) - - updated_messages = setdiff(edges(src_cache), incoming_edges) - - setmessages!(dst_cache, src_cache, updated_messages) - - if src_maxdiff > state.iterate.maxdiff - state.iterate.maxdiff = src_maxdiff - end - - end - - return state -end diff --git a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl index 1131ca3..8e90670 100644 --- a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl +++ b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl @@ -1,27 +1,19 @@ module ITensorNetworksNextParallel -using Graphs: neighbors, add_vertex!, vertices +using ..ITensorNetworksNext: BeliefPropagationCache +using Graphs: add_vertex!, neighbors, vertices using NamedGraphs.GraphsExtensions: subgraph using NamedGraphs.PartitionedGraphs: QuotientVertex -using ..ITensorNetworksNext: BeliefPropagationCache - -subcache(cache::BeliefPropagationCache, vertex::QuotientVertex) = subcache(cache, vertices(cache, vertex)) -function subcache(cache::BeliefPropagationCache, vertices) - subcache = subgraph(cache, vertices) - for vertex in vertices - for neighbor_vertex in neighbors(cache, vertex) - add_vertex!(subcache, neighbor_vertex) - # Add in necessary messages. - subcache[vertex => neighbor_vertex] = cache[vertex => neighbor_vertex] - subcache[neighbor_vertex => vertex] = cache[neighbor_vertex => vertex] - end - end +""" + get_subiterate(subproblem::AI.Problem, subalgorithm::AI.Algorithm, state::AI.State) - return subcache -end +For a given `subproblem` and `subalgorithm` of a parent nested algorithm, +derive (from the parent state `state`) the iterate to be used in the associated sub state. +The returned value of this function is then pass to a remote call of `initialize_state`. +""" +get_subiterate(::AI.Problem, ::AI.Algorithm, state::AI.State) = state.iterate -include("distributed.jl") include("dagger.jl") end # ITensorNetworksNextParallel diff --git a/src/ITensorNetworksNextParallel/dagger.jl b/src/ITensorNetworksNextParallel/dagger.jl index eef37bc..8f79de5 100644 --- a/src/ITensorNetworksNextParallel/dagger.jl +++ b/src/ITensorNetworksNextParallel/dagger.jl @@ -1,14 +1,14 @@ -import AlgorithmsInterface as AI import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI using ITensorNetworksNext: AbstractBeliefPropagationCache @kwdef mutable struct DaggerState{ Iterate, StoppingCriterionState <: AI.StoppingCriterionState, Chunk, DTask, } <: AIE.State - iterate::Iterate + iterate::Iterate # DaggerBeliefPropagationCache iteration::Int = 0 stopping_criterion_state::StoppingCriterionState - remote_subiterates::Dict{Int, Chunk} = Dict{Int, Any}() + # remote_subiterates::Dict{Int, Chunk} = Dict{Int, Any}() remote_results::Dict{Int, DTask} = Dict{Int, Any}() end diff --git a/src/ITensorNetworksNextParallel/distributed.jl b/src/ITensorNetworksNextParallel/distributed.jl deleted file mode 100644 index 01c1344..0000000 --- a/src/ITensorNetworksNextParallel/distributed.jl +++ /dev/null @@ -1,38 +0,0 @@ -import AlgorithmsInterface as AI -import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE - -using ..ITensorNetworksNext: AbstractBeliefPropagationCache - -@kwdef mutable struct DistributedState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, Future, KeyType, - } <: AIE.State - iterate::Iterate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState - remote_results::Dict{KeyType, Future} = Dict{Int, Any}() -end - -@kwdef struct DistributedNestedAlgorithm{ - ChildAlgorithm <: AIE.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - WorkerPool, - KeyType, - } <: AIE.NestedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) - workers::WorkerPool - keys::Vector{KeyType} = collect(1:length(algorithms)) -end - -function distributed_algorithm(f::Function, iterable; kwargs...) - return DistributedNestedAlgorithm(f, iterable; kwargs...) -end - -# ================================== belief propagation ================================== # - -struct DistributedBeliefPropagationCache{ - V, VD, ED, UC <: AbstractBeliefPropagationCache{V, VD, ED}, - } <: AbstractBeliefPropagationCache{V, VD, ED} - underlying_cache::UC -end diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 121073d..802e8b8 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -34,6 +34,7 @@ Base.copy(::AbstractTensorNetwork) = not_implemented() # Iteration Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) + Base.keys(tn::AbstractTensorNetwork) = vertices(tn) # TODO: This contrasts with the `DataGraphs.AbstractDataGraph` definition, diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl index 33f185b..a810c9a 100644 --- a/src/beliefpropagation/abstractbeliefpropagationcache.jl +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -1,7 +1,7 @@ using DataGraphs: AbstractDataGraph, edge_data, edge_data_type, vertex_data using Graphs: AbstractEdge, AbstractGraph using NamedGraphs.GraphsExtensions: boundary_edges -using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent, QuotientVertex messages(bp_cache::AbstractGraph) = edge_data(bp_cache) messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] @@ -141,3 +141,21 @@ function free_energy(bp_cache::AbstractBeliefPropagationCache) return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) end partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) + +function subcache(cache::AbstractBeliefPropagationCache, vertex::QuotientVertex) + return subcache(cache, vertices(cache, vertex)) +end +function subcache(cache::AbstractBeliefPropagationCache, vertices) + subcache = subgraph(cache, vertices) + + for vertex in vertices + for neighbor_vertex in neighbors(cache, vertex) + add_vertex!(subcache, neighbor_vertex) + # Add in necessary messages. + subcache[vertex => neighbor_vertex] = cache[vertex => neighbor_vertex] + subcache[neighbor_vertex => vertex] = cache[neighbor_vertex => vertex] + end + end + + return subcache +end From be1bfb1b095f4d0c168a4c71ee469a0bc15e32ee Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 26 Feb 2026 15:06:03 -0500 Subject: [PATCH 56/64] Fix imports. --- .../ITensorNetworksNextParallel.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl index 8e90670..7117641 100644 --- a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl +++ b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl @@ -1,9 +1,6 @@ module ITensorNetworksNextParallel -using ..ITensorNetworksNext: BeliefPropagationCache -using Graphs: add_vertex!, neighbors, vertices -using NamedGraphs.GraphsExtensions: subgraph -using NamedGraphs.PartitionedGraphs: QuotientVertex +import AlgorithmsInterface as AI """ get_subiterate(subproblem::AI.Problem, subalgorithm::AI.Algorithm, state::AI.State) From e715e15fa1b7fdd64808060cf6d7a49fa2f04222 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 26 Feb 2026 15:06:41 -0500 Subject: [PATCH 57/64] Fix Project.toml --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 435e87a..33a0a3f 100644 --- a/Project.toml +++ b/Project.toml @@ -30,12 +30,10 @@ WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [weakdeps] TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" -Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" [extensions] ITensorNetworksNextTensorOperationsExt = "TensorOperations" -ITensorNetworksNextDistributedExt = "Distributed" ITensorNetworksNextDaggerExt = "Dagger" [compat] From c1c73a3f2f96a46c9f6f14fdc547889d14bdeaab Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 2 Mar 2026 14:07:59 -0500 Subject: [PATCH 58/64] Simplify parallel code. --- .../ITensorNetworksNextDaggerExt.jl | 41 +++---- .../daggerbeliefpropagation.jl | 105 ++++++++++++------ .../ITensorNetworksNextParallel.jl | 23 ++++ src/ITensorNetworksNextParallel/dagger.jl | 27 +++-- .../beliefpropagationproblem.jl | 2 +- 5 files changed, 127 insertions(+), 71 deletions(-) diff --git a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl index 95c4105..33375c7 100644 --- a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl +++ b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl @@ -6,17 +6,24 @@ import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP using Dagger using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel +using Dictionaries: set! -function ITNNP.DaggerNestedAlgorithm(f::Function, iterable; workers = workers(), kwargs...) - return DaggerNestedAlgorithm(; algorithms = map(f, iterable), workers, kwargs...) +function ITNNP.DaggerNestedAlgorithm(f, iterable; kwargs...) + return DaggerNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) end -function initialize_dagger_state(problem::AIE.Problem, algorithm::AIE.Algorithm; iterate) +function ITNNP.dagger_algorithm(f::Base.Callable, iterable; kwargs...) + return DaggerNestedAlgorithm(f, iterable; kwargs...) +end + +function ITNNP.initialize_dagger_state( + problem::AIE.Problem, algorithm::AIE.Algorithm; iterate + ) stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion ) - remote_results = Dict{Int, Dagger.DTask}() + remote_results = Dictionary{Int, Dagger.DTask}() return ITNNP.DaggerState(; iterate, @@ -30,37 +37,23 @@ function AI.initialize_state( algorithm::ITNNP.DaggerNestedAlgorithm; kwargs... ) - return initialize_dagger_state(problem, algorithm; kwargs...) + return ITNNP.initialize_dagger_state(problem, algorithm; kwargs...) end -function AIE.get_subproblem( +function AI.step!( problem::AIE.Problem, algorithm::ITNNP.DaggerNestedAlgorithm, - state::ITNNP.DaggerState + state::ITNNP.DaggerState; + kwargs... ) subproblem = problem subalgorithm = algorithm.algorithms[state.iteration] - # This might be a Dagger.chun object. iterate = ITNNP.get_subiterate(subproblem, subalgorithm, state) - substate = Dagger.@spawn AI.initialize_state(subproblem, subalgorithm; iterate) - - return subproblem, subalgorithm, substate -end - -function AI.step!( - problem::AI.Problem, - algorithm::ITNNP.DaggerNestedAlgorithm, - state::ITNNP.DaggerState; - kwargs... - ) - subproblem, subalgorithm, substate_future = - AIE.get_subproblem(problem, algorithm, state) - - dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm, substate_future) + dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm; iterate) - state.remote_results[state.iteration] = dtask + set!(state.remote_results, state.iteration, dtask) return state end diff --git a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl index 8029fdd..547b6fb 100644 --- a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl +++ b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl @@ -1,33 +1,49 @@ import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP using Dagger -using DataGraphs: DataGraphs, get_edge_data, get_vertex_data, is_edge_assigned, +using DataGraphs: DataGraphs, edge_data, get_edge_data, get_vertex_data, is_edge_assigned, is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph -using Dictionaries: Indices +using Dictionaries: Dictionary, Indices, getindices using Graphs: AbstractEdge, AbstractGraph, dst, edges, src, vertices using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagation, BeliefPropagationCache, BeliefPropagationProblem, BeliefPropagationState, beliefpropagation, - forest_cover_edge_sequence, select_algorithm + forest_cover_edge_sequence, select_algorithm, subcache using NamedGraphs.GraphsExtensions: boundary_edges using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientedges, quotientvertices using NamedGraphs: NamedGraphs -function ITNNP.DaggerBeliefPropagationCache(network::AbstractGraph) +const DaggerBeliefPropagation = BeliefPropagation{<:ITNNP.DaggerNestedAlgorithm}; + +function ITNNP.DaggerBeliefPropagationCache( + network::AbstractGraph; + workers = nothing, + scopes = nothing + ) underlying_cache = BeliefPropagationCache(network) keys = Indices(quotientvertices(underlying_cache)) - workers = Iterators.cycle(Dagger.Distributed.workers()) - worker_dict = similar(keys, Int) + if isnothing(scopes) + workers = isnothing(workers) ? Dagger.Distributed.workers() : workers - for quotient_vertex in keys - worker, workers = Iterators.peel(workers) - worker_dict[quotient_vertex] = worker + sorted_workers = Iterators.take(Iterators.cycle(workers), length(keys)) + + scopes = map(Dagger.ProcessScope, collect(sorted_workers)) + else + if length(keys) != length(scopes) + throw( + ArgumentError( + "Number of provided scopes must match the number of vertex partitions of underlying graph" + ) + ) + end end + scope_dict = Dictionary(keys, scopes) + quotient_chunks = map(keys) do quotient_vertex - worker = worker_dict[quotient_vertex] + scope = scope_dict[quotient_vertex] iterate = subcache(underlying_cache, quotient_vertex) - chunk = Dagger.@mutable worker = worker BeliefPropagationState(; iterate) + chunk = Dagger.@mutable scope = scope BeliefPropagationState(; iterate) return chunk end @@ -42,7 +58,7 @@ function DataGraphs.is_vertex_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, return is_vertex_assigned(bpc.underlying_cache, vertex) end function DataGraphs.is_edge_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, edge) - return is_edge_assigned(bpc.undelying_cache, edge) + return is_edge_assigned(bpc.underlying_cache, edge) end function DataGraphs.get_vertex_data(bpc::ITNNP.DaggerBeliefPropagationCache, vertex) @@ -52,7 +68,7 @@ function DataGraphs.get_edge_data( bpc::ITNNP.DaggerBeliefPropagationCache, edge::AbstractEdge ) - return get_edge_data(bpc.undelying_caches, edge) + return get_edge_data(bpc.underlying_cache, edge) end function DataGraphs.set_vertex_data!(bpc::ITNNP.DaggerBeliefPropagationCache, val, vertex) @@ -73,13 +89,11 @@ end function ITensorNetworksNext.beliefpropagation_sweep( cache::ITNNP.DaggerBeliefPropagationCache; edges, - workers = workers(), kwargs... ) - keys = collect(quotientvertices(cache)) - - return ITNNP.dagger_algorithm(keys; keys, workers) do quotient_vertex - subcache = fetch(cache[quotient_vertex]).iterate + return ITNNP.dagger_algorithm(quotientvertices(cache)) do quotient_vertex + substate = fetch(cache[quotient_vertex]) + subcache = substate.iterate subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges incoming_edges = boundary_edges(cache, vertices(cache, quotient_vertex); dir = :in) @@ -120,8 +134,8 @@ function ITNNP.get_subiterate( end function AIE.set_substate!( - ::BeliefPropagationProblem, - ::AIE.NestedAlgorithm, + problem::BeliefPropagationProblem, + algorithm::AIE.NestedAlgorithm, state::AIE.State, substate::ITNNP.DaggerState ) @@ -129,35 +143,52 @@ function AIE.set_substate!( state.iterate.maxdiff = 0.0 - for remote_result in substate.remote_results - get_maxdiff = dtask -> dtask.iterate.maxdiff - src_maxdiff = fetch(Dagger.@spawn get_maxdiff(remote_result)) - - if src_maxdiff > state.iterate.maxdiff - state.iterate.maxdiff = src_maxdiff - end + maxdiff_dtasks = map(substate.remote_results) do remote_result + return Dagger.spawn(dtask -> dtask.iterate.maxdiff, remote_result) end - function transfer_edges!(dst_chunk, src_chunk, edges) - src_subcache = src_chunk.iterate - dst_subcache = dst_chunk.iterate - for edge in edges - dst_subcache[edge] = src_subcache[edge] - end - return + maxdiff = maximum(fetch, maxdiff_dtasks) + + if maxdiff > state.iterate.maxdiff + state.iterate.maxdiff = maxdiff end transfer_dtasks = map(quotientedges(dst_cache)) do quotient_edge src_subcache = dst_cache[src(quotient_edge)] dst_subcache = dst_cache[dst(quotient_edge)] - return Dagger.@spawn transfer_edges!( + + src_subcache = fetch(src_subcache) + + return Dagger.spawn( dst_subcache, fetch(src_subcache), edges(dst_cache, quotient_edge) - ) + ) do dst, src, edges + src_subcache = src.iterate + dst_subcache = dst.iterate + for edge in edges + dst_subcache[edge] = src_subcache[edge] + end + end end - wait.(transfer_dtasks) + foreach(wait, transfer_dtasks) + + return state +end + +function ITNNP.finalize_state!( + ::BeliefPropagationProblem, + ::BeliefPropagation, + state::ITNNP.DaggerState + ) + dst_cache = state.iterate.iterate + + for quotient_vertex in quotientvertices(dst_cache) + substate = fetch(dst_cache[quotient_vertex]) + subcache = substate.iterate + edge_data(dst_cache) .= edge_data(subcache) + end return state end diff --git a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl index 7117641..5d56ea4 100644 --- a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl +++ b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl @@ -1,6 +1,10 @@ module ITensorNetworksNextParallel import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE + +abstract type ParallelAlgorithm{Child} <: AIE.NestedAlgorithm{Child} end +const IterativeParallelAlgorithm{Child <: ParallelAlgorithm} = AIE.NestedAlgorithm{Child} """ get_subiterate(subproblem::AI.Problem, subalgorithm::AI.Algorithm, state::AI.State) @@ -11,6 +15,25 @@ The returned value of this function is then pass to a remote call of `initialize """ get_subiterate(::AI.Problem, ::AI.Algorithm, state::AI.State) = state.iterate +finalize_state!(::AI.Problem, ::AI.Algorithm, state::AI.State) = state + +function AI.is_finished!( + problem::AI.Problem, + algorithm::IterativeParallelAlgorithm, + state::AI.State + ) + c = algorithm.stopping_criterion + st = state.stopping_criterion_state + + isfinished = AI.is_finished!(problem, algorithm, state, c, st) + + if isfinished + finalize_state!(problem, algorithm, state) + end + + return isfinished +end + include("dagger.jl") end # ITensorNetworksNextParallel diff --git a/src/ITensorNetworksNextParallel/dagger.jl b/src/ITensorNetworksNextParallel/dagger.jl index 8f79de5..4ed21ad 100644 --- a/src/ITensorNetworksNextParallel/dagger.jl +++ b/src/ITensorNetworksNextParallel/dagger.jl @@ -1,31 +1,40 @@ import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE import AlgorithmsInterface as AI +using Dictionaries: Dictionary using ITensorNetworksNext: AbstractBeliefPropagationCache @kwdef mutable struct DaggerState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, Chunk, DTask, + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, DTask, } <: AIE.State iterate::Iterate # DaggerBeliefPropagationCache iteration::Int = 0 stopping_criterion_state::StoppingCriterionState - # remote_subiterates::Dict{Int, Chunk} = Dict{Int, Any}() - remote_results::Dict{Int, DTask} = Dict{Int, Any}() + remote_results::Dictionary{Int, DTask} = Dict{Int, Any}() +end + +function initialize_dagger_state(problem, algorithm; kwargs...) + throw( + ErrorException( + "Package Dagger not loaded. Please install and load the Dagger package." + ) + ) end @kwdef struct DaggerNestedAlgorithm{ ChildAlgorithm <: AIE.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, - KeyType, - } <: AIE.NestedAlgorithm + } <: ParallelAlgorithm{ChildAlgorithm} algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) - workers::Vector{Int} - keys::Vector{KeyType} = collect(1:length(algorithms)) end -function dagger_algorithm(f::Function, iterable; kwargs...) - return DaggerNestedAlgorithm(f, iterable; kwargs...) +function dagger_algorithm(f, iterable; kwargs...) + throw( + ErrorException( + "Package Dagger not loaded. Please install and load the Dagger package." + ) + ) end # ================================== belief propagation ================================== # diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 1a62792..af5aa3f 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -60,7 +60,7 @@ end ChildAlgorithm <: AIE.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, - } <: AIE.NestedAlgorithm + } <: AIE.NestedAlgorithm{ChildAlgorithm} algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end From 7110108bfb0c8c99936d46093f02e39bbe8bf4cf Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Thu, 21 May 2026 13:57:20 -0400 Subject: [PATCH 59/64] for merge --- .../ITensorNetworksNextDaggerExt.jl | 77 +++++++++---------- src/ITensorNetworksNextParallel/dagger.jl | 30 ++++---- src/beliefpropagation/beliefpropagation.jl | 23 +++++- 3 files changed, 70 insertions(+), 60 deletions(-) diff --git a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl index 33375c7..c24168b 100644 --- a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl +++ b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl @@ -4,60 +4,53 @@ import AlgorithmsInterface as AI import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP using Dagger -using ITensorNetworksNext.ITensorNetworksNextParallel: - DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel using Dictionaries: set! +using ITensorNetworksNext.ITensorNetworksNextParallel: + DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel, step_dagger! -function ITNNP.DaggerNestedAlgorithm(f, iterable; kwargs...) - return DaggerNestedAlgorithm(; algorithms = map(f, iterable), kwargs...) -end - -function ITNNP.dagger_algorithm(f::Base.Callable, iterable; kwargs...) - return DaggerNestedAlgorithm(f, iterable; kwargs...) -end - -function ITNNP.initialize_dagger_state( - problem::AIE.Problem, algorithm::AIE.Algorithm; iterate - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - +function AI.initialize_state(problem::AI.Problem, algorithm::AI.Algorithm; kwargs...) + substate = AI.initialize_state(problem, algorithm; kwargs...) remote_results = Dictionary{Int, Dagger.DTask}() - - return ITNNP.DaggerState(; - iterate, - remote_results, - stopping_criterion_state - ) -end - -function AI.initialize_state( - problem::AIE.Problem, - algorithm::ITNNP.DaggerNestedAlgorithm; - kwargs... + return DaggerState(; + substate, + substate.iteration, + substate.stopping_criterion_state, + remote_results ) - return ITNNP.initialize_dagger_state(problem, algorithm; kwargs...) end function AI.step!( - problem::AIE.Problem, - algorithm::ITNNP.DaggerNestedAlgorithm, - state::ITNNP.DaggerState; - kwargs... - ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] + problem::AI.Problem, + algorithm::Dagger, + state::DaggerState + ) where {Algorithm} + # Forward the "external" stopping info to the internal states. + state.substate.iteration = state.iteration + state.substate.stopping_criterion_state = state.stopping_criterion_state - iterate = ITNNP.get_subiterate(subproblem, subalgorithm, state) - - dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm; iterate) + dtask = Dagger.@spawn step_dagger!(problem, algorithm.parent, substate) set!(state.remote_results, state.iteration, dtask) - return state end -include("daggerbeliefpropagation.jl") +function ITNNP.step_dagger!( + problem::AI.Problem, + algorithm::AI.Algorithm, + state::AI.State + ) + AI.step!(problem, algorithm, state) + + return return state +end + +function AI.finalize_state!(::AI.Problem, ::AI.Algorithm, state::DaggerState) + for dtask in collect(state.remote_results) + wait(dtask) + end + return state.substate.iterate +end + +# include("daggerbeliefpropagation.jl") end # ITensorNetworksNextDaggerExt diff --git a/src/ITensorNetworksNextParallel/dagger.jl b/src/ITensorNetworksNextParallel/dagger.jl index 4ed21ad..720e3f0 100644 --- a/src/ITensorNetworksNextParallel/dagger.jl +++ b/src/ITensorNetworksNextParallel/dagger.jl @@ -1,18 +1,22 @@ import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE import AlgorithmsInterface as AI -using Dictionaries: Dictionary using ITensorNetworksNext: AbstractBeliefPropagationCache +abstract type Driver end + +struct SERIAL <: Driver end +struct DAGGER <: Driver end + @kwdef mutable struct DaggerState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, DTask, - } <: AIE.State - iterate::Iterate # DaggerBeliefPropagationCache + Substate, StoppingCriterionState <: AI.StoppingCriterionState, DTask, + } <: AIE.NestedState + substate::Substate iteration::Int = 0 stopping_criterion_state::StoppingCriterionState - remote_results::Dictionary{Int, DTask} = Dict{Int, Any}() + remote_results::Dict{Int, DTask} = Dict{Int, Any}() end -function initialize_dagger_state(problem, algorithm; kwargs...) +function initialize_dagger_state(_problem, _algorithm; _kwargs...) throw( ErrorException( "Package Dagger not loaded. Please install and load the Dagger package." @@ -20,16 +24,14 @@ function initialize_dagger_state(problem, algorithm; kwargs...) ) end -@kwdef struct DaggerNestedAlgorithm{ - ChildAlgorithm <: AIE.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - } <: ParallelAlgorithm{ChildAlgorithm} - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +@kwdef struct InParallel{Algorithm <: AI.Algorithm} + parent::Algorithm + end -function dagger_algorithm(f, iterable; kwargs...) + + +function step_dagger!(_problem, _algorithm, _state) throw( ErrorException( "Package Dagger not loaded. Please install and load the Dagger package." diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index 876e5cc..8bf382b 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -45,16 +45,19 @@ function beliefpropagation( factors, messages; edges = default_beliefpropagation_edges(factors), stopping_criterion = nothing, - message_update_algorithm = nothing + message_update_algorithm = nothing, + driver = nothing ) problem = BeliefPropagationProblem(factors) message_update_algorithm = AIE.select_algorithm( message_update!, message_update_algorithm ) - subalgorithm = BeliefPropagationSweepAlgorithm(; - message_update_algorithm, - stopping_criterion = AI.StopAfterIteration(length(edges)) + subalgorithm = Daggered( + BeliefPropagationSweepAlgorithm(; + message_update_algorithm, + stopping_criterion = AI.StopAfterIteration(length(edges)) + ) ) stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) algorithm = BeliefPropagationAlgorithm(; edges, subalgorithm, stopping_criterion) @@ -182,6 +185,18 @@ function AI.step!( return state end +function AI.step!( + problem::BeliefPropagationSweepProblem, + algorithm::Dagger{BeliefPropagationSweepAlgorithm}, + state::BeliefPropagationSweepState + ) + edge = problem.edges[state.iteration] + message_update!( + algorithm.message_update_algorithm, state.iterate, problem.factors, edge + ) + return state +end + # === Layer 3: single-edge message update strategy === # Strategy interface: a `MessageUpdateAlgorithm` defines how a single From c6d5a301dbdf6a943220f6083dd7b7cce7a17036 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 1 Jun 2026 17:16:52 -0400 Subject: [PATCH 60/64] Delete legacy files. --- .../daggerbeliefpropagation.jl | 194 ----------- src/ITensorNetworksNextParallel/dagger.jl | 49 --- .../beliefpropagationcache.jl | 164 ---------- .../beliefpropagationproblem.jl | 304 ------------------ 4 files changed, 711 deletions(-) delete mode 100644 ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl delete mode 100644 src/ITensorNetworksNextParallel/dagger.jl delete mode 100644 src/beliefpropagation/beliefpropagationcache.jl delete mode 100644 src/beliefpropagation/beliefpropagationproblem.jl diff --git a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl b/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl deleted file mode 100644 index 547b6fb..0000000 --- a/ext/ITensorNetworksNextDaggerExt/daggerbeliefpropagation.jl +++ /dev/null @@ -1,194 +0,0 @@ -import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP -using Dagger -using DataGraphs: DataGraphs, edge_data, get_edge_data, get_vertex_data, is_edge_assigned, - is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph -using Dictionaries: Dictionary, Indices, getindices -using Graphs: AbstractEdge, AbstractGraph, dst, edges, src, vertices -using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagation, BeliefPropagationCache, - BeliefPropagationProblem, BeliefPropagationState, beliefpropagation, - forest_cover_edge_sequence, select_algorithm, subcache -using NamedGraphs.GraphsExtensions: boundary_edges -using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientedges, quotientvertices -using NamedGraphs: NamedGraphs - -const DaggerBeliefPropagation = BeliefPropagation{<:ITNNP.DaggerNestedAlgorithm}; - -function ITNNP.DaggerBeliefPropagationCache( - network::AbstractGraph; - workers = nothing, - scopes = nothing - ) - underlying_cache = BeliefPropagationCache(network) - - keys = Indices(quotientvertices(underlying_cache)) - - if isnothing(scopes) - workers = isnothing(workers) ? Dagger.Distributed.workers() : workers - - sorted_workers = Iterators.take(Iterators.cycle(workers), length(keys)) - - scopes = map(Dagger.ProcessScope, collect(sorted_workers)) - else - if length(keys) != length(scopes) - throw( - ArgumentError( - "Number of provided scopes must match the number of vertex partitions of underlying graph" - ) - ) - end - end - - scope_dict = Dictionary(keys, scopes) - - quotient_chunks = map(keys) do quotient_vertex - scope = scope_dict[quotient_vertex] - iterate = subcache(underlying_cache, quotient_vertex) - chunk = Dagger.@mutable scope = scope BeliefPropagationState(; iterate) - return chunk - end - - return ITNNP.DaggerBeliefPropagationCache(underlying_cache, quotient_chunks) -end - -function DataGraphs.underlying_graph(cache::ITNNP.DaggerBeliefPropagationCache) - return underlying_graph(cache.underlying_cache) -end - -function DataGraphs.is_vertex_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, vertex) - return is_vertex_assigned(bpc.underlying_cache, vertex) -end -function DataGraphs.is_edge_assigned(bpc::ITNNP.DaggerBeliefPropagationCache, edge) - return is_edge_assigned(bpc.underlying_cache, edge) -end - -function DataGraphs.get_vertex_data(bpc::ITNNP.DaggerBeliefPropagationCache, vertex) - return get_vertex_data(bpc.underlying_cache, vertex) -end -function DataGraphs.get_edge_data( - bpc::ITNNP.DaggerBeliefPropagationCache, - edge::AbstractEdge - ) - return get_edge_data(bpc.underlying_cache, edge) -end - -function DataGraphs.set_vertex_data!(bpc::ITNNP.DaggerBeliefPropagationCache, val, vertex) - return set_vertex_data!(bpc.underlying_cache, val, vertex) -end -function DataGraphs.set_edge_data!(bpc::ITNNP.DaggerBeliefPropagationCache, val, edge) - return set_edge_data!(bpc.underlying_cache, val, edge) -end - -NamedGraphs.to_graph_index(::ITNNP.DaggerBeliefPropagationCache, qv::QuotientVertex) = qv -function DataGraphs.get_index_data( - cache::ITNNP.DaggerBeliefPropagationCache, - qv::QuotientVertex - ) - return cache.quotient_chunks[qv] -end - -function ITensorNetworksNext.beliefpropagation_sweep( - cache::ITNNP.DaggerBeliefPropagationCache; - edges, - kwargs... - ) - return ITNNP.dagger_algorithm(quotientvertices(cache)) do quotient_vertex - substate = fetch(cache[quotient_vertex]) - subcache = substate.iterate - - subcache_edges = forest_cover_edge_sequence(subcache) ∩ edges - incoming_edges = boundary_edges(cache, vertices(cache, quotient_vertex); dir = :in) - - alg = select_algorithm( - beliefpropagation, - subcache; - # Don't update the incoming messages - edges = setdiff(subcache_edges, incoming_edges), - maxiter = 1, - kwargs... - ) - - return alg - end -end - -function AI.initialize_state( - problem::AIE.Problem, - algorithm::BeliefPropagation{<:ITNNP.DaggerNestedAlgorithm}; - kwargs... - ) - return ITNNP.initialize_dagger_state(problem, algorithm; kwargs...) -end - -function ITNNP.get_subiterate( - ::BeliefPropagationProblem, - ::BeliefPropagation, # Our parallel region runs a small BP - state::ITNNP.DaggerState - ) - cache = state.iterate.iterate - - quotient_vertex = collect(quotientvertices(cache))[state.iteration] - - subiterate = cache[quotient_vertex] - - return subiterate -end - -function AIE.set_substate!( - problem::BeliefPropagationProblem, - algorithm::AIE.NestedAlgorithm, - state::AIE.State, - substate::ITNNP.DaggerState - ) - dst_cache = state.iterate.iterate - - state.iterate.maxdiff = 0.0 - - maxdiff_dtasks = map(substate.remote_results) do remote_result - return Dagger.spawn(dtask -> dtask.iterate.maxdiff, remote_result) - end - - maxdiff = maximum(fetch, maxdiff_dtasks) - - if maxdiff > state.iterate.maxdiff - state.iterate.maxdiff = maxdiff - end - - transfer_dtasks = map(quotientedges(dst_cache)) do quotient_edge - src_subcache = dst_cache[src(quotient_edge)] - dst_subcache = dst_cache[dst(quotient_edge)] - - src_subcache = fetch(src_subcache) - - return Dagger.spawn( - dst_subcache, - fetch(src_subcache), - edges(dst_cache, quotient_edge) - ) do dst, src, edges - src_subcache = src.iterate - dst_subcache = dst.iterate - for edge in edges - dst_subcache[edge] = src_subcache[edge] - end - end - end - - foreach(wait, transfer_dtasks) - - return state -end - -function ITNNP.finalize_state!( - ::BeliefPropagationProblem, - ::BeliefPropagation, - state::ITNNP.DaggerState - ) - dst_cache = state.iterate.iterate - - for quotient_vertex in quotientvertices(dst_cache) - substate = fetch(dst_cache[quotient_vertex]) - subcache = substate.iterate - edge_data(dst_cache) .= edge_data(subcache) - end - - return state -end diff --git a/src/ITensorNetworksNextParallel/dagger.jl b/src/ITensorNetworksNextParallel/dagger.jl deleted file mode 100644 index 720e3f0..0000000 --- a/src/ITensorNetworksNextParallel/dagger.jl +++ /dev/null @@ -1,49 +0,0 @@ -import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -import AlgorithmsInterface as AI -using ITensorNetworksNext: AbstractBeliefPropagationCache - -abstract type Driver end - -struct SERIAL <: Driver end -struct DAGGER <: Driver end - -@kwdef mutable struct DaggerState{ - Substate, StoppingCriterionState <: AI.StoppingCriterionState, DTask, - } <: AIE.NestedState - substate::Substate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState - remote_results::Dict{Int, DTask} = Dict{Int, Any}() -end - -function initialize_dagger_state(_problem, _algorithm; _kwargs...) - throw( - ErrorException( - "Package Dagger not loaded. Please install and load the Dagger package." - ) - ) -end - -@kwdef struct InParallel{Algorithm <: AI.Algorithm} - parent::Algorithm - -end - - - -function step_dagger!(_problem, _algorithm, _state) - throw( - ErrorException( - "Package Dagger not loaded. Please install and load the Dagger package." - ) - ) -end - -# ================================== belief propagation ================================== # - -struct DaggerBeliefPropagationCache{ - V, VD, ED, UC <: AbstractBeliefPropagationCache{V, VD, ED}, Chunks, - } <: AbstractBeliefPropagationCache{V, VD, ED} - underlying_cache::UC - quotient_chunks::Chunks -end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl deleted file mode 100644 index 5d1a31c..0000000 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ /dev/null @@ -1,164 +0,0 @@ -using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, edge_data_type, - set_vertex_data!, underlying_graph, underlying_graph_type, vertex_data, vertex_data_type -using Dictionaries: Dictionary, delete!, getindices, set! -using Graphs: AbstractGraph, connected_components, is_directed, is_tree -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, lazy, parenttype -using NamedGraphs.GraphsExtensions: - default_root_vertex, forest_cover, post_order_dfs_edges, undirected_graph, vertextype -using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, quotient_graph -using NamedGraphs: Vertices, convert_vertextype, parent_graph_indices - -struct BeliefPropagationCache{V, VD, ED, E, G <: AbstractGraph{V}} <: - AbstractBeliefPropagationCache{V, VD, ED} - underlying_graph::G # we only use this for the edges. - factors::Dictionary{V, VD} - messages::Dictionary{E, ED} - function BeliefPropagationCache( - graph::AbstractGraph, - factors::Dictionary, - messages::Dictionary - ) - # Ensure the graph is directed, if not make it directed. - digraph = is_directed(graph) ? graph : directed_graph(graph) - - V = keytype(factors) - VD = eltype(factors) - - E = keytype(messages) - ED = eltype(messages) - - bpc = new{V, VD, ED, E, typeof(digraph)}(digraph, factors, messages) - - for edge in edges(bpc) - get!(() -> default_message(bpc, edge), messages, edge) - end - return bpc - end -end - -DataGraphs.underlying_graph(bpc::BeliefPropagationCache) = bpc.underlying_graph - -function DataGraphs.is_vertex_assigned(bpc::BeliefPropagationCache, vertex) - return haskey(bpc.factors, vertex) -end -DataGraphs.is_edge_assigned(bpc::BeliefPropagationCache, edge) = haskey(bpc.messages, edge) - -DataGraphs.get_vertex_data(bpc::BeliefPropagationCache, vertex) = bpc.factors[vertex] -function DataGraphs.get_edge_data(bpc::BeliefPropagationCache, edge::AbstractEdge) - return bpc.messages[edge] -end - -function DataGraphs.set_vertex_data!(bpc::BeliefPropagationCache, val, vertex) - return set!(bpc.factors, vertex, val) -end -function DataGraphs.set_edge_data!(bpc::BeliefPropagationCache, val, edge) - return set!(bpc.messages, edge, val) -end - -# These two methods assume `network` behaves llike a tensor network -# (could be e.g. a QuotientView) otherwise how would one know what the factors should be. -function BeliefPropagationCache(network::AbstractGraph) - graph = underlying_graph(network) - return BeliefPropagationCache(graph, copy(vertex_data(network))) -end -function BeliefPropagationCache(MT::Type, network::AbstractGraph) - graph = underlying_graph(network) - return BeliefPropagationCache(MT, graph, copy(vertex_data(network))) -end - -function BeliefPropagationCache(graph::AbstractGraph, factors::Dictionary) - MT = vertex_data_type(typeof(graph)) - return BeliefPropagationCache(MT, graph, factors) -end -function BeliefPropagationCache(MT::Type, graph::AbstractGraph, factors::Dictionary) - messages = Dictionary{edgetype(graph), MT}() - return BeliefPropagationCache(graph, factors, messages) -end - -function Base.copy(bp_cache::BeliefPropagationCache) - return BeliefPropagationCache( - copy(bp_cache.underlying_graph), - copy(bp_cache.factors), - copy(bp_cache.messages) - ) -end - -# TODO: This needs to go in GraphsExtensions -function forest_cover_edge_sequence(gi::AbstractGraph; root_vertex = default_root_vertex) - # All we care about are the edges so the type of the graph doesnt matter - g = NamedGraph(vertices(gi)) - add_edges!(g, edges(gi)) - forests = forest_cover(g) - rv = edgetype(g)[] - for forest in forests - trees = [forest[Vertices(vs)] for vs in connected_components(forest)] - for tree in trees - tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) - push!(rv, vcat(tree_edges, reverse(reverse.(tree_edges)))...) - end - end - return rv -end - -function induced_subgraph_bpcache(graph, subvertices) - underlying_subgraph, vlist = - Graphs.induced_subgraph(underlying_graph(graph), subvertices) - - assigned = v -> isassigned(graph, v) - - assigned_subvertices = Iterators.filter(assigned, subvertices) - assigned_subedges = Iterators.filter(assigned, edges(underlying_subgraph)) - - factors = getindices(vertex_data(graph), Indices(assigned_subvertices)) - messages = getindices(edge_data(graph), Indices(assigned_subedges)) - - subgraph = BeliefPropagationCache(underlying_subgraph, factors, messages) - - return subgraph, vlist -end - -function NamedGraphs.induced_subgraph_from_vertices( - graph::BeliefPropagationCache, - subvertices - ) - return induced_subgraph_bpcache(graph, subvertices) -end - -## PartitionedGraphs - -# Take a QuotientView of the underlying graph. -function PartitionedGraphs.quotientview(bpc::BeliefPropagationCache) - graph = underlying_graph(bpc) - - quotient_view = QuotientView(graph) - - factors = map(v -> bpc[QuotientVertex(v)], Indices(vertices(quotient_view))) - messages = map(e -> bpc[QuotientEdge(e)], Indices(edges(quotient_view))) - - return BeliefPropagationCache(quotient_view, factors, messages) -end - -function default_message(bpc::BeliefPropagationCache, edge) - return default_message(message_type(bpc), bpc[src(edge)], bpc[dst(edge)]) -end -function default_message(T::Type, src, dst) - array = ones(Tuple(inds(src) ∩ inds(dst))) - return convert(T, array) -end -function default_message(T::Type{<:LazyNamedDimsArray}, src, dst) - message = default_message(parenttype(T), src, dst) - return convert(T, lazy(message)) -end - -NamedGraphs.to_graph_index(::BeliefPropagationCache, vertex::QuotientVertex) = vertex -# When getting data according the quotient vertices, take a lazy contraction. -function DataGraphs.get_index_data(tn::BeliefPropagationCache, vertex::QuotientVertex) - data = collect(map(v -> tn[v], vertices(tn, vertex))) - return mapreduce(lazy, *, data) -end -function DataGraphs.is_graph_index_assigned( - tn::BeliefPropagationCache, - vertex::QuotientVertex - ) - return isassigned(tn, Vertices(vertices(tn, vertex))) -end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl deleted file mode 100644 index af5aa3f..0000000 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ /dev/null @@ -1,304 +0,0 @@ -import .AlgorithmsInterfaceExtensions as AIE -import AlgorithmsInterface as AI -using DataGraphs: edge_data -using Graphs: AbstractEdge, edges, has_edge, vertices -using LinearAlgebra: norm, normalize -using NamedDimsArrays: AbstractNamedDimsArray -using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph -using NamedGraphs.PartitionedGraphs: quotientvertices - -@kwdef struct StopWhenConverged <: AI.StoppingCriterion - tol::Float64 = 0.0 -end - -@kwdef mutable struct StopWhenConvergedState <: AI.StoppingCriterionState - delta::Float64 = Inf -end - -function AI.initialize_state(::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged) - return StopWhenConvergedState() -end - -function AI.initialize_state!( - ::AIE.Problem, - ::AIE.Algorithm, - ::StopWhenConverged, - st::StopWhenConvergedState - ) - st.delta = Inf - return st -end - -function AI.is_finished!( - ::AIE.Problem, - ::AIE.Algorithm, - state::AIE.State, - c::StopWhenConverged, - st::StopWhenConvergedState - ) - - # maxdiff = 0.0 initially, so skip this the first time. - if state.iteration > 0 - st.delta = state.iterate.maxdiff - end - - return st.delta < c.tol -end - -struct BeliefPropagationProblem{Network} <: AIE.Problem - network::Network -end - -@kwdef mutable struct BeliefPropagationState{Iterate, Diffs} <: - AIE.NonIterativeAlgorithmState - iterate::Iterate - diffs::Diffs = similar(edge_data(iterate), Float64) - maxdiff::Float64 = 0.0 -end - -@kwdef struct BeliefPropagation{ - ChildAlgorithm <: AIE.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - } <: AIE.NestedAlgorithm{ChildAlgorithm} - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) -end - -function BeliefPropagation(f::Function, niterations::Int; kwargs...) - return BeliefPropagation(; algorithms = f.(1:niterations), kwargs...) -end - -abstract type AbstractMessageUpdate <: AIE.NonIterativeAlgorithm end - -struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} <: AbstractMessageUpdate - edge::E - kwargs::Kwargs -end - -function SimpleMessageUpdate( - edge; - normalize = true, - contraction_alg = "exact", - compute_diff = false, - kwargs... - ) - return SimpleMessageUpdate( - edge, - (; normalize, contraction_alg, compute_diff, kwargs...) - ) -end - -function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) - if name in (:edge, :kwargs) - return getfield(alg, name) - else - return getproperty(getfield(alg, :kwargs), name) - end -end - -struct BeliefPropagationSweep{ - ChildAlgorithm <: AIE.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - } <: AIE.NestedAlgorithm - algorithms::Algorithms - stopping_criterion::AI.StopAfterIteration - function BeliefPropagationSweep(; algorithms) - stopping_criterion = AI.StopAfterIteration(length(algorithms)) - return new{eltype(algorithms), typeof(algorithms)}(algorithms, stopping_criterion) - end -end - -function BeliefPropagationSweep(f::Function, edges) - return BeliefPropagationSweep(; algorithms = f.(edges)) -end - -function AI.initialize_state( - ::BeliefPropagationProblem, ::AIE.NonIterativeAlgorithm; iterate, kwargs... - ) - diffs = iterate.diffs - maxdiff = iterate.maxdiff - - return BeliefPropagationState(; iterate = iterate.iterate, diffs, maxdiff, kwargs...) -end - -# This gets called at the start of every sweep. -function AI.initialize_state!( - ::BeliefPropagationProblem, - ::BeliefPropagationSweep, - iteration_state::AIE.State - ) - iteration_state.iterate.maxdiff = 0.0 - return iteration_state -end - -function AIE.set_substate!( - ::BeliefPropagationProblem, - ::BeliefPropagationSweep, - sweep_state::AIE.DefaultState, - noniterative_substate::BeliefPropagationState - ) - sweep_state.iterate = noniterative_substate - - return sweep_state -end - -struct MessageUpdateProblem{Factor, Messages} <: AIE.Problem - factor::Factor - messages::Messages -end - -function AI.solve!( - problem::BeliefPropagationProblem, - algorithm::SimpleMessageUpdate, - state::BeliefPropagationState; - logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm) - ) - logger = AI.algorithm_logger() - - cache = state.iterate - edge = algorithm.edge - - AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) - ) - - new_message = updated_message(algorithm, cache) - - if !isnothing(algorithm.message_diff_function) - diff = algorithm.message_diff_function(new_message, cache[edge]) - - if diff > state.maxdiff - state.maxdiff = diff - end - - state.diffs[edge] = diff - end - - setmessage!(cache, edge, new_message) - - AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PostUpdate) - ) - - return state -end - -default_message_diff_function(m1, m2) = norm(normalize(m1) - normalize(m2)) - -function updated_message(algorithm, cache) - edge = algorithm.edge - - vertex = src(edge) - messages = incoming_messages(cache, vertex; ignore_edges = typeof(edge)[reverse(edge)]) - - update_problem = MessageUpdateProblem(cache[vertex], messages) - - message_state = AI.solve(update_problem, algorithm; iterate = message(cache, edge)) - - return message_state.iterate -end - -function AI.solve!( - problem::MessageUpdateProblem, - algorithm::SimpleMessageUpdate, - state::AIE.DefaultNonIterativeAlgorithmState; - logging_context_prefix = AIE.default_logging_context_prefix(problem, algorithm), - kwargs... - ) - logger = AI.algorithm_logger() - - AI.emit_message( - logger, problem, algorithm, state, Symbol(logging_context_prefix, :PreUpdate) - ) - - state.iterate = - contract_messages(algorithm.contraction_alg, problem.factor, problem.messages) - - AI.emit_message( - logger, problem, algorithm, state, Symbol( - logging_context_prefix, - :PreNormalization - ) - ) - - if algorithm.normalize - # TODO: use `sum` not `norm` - message_norm = LinearAlgebra.norm(state.iterate) - if !iszero(message_norm) - state.iterate /= message_norm - end - end - - AI.emit_message( - logger, problem, algorithm, state, - Symbol(logging_context_prefix, :PostNormalization) - ) - - return state -end - -function contract_messages(alg, factor::AbstractArray, messages) - factors = typeof(factor)[factor] - return contract_network(vcat(factors, messages); alg) -end - -function beliefpropagation(network; kwargs...) - return beliefpropagation(BeliefPropagationCache(network), network; kwargs...) -end -function beliefpropagation( - cache::AbstractBeliefPropagationCache, - network = nothing; - kwargs... - ) - problem = BeliefPropagationProblem(network) - - algorithm = select_algorithm(beliefpropagation, cache; kwargs...) - - # The nested algorithms will wrap and manipulate this object. - base_state = BeliefPropagationState(; iterate = cache) - - state = AI.initialize_state(problem, algorithm; iterate = base_state) - - state = AI.solve!(problem, algorithm, state) - - return state.iterate.iterate -end - -function select_algorithm( - ::typeof(beliefpropagation), - cache::AbstractBeliefPropagationCache; - edges = forest_cover_edge_sequence(cache), - maxiter = is_tree(cache) ? 1 : nothing, - tol = -Inf, - message_diff_function = if tol > -Inf - (m1, m2) -> norm(m1 / norm(m1) - m2 / norm(m2)) - else - nothing - end, - kwargs... - ) - if isnothing(maxiter) - throw(ArgumentError("`maxiter` must be specified for non-tree graphs")) - end - - stopping_criterion = AI.StopAfterIteration(maxiter) - - if tol > -Inf - stopping_criterion = stopping_criterion | StopWhenConverged(tol) - end - - extended_kwargs = extend_columns((; message_diff_function, kwargs...), maxiter) - edge_kwargs = rows(extended_kwargs, maxiter) - - return BeliefPropagation(maxiter; stopping_criterion) do repnum - return beliefpropagation_sweep(cache; edges, edge_kwargs[repnum]...) - end -end - -# A single sweep across the given edges. -function beliefpropagation_sweep(::BeliefPropagationCache; edges, kwargs...) - return BeliefPropagationSweep(edges) do edge - return SimpleMessageUpdate(edge; kwargs...) - end -end From cce227f8de851fb640940d84449f59e9d8bdbc8f Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 1 Jun 2026 17:18:23 -0400 Subject: [PATCH 61/64] Add partitioned from of BP; define generic algorithm state objects. --- .../AlgorithmsInterfaceExtensions.jl | 41 +++++- src/ITensorNetworksNext.jl | 1 + src/beliefpropagation/beliefpropagation.jl | 60 +-------- src/beliefpropagation/partitioned.jl | 126 ++++++++++++++++++ test/test_beliefpropagation.jl | 23 ++++ 5 files changed, 193 insertions(+), 58 deletions(-) create mode 100644 src/beliefpropagation/partitioned.jl diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index a95e0e0..5ee2764 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -2,6 +2,37 @@ module AlgorithmsInterfaceExtensions import AlgorithmsInterface as AI +@kwdef mutable struct GenericAlgorithmState{Iterate, StoppingCriterionState} <: AI.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +function AI.initialize_state( + problem::AI.Problem, + algorithm::AI.Algorithm; + iterate, + iteration = 0 + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return GenericAlgorithmState(; iterate, iteration, stopping_criterion_state) +end + +function AI.initialize_state!( + problem::AI.Problem, + algorithm::AI.Algorithm, + state::AI.State; + iteration = 0 + ) + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + # ============================ NestedAlgorithm ============================================= abstract type NestedAlgorithm <: AI.Algorithm end @@ -18,7 +49,7 @@ function initialize_subsolve( end function finalize_substate!( - problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State, substate::AI.State + ::AI.Problem, ::AI.Algorithm, substate::AI.State, state::AI.State ) state.iterate = substate.iterate return state @@ -27,7 +58,7 @@ end function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State) subproblem, subalgorithm, substate = initialize_subsolve(problem, algorithm, state) AI.solve!(subproblem, subalgorithm, substate) - finalize_substate!(problem, algorithm, state, substate) + finalize_substate!(problem, algorithm, substate, state) return state end @@ -38,6 +69,12 @@ end # state. Subtypes must store the inner state as a field named `substate`. abstract type NestedState <: AI.State end +@kwdef mutable struct GenericNestedState{Substate, StoppingCriterionState} <: NestedState + substate::Substate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + # Use `getfield` on the right-hand side so future edits to this forwarder # can't accidentally recurse through the overload. function Base.getproperty(state::NestedState, name::Symbol) diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 7f315ae..a3dff39 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -15,6 +15,7 @@ include("contract_network.jl") include("beliefpropagation/messagecache.jl") include("beliefpropagation/beliefpropagation.jl") +include("beliefpropagation/partitioned.jl") include("ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl") diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index d6dfabc..db134b1 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -95,14 +95,6 @@ end stopping_criterion::StoppingCriterion end -@kwdef mutable struct BeliefPropagationState{ - Substate <: AI.State, StoppingCriterionState <: AI.StoppingCriterionState, - } <: AIE.NestedState - substate::Substate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState -end - function AI.initialize_state( problem::BeliefPropagationProblem, algorithm::BeliefPropagationAlgorithm; @@ -110,29 +102,18 @@ function AI.initialize_state( ) subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) substate = AI.initialize_state(subproblem, algorithm.subalgorithm; iterate) + stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion; iterate ) - return BeliefPropagationState(; iteration, stopping_criterion_state, substate) -end -function AI.initialize_state!( - problem::BeliefPropagationProblem, - algorithm::BeliefPropagationAlgorithm, - state::BeliefPropagationState; - iteration::Int = 0 - ) - state.iteration = iteration - AI.initialize_state!( - problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state - ) - return state + return AIE.GenericNestedState(; iteration, stopping_criterion_state, substate) end function AIE.initialize_subsolve( problem::BeliefPropagationProblem, algorithm::BeliefPropagationAlgorithm, - state::BeliefPropagationState + state::AIE.NestedState ) subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) return subproblem, algorithm.subalgorithm, state.substate @@ -153,42 +134,10 @@ end stopping_criterion::StoppingCriterion end -@kwdef mutable struct BeliefPropagationSweepState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: AI.State - iterate::Iterate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState -end - -function AI.initialize_state( - problem::BeliefPropagationSweepProblem, - algorithm::BeliefPropagationSweepAlgorithm; - iterate, iteration::Int = 0 - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion; iterate - ) - return BeliefPropagationSweepState(; iterate, iteration, stopping_criterion_state) -end - -function AI.initialize_state!( - problem::BeliefPropagationSweepProblem, - algorithm::BeliefPropagationSweepAlgorithm, - state::BeliefPropagationSweepState; - iteration::Int = 0 - ) - state.iteration = iteration - AI.initialize_state!( - problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state - ) - return state -end - function AI.step!( problem::BeliefPropagationSweepProblem, algorithm::BeliefPropagationSweepAlgorithm, - state::BeliefPropagationSweepState + state::AI.State ) edge = problem.edges[state.iteration] message_update!( @@ -243,7 +192,6 @@ function message_update!(algorithm::SimpleMessageUpdate, cache, factors, edge) cache[edge] = new_message return cache end - # === `iterate_diff` for `MessageCache` (used by `AIE.StopWhenConverged`) === function AIE.iterate_diff(cache1::MessageCache, cache2::MessageCache) diff --git a/src/beliefpropagation/partitioned.jl b/src/beliefpropagation/partitioned.jl new file mode 100644 index 0000000..243f896 --- /dev/null +++ b/src/beliefpropagation/partitioned.jl @@ -0,0 +1,126 @@ +@kwdef struct PartitionedBeliefPropagationSweep{ + Subfactors, + MessageUpdateAlgorithm, + } <: AIE.NestedAlgorithm + partitioned_factors::Vector{Subfactors} + message_update_algorithm::MessageUpdateAlgorithm +end + +function Base.getproperty(alg::PartitionedBeliefPropagationSweep, name::Symbol) + if name === :stopping_criterion + return AI.StopAfterIteration(length(alg.partitioned_factors)) + end + return getfield(alg, name) +end + +function AIE.initialize_subsolve( + problem::BeliefPropagationSweepProblem, + algorithm::PartitionedBeliefPropagationSweep, + state::AI.State + ) + subvertices = algorithm.partitioned_factors[state.iteration] + cache = state.iterate + + incoming_edges = boundary_edges(problem.factors, subvertices; dir = :in) + + factors = subgraph(problem.factors, subvertices) + iterate = subgraph(cache, subvertices) + + + for edge in incoming_edges + add_vertex!(iterate, src(edge)) + add_vertex!(iterate, dst(edge)) + + iterate[edge] = cache[edge] + iterate[reverse(edge)] = cache[reverse(edge)] + end + + # Don't want to update the incoming messages. + subedges = setdiff(edges(iterate), incoming_edges) + + subproblem = BeliefPropagationSweepProblem(factors, subedges) + + subalgorithm = BeliefPropagationSweepAlgorithm(; + message_update_algorithm = algorithm.message_update_algorithm, + stopping_criterion = AI.StopAfterIteration(length(subedges)) + ) + + substate = AI.initialize_state(subproblem, subalgorithm; iterate) + + return subproblem, subalgorithm, substate +end + +function AIE.finalize_substate!( + subproblem::BeliefPropagationSweepProblem, + _subalgorithm::BeliefPropagationSweepAlgorithm, + substate::AI.State, + state::AI.State + ) + subcache = substate.iterate + subedges = subproblem.edges + + for edge in subedges + state.iterate[edge] = subcache[edge] + end + + return state +end + +function beliefpropagation( + factors, messages, partitions; + edges = default_beliefpropagation_edges(factors), + stopping_criterion = nothing, + message_update_algorithm = nothing, + sweep_algorithm = nothing + ) + problem = BeliefPropagationProblem(factors) + cache = MessageCache(messages) + + # No concrete `edge` value here, so the args tuple uses `edgetype(factors)`. + message_update_algorithm = AIE.select_algorithm( + message_update!, + message_update_algorithm, + Tuple{typeof(cache), typeof(factors), edgetype(factors)} + ) + + if isnothing(sweep_algorithm) + sweep_algorithm = PartitionedBeliefPropagationSweep(; + partitioned_factors = partitions, + message_update_algorithm + ) + end + stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) + algorithm = BeliefPropagationAlgorithm(; + edges, + subalgorithm = sweep_algorithm, + stopping_criterion + ) + + return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) +end + +# function message_update!(algorithm::KrylovMessageUpdate, cache, factors, path) +# function f(messages) +# temp_cache = copy(cache) # shallow copy. +# +# for (edge, message) in zip(path, messages) +# temp_cache[edge] = message +# end +# +# for (i, edge) in enumerate(path) +# message_update!( +# algorithm.message_update_algorithm, +# temp_cache, +# factors, +# edge +# ) +# messages[i] = temp_cache[edge] +# end +# +# return messages +# end +# +# # do eigsolve step. +# +# return cache +# end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 34cb305..084f728 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -213,5 +213,28 @@ end end end end + @testset "PartitionedBeliefPropagationSweep" begin + n = 4 + dims = (n, n) + g = named_grid(dims; periodic = true) + tn = spin_ice_tensornetwork(g) + + messages = + Dict(edge => rand(Tuple(linkinds(tn, edge))) for edge in all_edges(g)) + + partitions = [ + [ij, ij .+ (0, 1), ij .+ (1, 0), ij .+ (1, 1)] for + ij in [(1, 1), (1, 3), (3, 1), (3, 3)] + ] + + cache = ITensorNetworksNext.beliefpropagation( + tn, messages, partitions; + stopping_criterion = (; maxiter = 10, tol = 1.0e-10) + ) + + z_bp = exp(bethe_free_energy(tn, cache)) + + @test z_bp ≈ 1.5^(n^2) + end end end From 49a49818761820f14d9270e85645638e09831e13 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 1 Jun 2026 17:19:25 -0400 Subject: [PATCH 62/64] Refactor parallel and simplify code for parallelization. --- .../ITensorNetworksNextDaggerExt.jl | 98 ++++++++++++------- src/ITensorNetworksNext.jl | 1 + .../ITensorNetworksNextParallel.jl | 67 ++++++++----- 3 files changed, 108 insertions(+), 58 deletions(-) diff --git a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl index c24168b..837a3d3 100644 --- a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl +++ b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl @@ -4,53 +4,81 @@ import AlgorithmsInterface as AI import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP using Dagger -using Dictionaries: set! -using ITensorNetworksNext.ITensorNetworksNextParallel: - DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel, step_dagger! - -function AI.initialize_state(problem::AI.Problem, algorithm::AI.Algorithm; kwargs...) - substate = AI.initialize_state(problem, algorithm; kwargs...) - remote_results = Dictionary{Int, Dagger.DTask}() - return DaggerState(; - substate, - substate.iteration, - substate.stopping_criterion_state, - remote_results - ) +using Dictionaries: Dictionary, set! + +@kwdef mutable struct DaggerState{Chunk <: Dagger.Chunk} <: AI.State + chunk::Chunk # A chunk living on the host worker. + futures = Dictionary{Int, Dagger.DTask}() # the futures from the remote steps. end -function AI.step!( - problem::AI.Problem, - algorithm::Dagger, - state::DaggerState - ) where {Algorithm} - # Forward the "external" stopping info to the internal states. - state.substate.iteration = state.iteration - state.substate.stopping_criterion_state = state.stopping_criterion_state +function Base.getproperty(state::DaggerState, name::Symbol) + if name in (:chunk, :futures) + return getfield(state, name) + end + return getproperty(fetch(state.chunk), name) +end +function Base.setproperty!(state::DaggerState, name::Symbol, val) + if name === (:chunk, :futures) + return setfield!(state, name, val) + end + fetch(Dagger.@spawn setproperty!(state.chunk, name, val)) + return state +end - dtask = Dagger.@spawn step_dagger!(problem, algorithm.parent, substate) +# ====================================== overloads ======================================= # - set!(state.remote_results, state.iteration, dtask) - return state +function ITNNP.default_workers(::AI.Algorithm, ::ITNNP.AbstractDaggerStrategy) + return Dagger.Distributed.workers() end -function ITNNP.step_dagger!( - problem::AI.Problem, - algorithm::AI.Algorithm, - state::AI.State +function ITNNP.initialize_parallel_state( + problem::AI.Problem, algorithm::AI.Algorithm, + _strategy::ITNNP.GenericDaggerStrategy; kwargs... ) - AI.step!(problem, algorithm, state) + chunk = Dagger.@mutable AI.initialize_state(problem, algorithm; iterate, kwargs...) - return return state + return DaggerState(; chunk) end -function AI.finalize_state!(::AI.Problem, ::AI.Algorithm, state::DaggerState) - for dtask in collect(state.remote_results) - wait(dtask) +function AI.step!(problem::AI.Problem, algorithm::ITNNP.Parallelized, state::DaggerState) + worker_list = algorithm.workers + algorithm = algorithm.parent + + function remote_solve(state) + subsolve = AIE.initialize_subsolve(problem, algorithm, state) + + subproblem, subalgorithm, substate = subsolve + AI.solve!(subproblem, subalgorithm, substate) + + return subsolve end - return state.substate.iterate + + worker = worker_list[mod(state.iteration, 1:length(worker_list))] + + dtask = Dagger.@spawn scope = Dagger.scope(; worker) remote_solve(fetch(state.chunk)) + + # Spawns on the host only (as we do not fetch the chunk before hand.) + dtask = Dagger.spawn(dtask, state.chunk) do subsolve, state + subproblem, subalgorithm, substate = subsolve + + AIE.finalize_substate!(subproblem, subalgorithm, substate, state) + return state + end + + set!(state.futures, state.iteration, dtask) + + return state +end +function AI.finalize_state!(::AI.Problem, ::AI.Algorithm, state::DaggerState) + foreach(fetch, state.futures) + return state.iterate end -# include("daggerbeliefpropagation.jl") +function AIE.finalize_substate!( + problem::AI.Problem, algorithm::AI.Algorithm, substate::DaggerState, state::AI.State + ) + AIE.finalize_substate!(problem, algorithm, fetch(substate.chunk), state) + return state +end end # ITensorNetworksNextDaggerExt diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index a3dff39..4d42c44 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -17,6 +17,7 @@ include("beliefpropagation/messagecache.jl") include("beliefpropagation/beliefpropagation.jl") include("beliefpropagation/partitioned.jl") +# lib include("ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl") end diff --git a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl index 5d56ea4..333edea 100644 --- a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl +++ b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl @@ -1,39 +1,60 @@ module ITensorNetworksNextParallel +import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE import AlgorithmsInterface as AI -import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -abstract type ParallelAlgorithm{Child} <: AIE.NestedAlgorithm{Child} end -const IterativeParallelAlgorithm{Child <: ParallelAlgorithm} = AIE.NestedAlgorithm{Child} +abstract type AbstractParallelizationStrategy end -""" - get_subiterate(subproblem::AI.Problem, subalgorithm::AI.Algorithm, state::AI.State) +function default_workers end +function initialize_parallel_state end -For a given `subproblem` and `subalgorithm` of a parent nested algorithm, -derive (from the parent state `state`) the iterate to be used in the associated sub state. -The returned value of this function is then pass to a remote call of `initialize_state`. -""" -get_subiterate(::AI.Problem, ::AI.Algorithm, state::AI.State) = state.iterate +@kwdef struct Parallelized{Strategy, Workers, Algorithm <: AI.Algorithm} <: AI.Algorithm + parent::Algorithm + strategy::Strategy + workers::Workers = default_workers(parent, strategy) +end -finalize_state!(::AI.Problem, ::AI.Algorithm, state::AI.State) = state +function Base.getproperty(algorithm::Parallelized, name::Symbol) + if name in (:parent, :strategy, :workers) + return getfield(algorithm, name) + end + return getproperty(getfield(algorithm, :parent), name) +end -function AI.is_finished!( - problem::AI.Problem, - algorithm::IterativeParallelAlgorithm, - state::AI.State +function AI.initialize_state(problem::AI.Problem, algorithm::Parallelized; kwargs...) + return initialize_parallel_state( + problem, + algorithm.parent, + algorithm.strategy; + kwargs... ) - c = algorithm.stopping_criterion - st = state.stopping_criterion_state +end - isfinished = AI.is_finished!(problem, algorithm, state, c, st) +# ====================================== Dagger.jl ======================================= # - if isfinished - finalize_state!(problem, algorithm, state) - end +abstract type AbstractDaggerStrategy <: AbstractParallelizationStrategy end +struct GenericDaggerStrategy <: AbstractDaggerStrategy end - return isfinished +function initialize_parallel_state( + _problem, + _algorithm, + strategy::AbstractDaggerStrategy; + _kwargs... + ) + throw( + ArgumentError( + "package Dagger.jl not loaded; please install and load Dagger.jl to use \ + strategy of type $(typeof(strategy))." + ) + ) end -include("dagger.jl") +function default_workers(algorithm, strategy::AbstractDaggerStrategy) + @warn( + "package Dagger.jl may not be loaded; please install and load Dagger.jl to use \ + strategy of type `$(typeof(strategy))`" + ) + throw(MethodError(default_workers, (algorithm, strategy))) +end end # ITensorNetworksNextParallel From 798d34420dca8ab534adbd93449418fa68b65e96 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 1 Jun 2026 17:40:45 -0400 Subject: [PATCH 63/64] Fix algorithms interface using legacy `finalize_substate` signature. --- .../AlgorithmsInterfaceExtensions.jl | 2 +- test/test_algorithmsinterfaceextensions.jl | 2 +- test/test_beliefpropagation.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index 5ee2764..82b3536 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -58,7 +58,7 @@ end function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State) subproblem, subalgorithm, substate = initialize_subsolve(problem, algorithm, state) AI.solve!(subproblem, subalgorithm, substate) - finalize_substate!(problem, algorithm, substate, state) + finalize_substate!(subproblem, subalgorithm, substate, state) return state end diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index f580826..6ef4ce8 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -115,7 +115,7 @@ end # `finalize_substate!` copies the substate's iterate back into the # parent state. substate = AI.initialize_state(problem, algorithm; iterate = [42.0]) - AIE.finalize_substate!(problem, algorithm, state, substate) + AIE.finalize_substate!(problem, algorithm, substate, state) @test state.iterate == [42.0] end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 084f728..1610c72 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -229,7 +229,7 @@ end cache = ITensorNetworksNext.beliefpropagation( tn, messages, partitions; - stopping_criterion = (; maxiter = 10, tol = 1.0e-10) + stopping_criterion = (; maxiter = 30, tol = 1.0e-10) ) z_bp = exp(bethe_free_energy(tn, cache)) From 38004adf2852eb053f66a74a234336b373bfeef6 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Mon, 1 Jun 2026 17:41:02 -0400 Subject: [PATCH 64/64] Add Dagger.jl compat entry. --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 33a0a3f..5c8f666 100644 --- a/Project.toml +++ b/Project.toml @@ -42,6 +42,7 @@ Adapt = "4.3" AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" +Dagger = "0.19.0" DataGraphs = "0.4" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5"