From f2f9011a51aa37b02803046baddf7924e39b077b Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 15:47:35 -0400 Subject: [PATCH 01/16] Refactor `NestedAlgorithm` hooks: `initialize_subsolve` + `finalize_substate!` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename the `NestedAlgorithm` step hooks from `get_subproblem` / `set_substate!` to `initialize_subsolve` / `finalize_substate!`, and move the indexed-list dispatch (`algorithm.algorithms[state.iteration]`) out of the abstract type's default into `DefaultNestedAlgorithm`'s own `initialize_subsolve` method. The abstract `NestedAlgorithm` now has no default `initialize_subsolve` impl (throws `MethodError`), so subtypes must provide their own — the previous default silently assumed every subtype carried an `algorithms` vector, which is too narrow. `BeliefPropagation` and `BeliefPropagationSweep` get an explicit `AIE.initialize_subsolve` override mirroring the indexed-list shape; the existing `AIE.set_substate!` override on `BeliefPropagationSweep` is renamed to `AIE.finalize_substate!`. Co-Authored-By: Claude Opus 4.7 --- Project.toml | 2 +- .../AlgorithmsInterfaceExtensions.jl | 37 +++++++++++-------- .../beliefpropagationproblem.jl | 17 ++++++++- test/test_algorithmsinterfaceextensions.jl | 12 +++--- 4 files changed, 45 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index 301fb17..c8358d0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" -version = "0.4.2" +version = "0.4.3" authors = ["ITensor developers and contributors"] [workspace] diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index d9edb0d..9eb930f 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -103,32 +103,28 @@ end max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) -function get_subproblem( - problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State +# Subtypes of `NestedAlgorithm` must override `initialize_subsolve` — it +# returns the `(subproblem, subalgorithm, substate)` tuple that the next +# inner `AI.solve!` call consumes. The default `finalize_substate!` copies +# the substate's iterate back into the parent state; subtypes can override +# when more is required. +function initialize_subsolve( + problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] - substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) - return subproblem, subalgorithm, substate + return throw(MethodError(initialize_subsolve, (problem, algorithm, state))) end -function set_substate!( - problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State, substate::AI.State +function finalize_substate!( + problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State, substate::AI.State ) state.iterate = substate.iterate return state end function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State) - # Get the subproblem, subalgorithm, and substate. - subproblem, subalgorithm, substate = get_subproblem(problem, algorithm, state) - - # Solve the subproblem with the subalgorithm. + subproblem, subalgorithm, substate = initialize_subsolve(problem, algorithm, state) AI.solve!(subproblem, subalgorithm, substate) - - # Update the state with the substate. - set_substate!(problem, algorithm, state, substate) - + finalize_substate!(problem, algorithm, state, substate) return state end @@ -150,6 +146,15 @@ function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) return DefaultNestedAlgorithm(; algorithms = f.(iterable), kwargs...) end +function initialize_subsolve( + problem::AI.Problem, algorithm::DefaultNestedAlgorithm, state::AI.State + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate +end + # ============================ FlattenedAlgorithm ========================================== # Flatten a nested algorithm. diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 004e449..e0c65b1 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -139,7 +139,22 @@ function BeliefPropagationSweep(f::Function, edges) return BeliefPropagationSweep(; algorithms = f.(edges)) end -function AIE.set_substate!( +# `BeliefPropagation` and `BeliefPropagationSweep` carry a flat list of +# child algorithms, mirroring `AIE.DefaultNestedAlgorithm`. Each step picks +# the child algorithm by the current iteration index and reuses the parent +# problem. +function AIE.initialize_subsolve( + problem::BeliefPropagationProblem, + algorithm::Union{BeliefPropagation, BeliefPropagationSweep}, + state::AI.State + ) + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate +end + +function AIE.finalize_substate!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, state::AIE.DefaultState, diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 6f80527..60118da 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -213,8 +213,8 @@ end @test state.iteration == 2 end - @testset "get_subproblem and set_substate!" begin - # Test get_subproblem + @testset "initialize_subsolve and finalize_substate!" begin + # Test initialize_subsolve problem = TestProblem([1.0, 2.0]) nested_alg = AIE.nested_algorithm(2) do i return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(1)) @@ -229,17 +229,19 @@ end stopping_criterion_state ) - subproblem, subalgorithm, substate = AIE.get_subproblem(problem, nested_alg, state) + subproblem, subalgorithm, substate = AIE.initialize_subsolve( + problem, nested_alg, state + ) @test subproblem === problem @test subalgorithm === nested_alg.algorithms[1] @test substate.iterate ≈ [5.0, 10.0] - # Test set_substate! + # Test finalize_substate! new_substate = AIE.DefaultState(; iterate = [100.0, 200.0], substate.stopping_criterion_state ) - AIE.set_substate!(problem, nested_alg, state, new_substate) + AIE.finalize_substate!(problem, nested_alg, state, new_substate) @test state.iterate ≈ [100.0, 200.0] end From 314c29485578a4bb075d08910ad4add77cd3f923 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 16:11:22 -0400 Subject: [PATCH 02/16] Strip AIE down to minimal NestedAlgorithm + abstract scaffolding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete from AIE: - `DefaultNestedAlgorithm` (struct, constructor, and its `initialize_subsolve` indexed-list impl). - `nested_algorithm` factory function + `max_iterations`. - All `FlattenedAlgorithm` machinery (struct, state, helper). - `AlgorithmIterator` / `DefaultAlgorithmIterator` / `algorithm_iterator`. - `with_algorithmlogger`. - `NonIterativeAlgorithm` / `DefaultNonIterativeAlgorithmState`. Keep the minimum BP / nested-solve still needs: `Problem`, `Algorithm`, `State`, `DefaultState`, `AI.initialize_state` / `initialize_state!` / `increment!`, and the minimal `NestedAlgorithm` (abstract type + `initialize_subsolve` + `finalize_substate!` + `AI.step!`). Future features can be added back as concrete subtypes when actually needed. Delete the placeholder sweep / DMRG scaffolding (`src/sweeping/`, `test/test_dmrg.jl`, `test/test_sweeping.jl`): `EigsolveRegion` / `EigenProblem` / `dmrg` / `select_algorithm` were not actually wired up (the `solve!` body threw `not implemented yet`), and the tests covered only construction shape. Inline the kwarg-expansion helpers (`extend_columns`, `rows`, ...) that lived in `sweeping/utils.jl` into `beliefpropagation/beliefpropagationproblem.jl` — its only remaining consumer. Prune `test/test_algorithmsinterfaceextensions.jl` to the surface that still exists; add a `TestNestedAlgorithm` concrete subtype mirroring how `BeliefPropagation` shapes itself on top of the new minimal interface, and a test that the bare `initialize_subsolve` default throws. Co-Authored-By: Claude Opus 4.7 --- .../AlgorithmsInterfaceExtensions.jl | 177 --------- src/ITensorNetworksNext.jl | 2 - .../beliefpropagationproblem.jl | 18 +- src/sweeping/eigenproblem.jl | 44 --- src/sweeping/utils.jl | 12 - test/test_algorithmsinterfaceextensions.jl | 365 +++--------------- test/test_dmrg.jl | 34 -- test/test_sweeping.jl | 65 ---- 8 files changed, 58 insertions(+), 659 deletions(-) delete mode 100644 src/sweeping/eigenproblem.jl delete mode 100644 src/sweeping/utils.jl delete mode 100644 test/test_dmrg.jl delete mode 100644 test/test_sweeping.jl diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index 9eb930f..c91e24e 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -47,62 +47,10 @@ function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) return AI.increment!(state) end -# ============================ AlgorithmIterator =========================================== - -abstract type AlgorithmIterator end - -function algorithm_iterator( - problem::Problem, algorithm::Algorithm, state::State - ) - return DefaultAlgorithmIterator(problem, algorithm, state) -end - -function AI.is_finished!(iterator::AlgorithmIterator) - return AI.is_finished!(iterator.problem, iterator.algorithm, iterator.state) -end -function AI.is_finished(iterator::AlgorithmIterator) - return AI.is_finished(iterator.problem, iterator.algorithm, iterator.state) -end -function AI.increment!(iterator::AlgorithmIterator) - return AI.increment!(iterator.problem, iterator.algorithm, iterator.state) -end -function AI.step!(iterator::AlgorithmIterator) - return AI.step!(iterator.problem, iterator.algorithm, iterator.state) -end -function Base.iterate(iterator::AlgorithmIterator, init = nothing) - AI.is_finished!(iterator) && return nothing - AI.increment!(iterator) - AI.step!(iterator) - return iterator.state, nothing -end - -struct DefaultAlgorithmIterator{Problem, Algorithm, State} <: AlgorithmIterator - problem::Problem - algorithm::Algorithm - state::State -end - -# ============================ with_algorithmlogger ======================================== - -# Allow passing functions, not just CallbackActions. -@inline function with_algorithmlogger(f, args::Pair{Symbol, AI.LoggingAction}...) - return AI.with_algorithmlogger(f, args...) -end -@inline function with_algorithmlogger(f, args::Pair{Symbol}...) - return AI.with_algorithmlogger(f, (first.(args) .=> AI.CallbackAction.(last.(args)))...) -end - # ============================ NestedAlgorithm ============================================= abstract type NestedAlgorithm <: Algorithm end -nested_algorithm(f::Function, int::Int; kwargs...) = nested_algorithm(f, 1:int; kwargs...) -function nested_algorithm(f::Function, iterable; kwargs...) - return DefaultNestedAlgorithm(f, iterable; kwargs...) -end - -max_iterations(algorithm::NestedAlgorithm) = length(algorithm.algorithms) - # Subtypes of `NestedAlgorithm` must override `initialize_subsolve` — it # returns the `(subproblem, subalgorithm, substate)` tuple that the next # inner `AI.solve!` call consumes. The default `finalize_substate!` copies @@ -128,129 +76,4 @@ function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.Sta return state end -#= - DefaultNestedAlgorithm(sweeps::AbstractVector{<:Algorithm}) - -An algorithm that consists of running an algorithm at each iteration -from a list of stored algorithms. -=# -@kwdef struct DefaultNestedAlgorithm{ - ChildAlgorithm <: Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - } <: NestedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) -end -function DefaultNestedAlgorithm(f::Function, iterable; kwargs...) - return DefaultNestedAlgorithm(; algorithms = f.(iterable), kwargs...) -end - -function initialize_subsolve( - problem::AI.Problem, algorithm::DefaultNestedAlgorithm, state::AI.State - ) - subproblem = problem - subalgorithm = algorithm.algorithms[state.iteration] - substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) - return subproblem, subalgorithm, substate -end - -# ============================ FlattenedAlgorithm ========================================== - -# Flatten a nested algorithm. -abstract type FlattenedAlgorithm <: Algorithm end -abstract type FlattenedAlgorithmState <: State end - -function flattened_algorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultFlattenedAlgorithm(f, nalgorithms; kwargs...) -end - -function AI.initialize_state( - problem::Problem, algorithm::FlattenedAlgorithm; kwargs... - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - return DefaultFlattenedAlgorithmState(; stopping_criterion_state, kwargs...) -end -function AI.increment!( - problem::Problem, algorithm::Algorithm, state::FlattenedAlgorithmState - ) - # Increment the total iteration count. - state.iteration += 1 - # TODO: Use `is_finished!` instead? - if state.child_iteration ≥ max_iterations(algorithm.algorithms[state.parent_iteration]) - # We're on the last iteration of the child algorithm, so move to the next - # child algorithm. - state.parent_iteration += 1 - state.child_iteration = 1 - else - # Iterate the child algorithm. - state.child_iteration += 1 - end - return state -end -function AI.step!( - problem::AI.Problem, algorithm::FlattenedAlgorithm, state::FlattenedAlgorithmState - ) - algorithm_sweep = algorithm.algorithms[state.parent_iteration] - state_sweep = AI.initialize_state( - problem, algorithm_sweep; - state.iterate, iteration = state.child_iteration - ) - AI.step!(problem, algorithm_sweep, state_sweep) - state.iterate = state_sweep.iterate - return state -end - -@kwdef struct DefaultFlattenedAlgorithm{ - ChildAlgorithm <: Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - } <: FlattenedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = - AI.StopAfterIteration(sum(max_iterations, algorithms)) -end -function DefaultFlattenedAlgorithm(f::Function, nalgorithms::Int; kwargs...) - return DefaultFlattenedAlgorithm(; algorithms = f.(1:nalgorithms), kwargs...) -end - -@kwdef mutable struct DefaultFlattenedAlgorithmState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: FlattenedAlgorithmState - iterate::Iterate - iteration::Int = 0 - parent_iteration::Int = 1 - child_iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState -end - -# ============================ NonIterativeAlgorithm ======================================= - -# Algorithm that only performs a single step. -abstract type NonIterativeAlgorithm <: Algorithm end -abstract type NonIterativeAlgorithmState <: State end - -function AI.initialize_state(problem::Problem, algorithm::NonIterativeAlgorithm; kwargs...) - return DefaultNonIterativeAlgorithmState(; kwargs...) -end - -function AI.initialize_state!( - problem::Problem, - algorithm::NonIterativeAlgorithm, - state::NonIterativeAlgorithmState - ) - return state -end - -function AI.solve_loop!(problem::Problem, algorithm::NonIterativeAlgorithm, state::State) - return throw(MethodError(AI.solve_loop!, (problem, algorithm, state))) -end - -@kwdef mutable struct DefaultNonIterativeAlgorithmState{Iterate} <: - NonIterativeAlgorithmState - iterate::Iterate -end - end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 0b2b898..394546a 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -12,8 +12,6 @@ include("abstracttensornetwork.jl") include("tensornetwork.jl") include("TensorNetworkGenerators/TensorNetworkGenerators.jl") include("contract_network.jl") -include("sweeping/utils.jl") -include("sweeping/eigenproblem.jl") include("beliefpropagation/messagecache.jl") include("beliefpropagation/beliefpropagationproblem.jl") diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index e0c65b1..8f88ff4 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -8,6 +8,19 @@ using NamedDimsArrays: AbstractNamedDimsArray using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedGraphs.PartitionedGraphs: quotientvertices +# Utility functions for processing keyword arguments. +function repeat_last(v::AbstractVector, len::Int) + return [v; fill(v[end], max(len - length(v), 0))] +end +repeat_last(v, len::Int) = fill(v, len) +function extend_columns(nt::NamedTuple, len::Int) + return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) +end +rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) +function rows(nt::NamedTuple, len::Int = rowlength(nt)) + return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] +end + @kwdef struct StopWhenConverged <: AI.StoppingCriterion tol::Float64 end @@ -140,9 +153,8 @@ function BeliefPropagationSweep(f::Function, edges) end # `BeliefPropagation` and `BeliefPropagationSweep` carry a flat list of -# child algorithms, mirroring `AIE.DefaultNestedAlgorithm`. Each step picks -# the child algorithm by the current iteration index and reuses the parent -# problem. +# child algorithms. Each step picks the child algorithm by the current +# iteration index and reuses the parent problem. function AIE.initialize_subsolve( problem::BeliefPropagationProblem, algorithm::Union{BeliefPropagation, BeliefPropagationSweep}, diff --git a/src/sweeping/eigenproblem.jl b/src/sweeping/eigenproblem.jl deleted file mode 100644 index 8fefbd0..0000000 --- a/src/sweeping/eigenproblem.jl +++ /dev/null @@ -1,44 +0,0 @@ -import .AlgorithmsInterfaceExtensions as AIE -import AlgorithmsInterface as AI - -function dmrg(operator, algorithm, state) - problem = EigenProblem(operator) - return AI.solve(problem, algorithm; iterate = state).iterate -end -function dmrg(operator, state; kwargs...) - problem = EigenProblem(operator) - algorithm = select_algorithm(dmrg, operator, state; kwargs...) - return AI.solve(problem, algorithm; iterate = state).iterate -end - -# TODO: Allow specifying the region algorithm type? -function select_algorithm(::typeof(dmrg), operator, state; nsweeps, regions, kwargs...) - extended_kwargs = extend_columns((; kwargs...), nsweeps) - region_kwargs = rows(extended_kwargs) - return AIE.nested_algorithm(nsweeps) do i - return AIE.nested_algorithm(length(regions)) do j - return EigsolveRegion(regions[j]; region_kwargs[i]...) - end - end -end -#= - EigenProblem(operator) - -Represents the problem we are trying to solve and minimal algorithm-independent -information, so for an eigenproblem it is the operator we want the eigenvector of. -=# -struct EigenProblem{Operator} <: AIE.Problem - operator::Operator -end - -struct EigsolveRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm - region::R - kwargs::Kwargs -end -EigsolveRegion(region; kwargs...) = EigsolveRegion(region, (; kwargs...)) - -function AI.solve!( - problem::EigenProblem, algorithm::EigsolveRegion, state::AIE.State; kwargs... - ) - return error("EigsolveRegion step for EigenProblem not implemented yet.") -end diff --git a/src/sweeping/utils.jl b/src/sweeping/utils.jl deleted file mode 100644 index 39e09e4..0000000 --- a/src/sweeping/utils.jl +++ /dev/null @@ -1,12 +0,0 @@ -# Utility functions for processing keyword arguments. -function repeat_last(v::AbstractVector, len::Int) - return [v; fill(v[end], max(len - length(v), 0))] -end -repeat_last(v, len::Int) = fill(v, len) -function extend_columns(nt::NamedTuple, len::Int) - return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) -end -rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) -function rows(nt::NamedTuple, len::Int = rowlength(nt)) - return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] -end diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 60118da..0ec2d16 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -1,6 +1,6 @@ import AlgorithmsInterface as AI import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -using Test: @test, @testset +using Test: @test, @test_throws, @testset # Define test problems, algorithms, and states for testing struct TestProblem <: AIE.Problem @@ -11,10 +11,6 @@ end stopping_criterion::StoppingCriterion = AI.StopAfterIteration(10) end -@kwdef struct TestAlgorithmStep{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(5) -end - function AI.step!( problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState ) @@ -22,16 +18,29 @@ function AI.step!( return state end -function AI.step!( - problem::TestProblem, algorithm::TestAlgorithmStep, state::AIE.DefaultState +# Concrete `NestedAlgorithm` subtype: holds a flat list of child algorithms +# and picks them by iteration index. Mirrors how `BeliefPropagation` shapes +# itself on top of the new minimal `AIE.NestedAlgorithm`. +@kwdef struct TestNestedAlgorithm{ + ChildAlgorithm <: AIE.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) +end + +function AIE.initialize_subsolve( + problem::TestProblem, algorithm::TestNestedAlgorithm, state::AI.State ) - state.iterate .+= 2 # Different increment step - return state + subproblem = problem + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate end @testset "AlgorithmsInterfaceExtensions" begin @testset "DefaultState" begin - # Test DefaultState construction iterate = [1.0, 2.0, 3.0] stopping_criterion_state = AI.initialize_state( TestProblem([1.0]), TestAlgorithm(), TestAlgorithm().stopping_criterion @@ -41,13 +50,11 @@ end @test state.iteration == 0 @test state.stopping_criterion_state isa AI.StoppingCriterionState - # Test DefaultState with custom iteration state.iteration = 5 @test state.iteration == 5 end @testset "initialize_state!" begin - # Test initialize_state! with iterate kwarg problem = TestProblem([1.0, 2.0]) algorithm = TestAlgorithm() stopping_criterion_state = AI.initialize_state( @@ -63,17 +70,14 @@ end end @testset "initialize_state" begin - # Test initialize_state without exclamation problem = TestProblem([1.0, 2.0]) algorithm = TestAlgorithm() - state = AI.initialize_state(problem, algorithm; iterate = [0.0, 0.0]) @test state isa AIE.DefaultState @test state.iteration == 0 end @testset "increment!" begin - # Test increment! with problem and algorithm problem = TestProblem([1.0, 2.0]) algorithm = TestAlgorithm() stopping_criterion_state = AI.initialize_state( @@ -81,352 +85,69 @@ end ) state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) - # Increment and verify iteration counter increases AI.increment!(problem, algorithm, state) @test state.iteration == 1 - AI.increment!(problem, algorithm, state) @test state.iteration == 2 end @testset "solve! and solve" begin - # Test solve! with simple problem problem = TestProblem([1.0, 2.0]) algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(3)) + state = AI.initialize_state(problem, algorithm; iterate = [10.0, 20.0]) - initial_iterate = [10.0, 20.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - - # Solve with custom initial iterate initial_iterate = [5.0, 10.0] final_iterate = AI.solve!( problem, algorithm, state; iterate = copy(initial_iterate) ) - @test state.iteration == 3 @test final_iterate == state.iterate # Each step increments by 1, so after 3 steps: [5, 10] + 3 = [8, 13] @test state.iterate ≈ [8.0, 13.0] - # Test solve without exclamation problem2 = TestProblem([1.0, 2.0]) algorithm2 = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) - initial_iterate2 = [5.0, 10.0] - - final_iterate2 = AI.solve(problem2, algorithm2; iterate = copy(initial_iterate2)) + final_iterate2 = AI.solve(problem2, algorithm2; iterate = [5.0, 10.0]) @test final_iterate2 ≈ [7.0, 12.0] end - @testset "DefaultAlgorithmIterator" begin - # Test algorithm iterator creation - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) - initial_iterate = [0.0, 0.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - iterator = AIE.algorithm_iterator(problem, algorithm, state) - - @test iterator isa AIE.DefaultAlgorithmIterator - @test iterator.problem === problem - @test iterator.algorithm === algorithm - @test iterator.state === state - - # Test iteration interface - @test !AI.is_finished!(iterator) - - # Step through iterator - state_out, _ = iterate(iterator) - @test state_out.iteration == 1 - @test state_out.iterate ≈ [1.0, 1.0] # Incremented by step! - - state_out, _ = iterate(iterator) - @test state_out.iteration == 2 - - @test AI.is_finished!(iterator) - end - - @testset "DefaultNestedAlgorithm" begin - # Test creating nested algorithm with function - nested_alg = AIE.nested_algorithm(3) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - - @test nested_alg isa AIE.DefaultNestedAlgorithm - @test length(nested_alg.algorithms) == 3 - @test AIE.max_iterations(nested_alg) == 3 - - # Test stepping through nested algorithm - problem = TestProblem([1.0, 2.0]) - stopping_criterion_state = AI.initialize_state( - problem, nested_alg, nested_alg.stopping_criterion - ) - state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) - - initial_iterate = [0.0, 0.0] - AI.solve!( - problem, nested_alg, state; iterate = copy(initial_iterate) - ) - - @test state.iteration == 3 - # Each nested algorithm runs once with 2 steps, incrementing by 2 - # Total: 3 algorithms × 2 iterations × 2 increment = 12 - @test state.iterate ≈ [12.0, 12.0] - end - - @testset "NestedAlgorithm basic tests" begin - # Test basic nested algorithm functionality - nested_alg = AIE.nested_algorithm(2) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - - problem = TestProblem([1.0, 2.0]) - - # Test state initialization - state_nested = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) - - @test state_nested isa AIE.DefaultState - @test state_nested.iteration == 0 - @test AIE.max_iterations(nested_alg) == 2 - end - - @testset "increment! for nested algorithms" begin - # Test increment! logic for nested algorithm state + @testset "NestedAlgorithm defaults" begin + # The bare `initialize_subsolve` default throws a `MethodError`, + # forcing concrete subtypes to provide their own override. problem = TestProblem([1.0]) - nested_alg = AIE.nested_algorithm(2) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - - stopping_criterion_state = AI.initialize_state( - problem, nested_alg, nested_alg.stopping_criterion - ) - state = AIE.DefaultState(; - iterate = [0.0], - stopping_criterion_state = stopping_criterion_state - ) - - # Test progression through iterations - @test state.iteration == 0 - - AI.increment!(problem, nested_alg, state) - @test state.iteration == 1 - - AI.increment!(problem, nested_alg, state) - @test state.iteration == 2 - end - - @testset "initialize_subsolve and finalize_substate!" begin - # Test initialize_subsolve - problem = TestProblem([1.0, 2.0]) - nested_alg = AIE.nested_algorithm(2) do i - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(1)) - end - - stopping_criterion_state = AI.initialize_state( - problem, nested_alg, nested_alg.stopping_criterion - ) + algorithm = TestAlgorithm() state = AIE.DefaultState(; - iterate = [5.0, 10.0], - iteration = 1, - stopping_criterion_state - ) - - subproblem, subalgorithm, substate = AIE.initialize_subsolve( - problem, nested_alg, state - ) - @test subproblem === problem - @test subalgorithm === nested_alg.algorithms[1] - @test substate.iterate ≈ [5.0, 10.0] - - # Test finalize_substate! - new_substate = AIE.DefaultState(; - iterate = [100.0, 200.0], - substate.stopping_criterion_state - ) - AIE.finalize_substate!(problem, nested_alg, state, new_substate) - @test state.iterate ≈ [100.0, 200.0] - end - - @testset "DefaultFlattenedAlgorithm" begin - # Create nested algorithms that support max_iterations - nested_algs = map(1:3) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - flattened_alg = AIE.DefaultFlattenedAlgorithm(; - algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(6) # 3 algorithms × 2 iterations each - ) - - @test flattened_alg isa AIE.DefaultFlattenedAlgorithm - @test length(flattened_alg.algorithms) == 3 - - # Test state initialization - problem = TestProblem([1.0, 2.0]) - state_flat = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) - - @test state_flat isa AIE.DefaultFlattenedAlgorithmState - @test state_flat.iteration == 0 - @test state_flat.parent_iteration == 1 - @test state_flat.child_iteration == 0 - end - - @testset "DefaultFlattenedAlgorithmState increment!" begin - # Create nested algorithms for flattened algorithm - nested_algs = map(1:2) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - flattened_alg = AIE.DefaultFlattenedAlgorithm(; - algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(4) - ) - - problem = TestProblem([1.0]) - stopping_criterion_state = AI.initialize_state( - problem, flattened_alg, flattened_alg.stopping_criterion - ) - state = AIE.DefaultFlattenedAlgorithmState(; iterate = [0.0], - stopping_criterion_state = stopping_criterion_state + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion + ) ) + @test_throws MethodError AIE.initialize_subsolve(problem, algorithm, state) - # Test initial state - @test state.iteration == 0 - @test state.parent_iteration == 1 - @test state.child_iteration == 0 - - # First increment - should increment child_iteration - AI.increment!(problem, flattened_alg, state) - @test state.iteration == 1 - @test state.parent_iteration == 1 - @test state.child_iteration == 1 - - # Second increment - should increment child_iteration again - AI.increment!(problem, flattened_alg, state) - @test state.iteration == 2 - @test state.parent_iteration == 2 # Should move to next parent - @test state.child_iteration == 1 - end - - @testset "FlattenedAlgorithm step!" begin - # Test individual step! calls for flattened algorithm - nested_algs = map(1:2) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - flattened_alg = AIE.DefaultFlattenedAlgorithm(; - algorithms = nested_algs, - stopping_criterion = AI.StopAfterIteration(4) + # `finalize_substate!` copies the substate's iterate back into the + # parent state. + substate = AIE.DefaultState(; + iterate = [42.0], + stopping_criterion_state = state.stopping_criterion_state ) - - problem = TestProblem([1.0, 2.0]) - state = AI.initialize_state(problem, flattened_alg; iterate = [0.0, 0.0]) - - # Manually step through to test step! functionality - AI.increment!(problem, flattened_alg, state) - @test state.parent_iteration == 1 - @test state.child_iteration == 1 - - AI.step!(problem, flattened_alg, state) - # The nested algorithm runs TestAlgorithmStep with 2 iterations, each incrementing by 2 - @test state.iterate ≈ [4.0, 4.0] - end - - @testset "flattened_algorithm helper" begin - # Test the flattened_algorithm helper function - nested_algs = map(1:2) do i - return AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - # Using the helper function - flattened_alg = AIE.flattened_algorithm(2) do i - AIE.nested_algorithm(1) do j - return TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)) - end - end - - @test flattened_alg isa AIE.DefaultFlattenedAlgorithm - @test length(flattened_alg.algorithms) == 2 + AIE.finalize_substate!(problem, algorithm, state, substate) + @test state.iterate == [42.0] end - @testset "AlgorithmIterator is_finished (without !)" begin - # Test is_finished without mutation + @testset "TestNestedAlgorithm" begin problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) - initial_iterate = [0.0, 0.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - iterator = AIE.algorithm_iterator(problem, algorithm, state) - - # Before any iterations - @test !AI.is_finished(iterator) - - # Run the algorithm - AI.solve!(problem, algorithm, state; iterate = copy(initial_iterate)) - - # After completion - @test AI.is_finished(iterator) - end - - @testset "AlgorithmIterator step!" begin - # Test step! method for iterator - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) - initial_iterate = [0.0, 0.0] - state = AI.initialize_state(problem, algorithm; iterate = copy(initial_iterate)) - iterator = AIE.algorithm_iterator(problem, algorithm, state) - - # Step the iterator - AI.step!(iterator) - @test iterator.state.iterate ≈ [1.0, 1.0] - - AI.step!(iterator) - @test iterator.state.iterate ≈ [2.0, 2.0] - end - - @testset "NestedAlgorithm with different sub-algorithms" begin - # Test nested algorithm with varying sub-algorithms - nested_alg = AIE.DefaultNestedAlgorithm(; + nested_alg = TestNestedAlgorithm(; algorithms = [ TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), - TestAlgorithmStep(; stopping_criterion = AI.StopAfterIteration(2)), - TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)), ] ) + @test nested_alg isa AIE.NestedAlgorithm - @test AIE.max_iterations(nested_alg) == 3 - @test length(nested_alg.algorithms) == 3 - - problem = TestProblem([1.0, 2.0]) state = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) - AI.solve!(problem, nested_alg, state; iterate = [0.0, 0.0]) - - # First algorithm: 1 iteration × 1 increment = 1 - # Second algorithm: 2 iterations × 2 increment = 4 - # Third algorithm: 1 iteration × 1 increment = 1 - # Total: 1 + 4 + 1 = 6 - @test state.iterate ≈ [6.0, 6.0] - @test state.iteration == 3 - end - - @testset "Edge cases" begin - # Test with single nested algorithm - nested_alg = AIE.nested_algorithm(1) do i - return TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)) - end - - problem = TestProblem([1.0]) - state = AI.initialize_state(problem, nested_alg; iterate = [0.0]) - AI.solve!(problem, nested_alg, state; iterate = [0.0]) - - @test state.iterate ≈ [1.0] - @test state.iteration == 1 + # Two child algorithms: 1 iter + 2 iter = 3 inner increments total. + @test state.iteration == 2 + @test state.iterate ≈ [3.0, 3.0] end end diff --git a/test/test_dmrg.jl b/test/test_dmrg.jl deleted file mode 100644 index dba2570..0000000 --- a/test/test_dmrg.jl +++ /dev/null @@ -1,34 +0,0 @@ -import AlgorithmsInterface as AI -import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -using ITensorNetworksNext: EigsolveRegion, dmrg, select_algorithm -using Test: @test, @testset - -@testset "select_algorithm(dmrg, ...)" begin - operator = "operator" - init = "init" - nsweeps = 3 - regions = ["region1", "region2"] - maxdim = [10, 20] - cutoff = 1.0e-7 - algorithm = select_algorithm(dmrg, operator, init; nsweeps, regions, maxdim, cutoff) - @test algorithm isa AIE.NestedAlgorithm - @test length(algorithm.algorithms) == nsweeps - - maxdims = [10, 20, 20] - cutoffs = [1.0e-7, 1.0e-7, 1.0e-7] - algorithm′ = AIE.nested_algorithm(nsweeps) do i - return AIE.nested_algorithm(length(regions)) do j - return EigsolveRegion( - regions[j]; - maxdim = maxdims[i], - cutoff = cutoffs[i] - ) - end - end - for i in 1:nsweeps - for j in 1:length(regions) - @test algorithm.algorithms[i].algorithms[j] == - algorithm′.algorithms[i].algorithms[j] - end - end -end diff --git a/test/test_sweeping.jl b/test/test_sweeping.jl deleted file mode 100644 index 01881d9..0000000 --- a/test/test_sweeping.jl +++ /dev/null @@ -1,65 +0,0 @@ -import AlgorithmsInterface as AI -import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE -using Test: @test, @testset - -struct TestProblem <: AIE.Problem -end - -struct TestRegion{R, Kwargs <: NamedTuple} <: AIE.NonIterativeAlgorithm - region::R - kwargs::Kwargs -end -TestRegion(region; kwargs...) = TestRegion(region, (; kwargs...)) - -function AI.solve_loop!(problem::TestProblem, algorithm::TestRegion, state::AIE.State) - new_iterate = (; algorithm.region, algorithm.kwargs.foo, algorithm.kwargs.bar) - state.iterate = [state.iterate; [new_iterate]] - return state -end - -@testset "Sweeping" begin - @testset "TestRegion" begin - algorithm = TestRegion("region"; foo = 1, bar = 2) - @test algorithm isa AIE.NonIterativeAlgorithm - @test algorithm isa AIE.Algorithm - @test algorithm isa AI.Algorithm - @test algorithm.region == "region" - @test algorithm.kwargs == (; foo = 1, bar = 2) - - problem = TestProblem() - iterate = [] - iterate = AI.solve(problem, algorithm; iterate) - @test iterate == [(; region = "region", foo = 1, bar = 2)] - end - @testset "Sweep" begin - algorithm = AIE.nested_algorithm(3) do i - return TestRegion("region$i"; foo = i, bar = 2i) - end - problem = TestProblem() - iterate = [] - iterate = AI.solve(problem, algorithm; iterate) - @test iterate == [ - (; region = "region1", foo = 1, bar = 2), - (; region = "region2", foo = 2, bar = 4), - (; region = "region3", foo = 3, bar = 6), - ] - end - @testset "Sweeping" begin - algorithm = AIE.nested_algorithm(2) do i - AIE.nested_algorithm(3) do j - return TestRegion("sweep$i, region$j"; foo = (i, j), bar = (2i, 2j)) - end - end - problem = TestProblem() - iterate = [] - iterate = AI.solve(problem, algorithm; iterate) - @test iterate == [ - (; region = "sweep1, region1", foo = (1, 1), bar = (2, 2)), - (; region = "sweep1, region2", foo = (1, 2), bar = (2, 4)), - (; region = "sweep1, region3", foo = (1, 3), bar = (2, 6)), - (; region = "sweep2, region1", foo = (2, 1), bar = (4, 2)), - (; region = "sweep2, region2", foo = (2, 2), bar = (4, 4)), - (; region = "sweep2, region3", foo = (2, 3), bar = (4, 6)), - ] - end -end From e35b1a20edc60424d68b9a962ca1e66108a59c8e Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 16:19:43 -0400 Subject: [PATCH 03/16] Drop AIE `Problem` / `Algorithm` / `State` / `DefaultState` scaffolding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `AlgorithmsInterfaceExtensions` is now just the `NestedAlgorithm` abstract type plus `initialize_subsolve` / `finalize_substate!` / `AI.step!`. The `Problem` / `Algorithm` / `State` abstract types and `DefaultState` are gone, along with the `AI.initialize_state` / `initialize_state!` / `increment!` overloads that hung off them. Belief propagation grows its own state machinery instead of leaning on AIE: a `BeliefPropagationState <: AI.State` (mutable, `iterate` / `iteration` / `stopping_criterion_state`), plus per-algorithm `AI.initialize_state` / `initialize_state!` / `increment!` overloads on `Union{BeliefPropagation, BeliefPropagationSweep}`. This mirrors the pattern `BPApplyGate` already uses in the apply-operator path — each algorithm owns its state, no shared default. `BeliefPropagationProblem`, `BeliefPropagation`, and `BeliefPropagationSweep` now subtype the bare `AI.Problem` / `AI.Algorithm` types; the `StopWhenConverged` dispatches accept `AI.Problem` / `AI.Algorithm` / `AI.State` since `StopWhenConverged` itself is the unique type doing the disambiguation. `test_algorithmsinterfaceextensions.jl` is rewritten to define its own `TestProblem` / `TestChildAlgorithm` / `TestChildState` / `TestNestedAlgorithm` directly on `AI.*`, so the test exercises the same "each algorithm owns its state" pattern. Co-Authored-By: Claude Opus 4.7 --- .../AlgorithmsInterfaceExtensions.jl | 47 +---- .../beliefpropagationproblem.jl | 80 ++++++-- test/test_algorithmsinterfaceextensions.jl | 182 ++++++++---------- 3 files changed, 151 insertions(+), 158 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index c91e24e..ddf1083 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -2,54 +2,9 @@ module AlgorithmsInterfaceExtensions import AlgorithmsInterface as AI -# ========================== Patches for AlgorithmsInterface.jl ============================ - -abstract type Problem <: AI.Problem end -abstract type Algorithm <: AI.Algorithm end -abstract type State <: AI.State end - -function AI.initialize_state!( - problem::Problem, algorithm::Algorithm, state::State; iteration = 0, kwargs... - ) - for (k, v) in pairs(kwargs) - setproperty!(state, k, v) - end - state.iteration = iteration - AI.initialize_state!( - problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state - ) - return state -end - -function AI.initialize_state( - problem::Problem, algorithm::Algorithm; iterate, kwargs... - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion; iterate - ) - return DefaultState(; iterate, stopping_criterion_state, kwargs...) -end - -# ============================ DefaultState ================================================ - -@kwdef mutable struct DefaultState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: State - iterate::Iterate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState -end - -# ============================ increment! ================================================== - -# Custom version of `increment!` that also takes the problem and algorithm as arguments. -function AI.increment!(problem::Problem, algorithm::Algorithm, state::State) - return AI.increment!(state) -end - # ============================ NestedAlgorithm ============================================= -abstract type NestedAlgorithm <: Algorithm end +abstract type NestedAlgorithm <: AI.Algorithm end # Subtypes of `NestedAlgorithm` must override `initialize_subsolve` — it # returns the `(subproblem, subalgorithm, substate)` tuple that the next diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 8f88ff4..91402ff 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -31,13 +31,13 @@ end previous_iterate::Iterate end -function AI.initialize_state(::AIE.Problem, ::AIE.Algorithm, ::StopWhenConverged; iterate) +function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate) return StopWhenConvergedState(; previous_iterate = copy(iterate)) end function AI.initialize_state!( - ::AIE.Problem, - ::AIE.Algorithm, + ::AI.Problem, + ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState ) @@ -46,9 +46,9 @@ function AI.initialize_state!( end function AI.is_finished!( - problem::AIE.Problem, - algorithm::AIE.Algorithm, - state::AIE.State, + problem::AI.Problem, + algorithm::AI.Algorithm, + state::AI.State, c::StopWhenConverged, st::StopWhenConvergedState ) @@ -73,19 +73,30 @@ function AI.is_finished!( end function AI.is_finished( - ::AIE.Problem, - ::AIE.Algorithm, - ::AIE.State, + ::AI.Problem, + ::AI.Algorithm, + ::AI.State, c::StopWhenConverged, st::StopWhenConvergedState ) return st.delta < c.tol end -struct BeliefPropagationProblem{Factors} <: AIE.Problem +struct BeliefPropagationProblem{Factors} <: AI.Problem factors::Factors end +# Shared state type for the BP NestedAlgorithm subtypes +# (`BeliefPropagation` / `BeliefPropagationSweep`). Mirrors the small +# state shape `BPApplyGate` uses in the apply-operator path. +@kwdef mutable struct BeliefPropagationState{ + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, + } <: AI.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + function iterate_diff( cache1::MessageCache, cache2::MessageCache @@ -98,7 +109,7 @@ function iterate_diff( end @kwdef struct BeliefPropagation{ - ChildAlgorithm <: AIE.Algorithm, + ChildAlgorithm <: AI.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm @@ -152,12 +163,53 @@ function BeliefPropagationSweep(f::Function, edges) return BeliefPropagationSweep(; algorithms = f.(edges)) end +# State construction / reset / increment for the BP NestedAlgorithm +# pair. Both `BeliefPropagation` and `BeliefPropagationSweep` carry a +# `stopping_criterion`, so the standard AlgorithmsInterface state chain +# resolves the same way. +const BeliefPropagationLikeAlgorithm = Union{BeliefPropagation, BeliefPropagationSweep} + +function AI.initialize_state( + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationLikeAlgorithm; + iterate, kwargs... + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return BeliefPropagationState(; iterate, stopping_criterion_state, kwargs...) +end + +function AI.initialize_state!( + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationLikeAlgorithm, + state::BeliefPropagationState; + iteration = 0, kwargs... + ) + for (k, v) in pairs(kwargs) + setproperty!(state, k, v) + end + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + +function AI.increment!( + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationLikeAlgorithm, + state::BeliefPropagationState + ) + return AI.increment!(state) +end + # `BeliefPropagation` and `BeliefPropagationSweep` carry a flat list of # child algorithms. Each step picks the child algorithm by the current # iteration index and reuses the parent problem. function AIE.initialize_subsolve( problem::BeliefPropagationProblem, - algorithm::Union{BeliefPropagation, BeliefPropagationSweep}, + algorithm::BeliefPropagationLikeAlgorithm, state::AI.State ) subproblem = problem @@ -169,7 +221,7 @@ end function AIE.finalize_substate!( ::BeliefPropagationProblem, ::BeliefPropagationSweep, - state::AIE.DefaultState, + state::BeliefPropagationState, cache::MessageCache ) state.iterate = cache @@ -211,7 +263,7 @@ function beliefpropagation( if isnothing(maxiter) throw( ArgumentError( - "`maxiter` must be specified for non-tree graphs, even when + "`maxiter` must be specified for non-tree graphs, even when `stopping_criterion` is provided." ) ) diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 0ec2d16..1345bfd 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -2,27 +2,60 @@ import AlgorithmsInterface as AI import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE using Test: @test, @test_throws, @testset -# Define test problems, algorithms, and states for testing -struct TestProblem <: AIE.Problem - data::Vector{Float64} +# Concrete `NestedAlgorithm` subtype: holds a flat list of child algorithms +# and picks them by iteration index. Mirrors how `BeliefPropagation` shapes +# itself on top of the minimal `AIE.NestedAlgorithm`. +struct TestProblem <: AI.Problem end + +@kwdef struct TestChildAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AI.Algorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(2) +end + +@kwdef mutable struct TestChildState{SCState <: AI.StoppingCriterionState} <: AI.State + iterate::Vector{Float64} + iteration::Int = 0 + stopping_criterion_state::SCState +end + +function AI.initialize_state( + problem::TestProblem, algorithm::TestChildAlgorithm; + iterate, kwargs... + ) + sc_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return TestChildState(; iterate, stopping_criterion_state = sc_state, kwargs...) +end + +function AI.initialize_state!( + problem::TestProblem, algorithm::TestChildAlgorithm, state::TestChildState; + iteration = 0, kwargs... + ) + for (k, v) in pairs(kwargs) + setproperty!(state, k, v) + end + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state end -@kwdef struct TestAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AIE.Algorithm - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(10) +function AI.increment!( + problem::TestProblem, algorithm::TestChildAlgorithm, state::TestChildState + ) + return AI.increment!(state) end function AI.step!( - problem::TestProblem, algorithm::TestAlgorithm, state::AIE.DefaultState + problem::TestProblem, algorithm::TestChildAlgorithm, state::TestChildState ) - state.iterate .+= 1 # Simple increment step + state.iterate .+= 1 return state end -# Concrete `NestedAlgorithm` subtype: holds a flat list of child algorithms -# and picks them by iteration index. Mirrors how `BeliefPropagation` shapes -# itself on top of the new minimal `AIE.NestedAlgorithm`. @kwdef struct TestNestedAlgorithm{ - ChildAlgorithm <: AIE.Algorithm, + ChildAlgorithm <: AI.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm @@ -30,6 +63,37 @@ end stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end +# Reuse the child-state shape for the parent algorithm too. +function AI.initialize_state( + problem::TestProblem, algorithm::TestNestedAlgorithm; + iterate, kwargs... + ) + sc_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return TestChildState(; iterate, stopping_criterion_state = sc_state, kwargs...) +end + +function AI.initialize_state!( + problem::TestProblem, algorithm::TestNestedAlgorithm, state::TestChildState; + iteration = 0, kwargs... + ) + for (k, v) in pairs(kwargs) + setproperty!(state, k, v) + end + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + +function AI.increment!( + problem::TestProblem, algorithm::TestNestedAlgorithm, state::TestChildState + ) + return AI.increment!(state) +end + function AIE.initialize_subsolve( problem::TestProblem, algorithm::TestNestedAlgorithm, state::AI.State ) @@ -40,113 +104,35 @@ function AIE.initialize_subsolve( end @testset "AlgorithmsInterfaceExtensions" begin - @testset "DefaultState" begin - iterate = [1.0, 2.0, 3.0] - stopping_criterion_state = AI.initialize_state( - TestProblem([1.0]), TestAlgorithm(), TestAlgorithm().stopping_criterion - ) - state = AIE.DefaultState(; iterate = copy(iterate), stopping_criterion_state) - @test state.iterate == iterate - @test state.iteration == 0 - @test state.stopping_criterion_state isa AI.StoppingCriterionState - - state.iteration = 5 - @test state.iteration == 5 - end - - @testset "initialize_state!" begin - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm() - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - state = AIE.DefaultState(; - iteration = 2, iterate = [0.0, 0.0], stopping_criterion_state - ) - AI.initialize_state!(problem, algorithm, state) - @test state.iterate == [0.0, 0.0] - @test state.iteration == 0 - @test state.stopping_criterion_state == stopping_criterion_state - end - - @testset "initialize_state" begin - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm() - state = AI.initialize_state(problem, algorithm; iterate = [0.0, 0.0]) - @test state isa AIE.DefaultState - @test state.iteration == 0 - end - - @testset "increment!" begin - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm() - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - state = AIE.DefaultState(; iterate = [0.0, 0.0], stopping_criterion_state) - - AI.increment!(problem, algorithm, state) - @test state.iteration == 1 - AI.increment!(problem, algorithm, state) - @test state.iteration == 2 - end - - @testset "solve! and solve" begin - problem = TestProblem([1.0, 2.0]) - algorithm = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(3)) - state = AI.initialize_state(problem, algorithm; iterate = [10.0, 20.0]) - - initial_iterate = [5.0, 10.0] - final_iterate = AI.solve!( - problem, algorithm, state; iterate = copy(initial_iterate) - ) - @test state.iteration == 3 - @test final_iterate == state.iterate - # Each step increments by 1, so after 3 steps: [5, 10] + 3 = [8, 13] - @test state.iterate ≈ [8.0, 13.0] - - problem2 = TestProblem([1.0, 2.0]) - algorithm2 = TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)) - final_iterate2 = AI.solve(problem2, algorithm2; iterate = [5.0, 10.0]) - @test final_iterate2 ≈ [7.0, 12.0] - end - @testset "NestedAlgorithm defaults" begin # The bare `initialize_subsolve` default throws a `MethodError`, # forcing concrete subtypes to provide their own override. - problem = TestProblem([1.0]) - algorithm = TestAlgorithm() - state = AIE.DefaultState(; - iterate = [0.0], - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion - ) - ) + problem = TestProblem() + algorithm = TestChildAlgorithm() + state = AI.initialize_state(problem, algorithm; iterate = [0.0]) @test_throws MethodError AIE.initialize_subsolve(problem, algorithm, state) # `finalize_substate!` copies the substate's iterate back into the # parent state. - substate = AIE.DefaultState(; - iterate = [42.0], - stopping_criterion_state = state.stopping_criterion_state - ) + substate = AI.initialize_state(problem, algorithm; iterate = [42.0]) AIE.finalize_substate!(problem, algorithm, state, substate) @test state.iterate == [42.0] end @testset "TestNestedAlgorithm" begin - problem = TestProblem([1.0, 2.0]) + problem = TestProblem() nested_alg = TestNestedAlgorithm(; algorithms = [ - TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), - TestAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)), + TestChildAlgorithm(; stopping_criterion = AI.StopAfterIteration(1)), + TestChildAlgorithm(; stopping_criterion = AI.StopAfterIteration(2)), ] ) @test nested_alg isa AIE.NestedAlgorithm state = AI.initialize_state(problem, nested_alg; iterate = [0.0, 0.0]) AI.solve!(problem, nested_alg, state; iterate = [0.0, 0.0]) - # Two child algorithms: 1 iter + 2 iter = 3 inner increments total. + # Two child algorithms: 1 inner step + 2 inner steps = 3 total + # `state.iterate .+= 1` calls. @test state.iteration == 2 @test state.iterate ≈ [3.0, 3.0] end From b1821ce82ba6ac3fa6dd0e2cf8c0fc7677fc9d52 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 16:41:41 -0400 Subject: [PATCH 04/16] Refactor BP into three problem/algorithm/state triples MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Layer 1 — outer BP loop: `BeliefPropagationProblem` / `BeliefPropagationAlgorithm` / `BeliefPropagationState` (iterative, `<: AIE.NestedAlgorithm`) Layer 2 — one sweep over edges: `BeliefPropagationSweepProblem` / `BeliefPropagationSweepAlgorithm` / `BeliefPropagationSweepState` (iterative, `<: AIE.NestedAlgorithm`) Layer 3 — single-edge message update: `MessageUpdateProblem` / `SimpleMessageUpdateAlgorithm` / `MessageUpdateState` (non-iterative, overrides `AI.solve_loop!` — same shape `BPApplyGate` uses in the gate-apply code) Every layer subtypes `AI.Problem` / `AI.Algorithm` / `AI.State` directly. No `Union{...}` shortcuts — each layer's `initialize_state` / `initialize_state!` / `initialize_subsolve` is its own method. The `StopWhenConverged` dispatches accept `AI.Problem` / `AI.Algorithm` / `AI.State` since `StopWhenConverged` itself is the unique disambiguating type. Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 253 ++++++++++-------- test/test_algorithmsinterfaceextensions.jl | 4 +- 2 files changed, 140 insertions(+), 117 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 91402ff..1202e2b 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -21,6 +21,8 @@ function rows(nt::NamedTuple, len::Int = rowlength(nt)) return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] end +# === `StopWhenConverged` stopping criterion === + @kwdef struct StopWhenConverged <: AI.StoppingCriterion tol::Float64 end @@ -36,10 +38,7 @@ function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; end function AI.initialize_state!( - ::AI.Problem, - ::AI.Algorithm, - ::StopWhenConverged, - st::StopWhenConvergedState + ::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState ) st.delta = Inf return st @@ -59,7 +58,7 @@ function AI.is_finished!( st.previous_iterate = copy(iterate) - # maxdiff = 0.0 initially, so skip this the first time. + # delta = 0 initially, so skip this the first time. state.iteration == 0 && return false st.delta = delta @@ -82,33 +81,78 @@ function AI.is_finished( return st.delta < c.tol end -struct BeliefPropagationProblem{Factors} <: AI.Problem +function iterate_diff(cache1::MessageCache, cache2::MessageCache) + return maximum(edges(cache1)) do edge + m1 = cache1[edge] + m2 = cache2[edge] + return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) + end +end + +# === Layer 3: single-edge message update (non-iterative) === + +struct MessageUpdateProblem{Factors} <: AI.Problem factors::Factors end -# Shared state type for the BP NestedAlgorithm subtypes -# (`BeliefPropagation` / `BeliefPropagationSweep`). Mirrors the small -# state shape `BPApplyGate` uses in the apply-operator path. -@kwdef mutable struct BeliefPropagationState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: AI.State +@kwdef struct SimpleMessageUpdateAlgorithm{ + E <: AbstractEdge, ContractionAlg, + } <: AI.Algorithm + edge::E + normalize::Bool = true + contraction_alg::ContractionAlg = Algorithm"exact" +end + +@kwdef mutable struct MessageUpdateState{Iterate} <: AI.State iterate::Iterate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState end -function iterate_diff( - cache1::MessageCache, - cache2::MessageCache +function AI.initialize_state( + ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm; iterate ) - return maximum(edges(cache1)) do edge - m1 = cache1[edge] - m2 = cache2[edge] - return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) + return MessageUpdateState(; iterate) +end + +# Non-iterative algorithm: no per-call state to reset. +function AI.initialize_state!( + ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm, state::MessageUpdateState + ) + return state +end + +# Non-iterative algorithm: bypass the step!/stopping-criterion loop. +function AI.solve_loop!( + problem::MessageUpdateProblem, + algorithm::SimpleMessageUpdateAlgorithm, + state::MessageUpdateState + ) + cache = state.iterate + edge = algorithm.edge + + messages = collect(incoming_messages(cache, edge)) + factor = problem.factors[src(edge)] + + new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) + + if algorithm.normalize + message_norm = sum(new_message) + if !iszero(message_norm) + new_message /= message_norm + end end + + cache[edge] = new_message + + return state +end + +# === Layer 2: one sweep over edges (iterative) === + +struct BeliefPropagationSweepProblem{Factors} <: AI.Problem + factors::Factors end -@kwdef struct BeliefPropagation{ +@kwdef struct BeliefPropagationSweepAlgorithm{ ChildAlgorithm <: AI.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, @@ -117,78 +161,101 @@ end stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -function BeliefPropagation(f::Function, niterations::Int; kwargs...) - return BeliefPropagation(; algorithms = f.(1:niterations), kwargs...) +function BeliefPropagationSweepAlgorithm(f::Function, edges) + return BeliefPropagationSweepAlgorithm(; algorithms = f.(edges)) end -struct SimpleMessageUpdate{E <: AbstractEdge, Kwargs <: NamedTuple} - edge::E - kwargs::Kwargs +@kwdef mutable struct BeliefPropagationSweepState{ + Iterate, SCState <: AI.StoppingCriterionState, + } <: AI.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::SCState end -function SimpleMessageUpdate( - edge; - normalize = true, - contraction_alg = Algorithm"exact", - kwargs... +function AI.initialize_state( + problem::BeliefPropagationSweepProblem, + algorithm::BeliefPropagationSweepAlgorithm; + iterate, iteration::Int = 0 + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate ) - return SimpleMessageUpdate( - edge, - (; normalize, contraction_alg, kwargs...) + return BeliefPropagationSweepState(; + iterate, iteration, stopping_criterion_state ) end -function Base.getproperty(alg::SimpleMessageUpdate, name::Symbol) - if name in (:edge, :kwargs) - return getfield(alg, name) - else - return getproperty(getfield(alg, :kwargs), name) - end +function AI.initialize_state!( + problem::BeliefPropagationSweepProblem, + algorithm::BeliefPropagationSweepAlgorithm, + state::BeliefPropagationSweepState; + iteration::Int = 0 + ) + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state end -AI.initialize_state(::BeliefPropagationProblem, ::SimpleMessageUpdate; iterate) = iterate +function AIE.initialize_subsolve( + problem::BeliefPropagationSweepProblem, + algorithm::BeliefPropagationSweepAlgorithm, + state::BeliefPropagationSweepState + ) + subproblem = MessageUpdateProblem(problem.factors) + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate +end -struct BeliefPropagationSweep{ - ChildAlgorithm, Algorithms <: AbstractVector{ChildAlgorithm}, +# === Layer 1: BP outer loop (iterative) === + +struct BeliefPropagationProblem{Factors} <: AI.Problem + factors::Factors +end + +@kwdef struct BeliefPropagationAlgorithm{ + ChildAlgorithm <: AI.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm algorithms::Algorithms - stopping_criterion::AI.StopAfterIteration - function BeliefPropagationSweep(; algorithms) - stopping_criterion = AI.StopAfterIteration(length(algorithms)) - return new{eltype(algorithms), typeof(algorithms)}(algorithms, stopping_criterion) - end + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -function BeliefPropagationSweep(f::Function, edges) - return BeliefPropagationSweep(; algorithms = f.(edges)) +function BeliefPropagationAlgorithm(f::Function, niterations::Int; kwargs...) + return BeliefPropagationAlgorithm(; algorithms = f.(1:niterations), kwargs...) end -# State construction / reset / increment for the BP NestedAlgorithm -# pair. Both `BeliefPropagation` and `BeliefPropagationSweep` carry a -# `stopping_criterion`, so the standard AlgorithmsInterface state chain -# resolves the same way. -const BeliefPropagationLikeAlgorithm = Union{BeliefPropagation, BeliefPropagationSweep} +@kwdef mutable struct BeliefPropagationState{ + Iterate, SCState <: AI.StoppingCriterionState, + } <: AI.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::SCState +end function AI.initialize_state( problem::BeliefPropagationProblem, - algorithm::BeliefPropagationLikeAlgorithm; - iterate, kwargs... + algorithm::BeliefPropagationAlgorithm; + iterate, iteration::Int = 0 ) stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion; iterate ) - return BeliefPropagationState(; iterate, stopping_criterion_state, kwargs...) + return BeliefPropagationState(; + iterate, iteration, stopping_criterion_state + ) end function AI.initialize_state!( problem::BeliefPropagationProblem, - algorithm::BeliefPropagationLikeAlgorithm, + algorithm::BeliefPropagationAlgorithm, state::BeliefPropagationState; - iteration = 0, kwargs... + iteration::Int = 0 ) - for (k, v) in pairs(kwargs) - setproperty!(state, k, v) - end state.iteration = iteration AI.initialize_state!( problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state @@ -196,62 +263,18 @@ function AI.initialize_state!( return state end -function AI.increment!( - problem::BeliefPropagationProblem, - algorithm::BeliefPropagationLikeAlgorithm, - state::BeliefPropagationState - ) - return AI.increment!(state) -end - -# `BeliefPropagation` and `BeliefPropagationSweep` carry a flat list of -# child algorithms. Each step picks the child algorithm by the current -# iteration index and reuses the parent problem. function AIE.initialize_subsolve( problem::BeliefPropagationProblem, - algorithm::BeliefPropagationLikeAlgorithm, - state::AI.State + algorithm::BeliefPropagationAlgorithm, + state::BeliefPropagationState ) - subproblem = problem + subproblem = BeliefPropagationSweepProblem(problem.factors) subalgorithm = algorithm.algorithms[state.iteration] substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end -function AIE.finalize_substate!( - ::BeliefPropagationProblem, - ::BeliefPropagationSweep, - state::BeliefPropagationState, - cache::MessageCache - ) - state.iterate = cache - - return state -end - -function AI.solve!( - problem::BeliefPropagationProblem, - algorithm::SimpleMessageUpdate, - cache::MessageCache - ) - edge = algorithm.edge - - messages = collect(incoming_messages(cache, edge)) - factor = problem.factors[src(edge)] - - new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) - - if algorithm.normalize - message_norm = sum(new_message) - if !iszero(message_norm) - new_message /= message_norm - end - end - - cache[edge] = new_message - - return cache -end +# === Top-level user entry point === function beliefpropagation( factors, messages; @@ -287,9 +310,9 @@ function beliefpropagation( extended_kwargs = extend_columns((; kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, maxiter) - algorithm = BeliefPropagation(maxiter; stopping_criterion) do repnum - return BeliefPropagationSweep(edges) do edge - return SimpleMessageUpdate(edge; edge_kwargs[repnum]...) + algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum + return BeliefPropagationSweepAlgorithm(edges) do edge + return SimpleMessageUpdateAlgorithm(; edge, edge_kwargs[repnum]...) end end diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index 1345bfd..f580826 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -3,8 +3,8 @@ import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE using Test: @test, @test_throws, @testset # Concrete `NestedAlgorithm` subtype: holds a flat list of child algorithms -# and picks them by iteration index. Mirrors how `BeliefPropagation` shapes -# itself on top of the minimal `AIE.NestedAlgorithm`. +# and picks them by iteration index. Mirrors how `BeliefPropagationAlgorithm` +# shapes itself on top of the minimal `AIE.NestedAlgorithm`. struct TestProblem <: AI.Problem end @kwdef struct TestChildAlgorithm{StoppingCriterion <: AI.StoppingCriterion} <: AI.Algorithm From fdb3c04ab4f188c6494de432044e67eb75694085 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 16:46:48 -0400 Subject: [PATCH 05/16] Reorder BP source top-to-bottom MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit API first, then the three layers from outer (BP) to inner (message update), then the supporting `StopWhenConverged` stopping criterion + `iterate_diff` helper, then the low-level kwarg utilities at the bottom. Reads as "what's the API → how it's composed → supporting pieces" the way the file is most likely to be skimmed. Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 355 +++++++++--------- 1 file changed, 178 insertions(+), 177 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 1202e2b..44b42ed 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -8,142 +8,115 @@ using NamedDimsArrays: AbstractNamedDimsArray using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph using NamedGraphs.PartitionedGraphs: quotientvertices -# Utility functions for processing keyword arguments. -function repeat_last(v::AbstractVector, len::Int) - return [v; fill(v[end], max(len - length(v), 0))] -end -repeat_last(v, len::Int) = fill(v, len) -function extend_columns(nt::NamedTuple, len::Int) - return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) -end -rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) -function rows(nt::NamedTuple, len::Int = rowlength(nt)) - return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] -end - -# === `StopWhenConverged` stopping criterion === - -@kwdef struct StopWhenConverged <: AI.StoppingCriterion - tol::Float64 -end - -@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState - delta::Float64 = Inf - at_iteration::Int = -1 - previous_iterate::Iterate -end - -function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate) - return StopWhenConvergedState(; previous_iterate = copy(iterate)) -end - -function AI.initialize_state!( - ::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState - ) - st.delta = Inf - return st -end +# === Top-level user entry point === -function AI.is_finished!( - problem::AI.Problem, - algorithm::AI.Algorithm, - state::AI.State, - c::StopWhenConverged, - st::StopWhenConvergedState +function beliefpropagation( + factors, messages; + edges = nothing, + maxiter = is_tree(factors) ? 1 : nothing, + stopping_criterion = nothing, + kwargs... ) - iterate = state.iterate - previous_iterate = st.previous_iterate + if isnothing(maxiter) + throw( + ArgumentError( + "`maxiter` must be specified for non-tree graphs, even when + `stopping_criterion` is provided." + ) + ) + end - delta = iterate_diff(iterate, previous_iterate) + cache = MessageCache(messages) + problem = BeliefPropagationProblem(factors) - st.previous_iterate = copy(iterate) + ## Algorithm construction: - # delta = 0 initially, so skip this the first time. - state.iteration == 0 && return false + edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges - st.delta = delta + base_stopping_criterion = AI.StopAfterIteration(maxiter) - if AI.is_finished(problem, algorithm, state, c, st) - st.at_iteration = state.iteration - return true + if !isnothing(stopping_criterion) + base_stopping_criterion |= stopping_criterion end - return false -end + stopping_criterion = base_stopping_criterion -function AI.is_finished( - ::AI.Problem, - ::AI.Algorithm, - ::AI.State, - c::StopWhenConverged, - st::StopWhenConvergedState - ) - return st.delta < c.tol -end + extended_kwargs = extend_columns((; kwargs...), maxiter) + edge_kwargs = rows(extended_kwargs, maxiter) -function iterate_diff(cache1::MessageCache, cache2::MessageCache) - return maximum(edges(cache1)) do edge - m1 = cache1[edge] - m2 = cache2[edge] - return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) + algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum + return BeliefPropagationSweepAlgorithm(edges) do edge + return SimpleMessageUpdateAlgorithm(; edge, edge_kwargs[repnum]...) + end end + + ## + + return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) end -# === Layer 3: single-edge message update (non-iterative) === +# === Layer 1: BP outer loop (iterative) === -struct MessageUpdateProblem{Factors} <: AI.Problem +struct BeliefPropagationProblem{Factors} <: AI.Problem factors::Factors end -@kwdef struct SimpleMessageUpdateAlgorithm{ - E <: AbstractEdge, ContractionAlg, - } <: AI.Algorithm - edge::E - normalize::Bool = true - contraction_alg::ContractionAlg = Algorithm"exact" +@kwdef struct BeliefPropagationAlgorithm{ + ChildAlgorithm <: AI.Algorithm, + Algorithms <: AbstractVector{ChildAlgorithm}, + StoppingCriterion <: AI.StoppingCriterion, + } <: AIE.NestedAlgorithm + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -@kwdef mutable struct MessageUpdateState{Iterate} <: AI.State +function BeliefPropagationAlgorithm(f::Function, niterations::Int; kwargs...) + return BeliefPropagationAlgorithm(; algorithms = f.(1:niterations), kwargs...) +end + +@kwdef mutable struct BeliefPropagationState{ + Iterate, SCState <: AI.StoppingCriterionState, + } <: AI.State iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::SCState end function AI.initialize_state( - ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm; iterate + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationAlgorithm; + iterate, iteration::Int = 0 + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return BeliefPropagationState(; + iterate, iteration, stopping_criterion_state ) - return MessageUpdateState(; iterate) end -# Non-iterative algorithm: no per-call state to reset. function AI.initialize_state!( - ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm, state::MessageUpdateState + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationAlgorithm, + state::BeliefPropagationState; + iteration::Int = 0 + ) + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state ) return state end -# Non-iterative algorithm: bypass the step!/stopping-criterion loop. -function AI.solve_loop!( - problem::MessageUpdateProblem, - algorithm::SimpleMessageUpdateAlgorithm, - state::MessageUpdateState +function AIE.initialize_subsolve( + problem::BeliefPropagationProblem, + algorithm::BeliefPropagationAlgorithm, + state::BeliefPropagationState ) - cache = state.iterate - edge = algorithm.edge - - messages = collect(incoming_messages(cache, edge)) - factor = problem.factors[src(edge)] - - new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) - - if algorithm.normalize - message_norm = sum(new_message) - if !iszero(message_norm) - new_message /= message_norm - end - end - - cache[edge] = new_message - - return state + subproblem = BeliefPropagationSweepProblem(problem.factors) + subalgorithm = algorithm.algorithms[state.iteration] + substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) + return subproblem, subalgorithm, substate end # === Layer 2: one sweep over edges (iterative) === @@ -210,113 +183,141 @@ function AIE.initialize_subsolve( return subproblem, subalgorithm, substate end -# === Layer 1: BP outer loop (iterative) === +# === Layer 3: single-edge message update (non-iterative) === -struct BeliefPropagationProblem{Factors} <: AI.Problem +struct MessageUpdateProblem{Factors} <: AI.Problem factors::Factors end -@kwdef struct BeliefPropagationAlgorithm{ - ChildAlgorithm <: AI.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, - StoppingCriterion <: AI.StoppingCriterion, - } <: AIE.NestedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) -end - -function BeliefPropagationAlgorithm(f::Function, niterations::Int; kwargs...) - return BeliefPropagationAlgorithm(; algorithms = f.(1:niterations), kwargs...) +@kwdef struct SimpleMessageUpdateAlgorithm{ + E <: AbstractEdge, ContractionAlg, + } <: AI.Algorithm + edge::E + normalize::Bool = true + contraction_alg::ContractionAlg = Algorithm"exact" end -@kwdef mutable struct BeliefPropagationState{ - Iterate, SCState <: AI.StoppingCriterionState, - } <: AI.State +@kwdef mutable struct MessageUpdateState{Iterate} <: AI.State iterate::Iterate - iteration::Int = 0 - stopping_criterion_state::SCState end function AI.initialize_state( - problem::BeliefPropagationProblem, - algorithm::BeliefPropagationAlgorithm; - iterate, iteration::Int = 0 - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion; iterate - ) - return BeliefPropagationState(; - iterate, iteration, stopping_criterion_state + ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm; iterate ) + return MessageUpdateState(; iterate) end +# Non-iterative algorithm: no per-call state to reset. function AI.initialize_state!( - problem::BeliefPropagationProblem, - algorithm::BeliefPropagationAlgorithm, - state::BeliefPropagationState; - iteration::Int = 0 - ) - state.iteration = iteration - AI.initialize_state!( - problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm, state::MessageUpdateState ) return state end -function AIE.initialize_subsolve( - problem::BeliefPropagationProblem, - algorithm::BeliefPropagationAlgorithm, - state::BeliefPropagationState +# Non-iterative algorithm: bypass the step!/stopping-criterion loop. +function AI.solve_loop!( + problem::MessageUpdateProblem, + algorithm::SimpleMessageUpdateAlgorithm, + state::MessageUpdateState ) - subproblem = BeliefPropagationSweepProblem(problem.factors) - subalgorithm = algorithm.algorithms[state.iteration] - substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) - return subproblem, subalgorithm, substate + cache = state.iterate + edge = algorithm.edge + + messages = collect(incoming_messages(cache, edge)) + factor = problem.factors[src(edge)] + + new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) + + if algorithm.normalize + message_norm = sum(new_message) + if !iszero(message_norm) + new_message /= message_norm + end + end + + cache[edge] = new_message + + return state end -# === Top-level user entry point === +# === `StopWhenConverged` stopping criterion === -function beliefpropagation( - factors, messages; - edges = nothing, - maxiter = is_tree(factors) ? 1 : nothing, - stopping_criterion = nothing, - kwargs... +@kwdef struct StopWhenConverged <: AI.StoppingCriterion + tol::Float64 +end + +@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState + delta::Float64 = Inf + at_iteration::Int = -1 + previous_iterate::Iterate +end + +function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate) + return StopWhenConvergedState(; previous_iterate = copy(iterate)) +end + +function AI.initialize_state!( + ::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState ) - if isnothing(maxiter) - throw( - ArgumentError( - "`maxiter` must be specified for non-tree graphs, even when - `stopping_criterion` is provided." - ) - ) - end + st.delta = Inf + return st +end - cache = MessageCache(messages) - problem = BeliefPropagationProblem(factors) +function AI.is_finished!( + problem::AI.Problem, + algorithm::AI.Algorithm, + state::AI.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + iterate = state.iterate + previous_iterate = st.previous_iterate - ## Algorithm construction: + delta = iterate_diff(iterate, previous_iterate) - edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges + st.previous_iterate = copy(iterate) - base_stopping_criterion = AI.StopAfterIteration(maxiter) + # delta = 0 initially, so skip this the first time. + state.iteration == 0 && return false - if !isnothing(stopping_criterion) - base_stopping_criterion |= stopping_criterion + st.delta = delta + + if AI.is_finished(problem, algorithm, state, c, st) + st.at_iteration = state.iteration + return true end - stopping_criterion = base_stopping_criterion + return false +end - extended_kwargs = extend_columns((; kwargs...), maxiter) - edge_kwargs = rows(extended_kwargs, maxiter) +function AI.is_finished( + ::AI.Problem, + ::AI.Algorithm, + ::AI.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + return st.delta < c.tol +end - algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum - return BeliefPropagationSweepAlgorithm(edges) do edge - return SimpleMessageUpdateAlgorithm(; edge, edge_kwargs[repnum]...) - end +function iterate_diff(cache1::MessageCache, cache2::MessageCache) + return maximum(edges(cache1)) do edge + m1 = cache1[edge] + m2 = cache2[edge] + return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) end +end - ## +# === Utility functions for processing keyword arguments === - return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) +function repeat_last(v::AbstractVector, len::Int) + return [v; fill(v[end], max(len - length(v), 0))] +end +repeat_last(v, len::Int) = fill(v, len) +function extend_columns(nt::NamedTuple, len::Int) + return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) +end +rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) +function rows(nt::NamedTuple, len::Int = rowlength(nt)) + return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] end From 56b2d82570c6cc02551be85e7811d66557b2173d Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 17:10:38 -0400 Subject: [PATCH 06/16] Move StopWhenConverged + iterate_diff verb into AIE MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `StopWhenConverged` is fully generic — its dispatches accept `AI.Problem` / `AI.Algorithm` / `AI.State`, with `StopWhenConverged` itself as the unique disambiguating type. The only piece that needed to live in BP was the metric: the BP `iterate_diff(::MessageCache, ::MessageCache)` override. AIE now owns: `StopWhenConverged`, `StopWhenConvergedState`, the four `AI.*` overloads, and a bare `iterate_diff(a, b)` verb that throws by default. BP provides its concrete `AIE.iterate_diff(::MessageCache, ::MessageCache)`. The top-level `ITensorNetworksNext` namespace re-exports `StopWhenConverged` + `iterate_diff` for callers that use the package-qualified names. Co-Authored-By: Claude Opus 4.7 --- .../AlgorithmsInterfaceExtensions.jl | 66 +++++++++++++++++++ .../beliefpropagationproblem.jl | 63 +----------------- 2 files changed, 69 insertions(+), 60 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index ddf1083..c7ac727 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -31,4 +31,70 @@ function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.Sta return state end +# ============================ StopWhenConverged =========================================== + +# Stopping criterion that fires once `iterate_diff(iterate, previous_iterate) < tol`. +# Concrete iterate types must supply an `iterate_diff` method. +function iterate_diff(a, b) + return throw(MethodError(iterate_diff, (a, b))) +end + +@kwdef struct StopWhenConverged <: AI.StoppingCriterion + tol::Float64 +end + +@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState + delta::Float64 = Inf + at_iteration::Int = -1 + previous_iterate::Iterate +end + +function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate) + return StopWhenConvergedState(; previous_iterate = copy(iterate)) +end + +function AI.initialize_state!( + ::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState + ) + st.delta = Inf + return st +end + +function AI.is_finished!( + problem::AI.Problem, + algorithm::AI.Algorithm, + state::AI.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + iterate = state.iterate + previous_iterate = st.previous_iterate + + delta = iterate_diff(iterate, previous_iterate) + + st.previous_iterate = copy(iterate) + + # delta = 0 initially, so skip this the first time. + state.iteration == 0 && return false + + st.delta = delta + + if AI.is_finished(problem, algorithm, state, c, st) + st.at_iteration = state.iteration + return true + end + + return false +end + +function AI.is_finished( + ::AI.Problem, + ::AI.Algorithm, + ::AI.State, + c::StopWhenConverged, + st::StopWhenConvergedState + ) + return st.delta < c.tol +end + end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 44b42ed..32e00ad 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -1,5 +1,6 @@ import .AlgorithmsInterfaceExtensions as AIE import AlgorithmsInterface as AI +using .AlgorithmsInterfaceExtensions: StopWhenConverged, iterate_diff using BackendSelection: @Algorithm_str, Algorithm using DataGraphs: edge_data using Graphs: AbstractEdge, edges, has_edge, vertices @@ -240,67 +241,9 @@ function AI.solve_loop!( return state end -# === `StopWhenConverged` stopping criterion === +# === `iterate_diff` for `MessageCache` (used by `AIE.StopWhenConverged`) === -@kwdef struct StopWhenConverged <: AI.StoppingCriterion - tol::Float64 -end - -@kwdef mutable struct StopWhenConvergedState{Iterate} <: AI.StoppingCriterionState - delta::Float64 = Inf - at_iteration::Int = -1 - previous_iterate::Iterate -end - -function AI.initialize_state(::AI.Problem, ::AI.Algorithm, ::StopWhenConverged; iterate) - return StopWhenConvergedState(; previous_iterate = copy(iterate)) -end - -function AI.initialize_state!( - ::AI.Problem, ::AI.Algorithm, ::StopWhenConverged, st::StopWhenConvergedState - ) - st.delta = Inf - return st -end - -function AI.is_finished!( - problem::AI.Problem, - algorithm::AI.Algorithm, - state::AI.State, - c::StopWhenConverged, - st::StopWhenConvergedState - ) - iterate = state.iterate - previous_iterate = st.previous_iterate - - delta = iterate_diff(iterate, previous_iterate) - - st.previous_iterate = copy(iterate) - - # delta = 0 initially, so skip this the first time. - state.iteration == 0 && return false - - st.delta = delta - - if AI.is_finished(problem, algorithm, state, c, st) - st.at_iteration = state.iteration - return true - end - - return false -end - -function AI.is_finished( - ::AI.Problem, - ::AI.Algorithm, - ::AI.State, - c::StopWhenConverged, - st::StopWhenConvergedState - ) - return st.delta < c.tol -end - -function iterate_diff(cache1::MessageCache, cache2::MessageCache) +function AIE.iterate_diff(cache1::MessageCache, cache2::MessageCache) return maximum(edges(cache1)) do edge m1 = cache1[edge] m2 = cache2[edge] From 4fd8a47689e23598e8968de3b726b424ec95755e Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 17:36:34 -0400 Subject: [PATCH 07/16] Move `edge` field from `SimpleMessageUpdateAlgorithm` to `MessageUpdateProblem` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The edge is per-step data; it belongs on the problem side, not the algorithm side. After the move: - `MessageUpdateProblem{Factors, Edge <: AbstractEdge}` carries `factors` + `edge` (type parameter renamed from `E` to `Edge`). - `SimpleMessageUpdateAlgorithm{ContractionAlg}` is now just `normalize` + `contraction_alg` — no per-edge data. - `BeliefPropagationSweepProblem` gains an `edges` field; its `initialize_subsolve` picks `edge = problem.edges[state.iteration]` and threads it into the `MessageUpdateProblem`. - `BeliefPropagationProblem` gains an `edges` field too, and the outer `initialize_subsolve` threads it down to the sweep subproblem. - The entry-point `do edge` closure becomes `do _` since the algorithm no longer needs the edge. Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 32e00ad..640f8e0 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -28,12 +28,13 @@ function beliefpropagation( end cache = MessageCache(messages) - problem = BeliefPropagationProblem(factors) ## Algorithm construction: edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges + problem = BeliefPropagationProblem(factors, edges) + base_stopping_criterion = AI.StopAfterIteration(maxiter) if !isnothing(stopping_criterion) @@ -46,8 +47,8 @@ function beliefpropagation( edge_kwargs = rows(extended_kwargs, maxiter) algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum - return BeliefPropagationSweepAlgorithm(edges) do edge - return SimpleMessageUpdateAlgorithm(; edge, edge_kwargs[repnum]...) + return BeliefPropagationSweepAlgorithm(edges) do _ + return SimpleMessageUpdateAlgorithm(; edge_kwargs[repnum]...) end end @@ -58,8 +59,9 @@ end # === Layer 1: BP outer loop (iterative) === -struct BeliefPropagationProblem{Factors} <: AI.Problem +struct BeliefPropagationProblem{Factors, Edges} <: AI.Problem factors::Factors + edges::Edges end @kwdef struct BeliefPropagationAlgorithm{ @@ -114,7 +116,7 @@ function AIE.initialize_subsolve( algorithm::BeliefPropagationAlgorithm, state::BeliefPropagationState ) - subproblem = BeliefPropagationSweepProblem(problem.factors) + subproblem = BeliefPropagationSweepProblem(problem.factors, problem.edges) subalgorithm = algorithm.algorithms[state.iteration] substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate @@ -122,8 +124,9 @@ end # === Layer 2: one sweep over edges (iterative) === -struct BeliefPropagationSweepProblem{Factors} <: AI.Problem +struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem factors::Factors + edges::Edges end @kwdef struct BeliefPropagationSweepAlgorithm{ @@ -178,7 +181,8 @@ function AIE.initialize_subsolve( algorithm::BeliefPropagationSweepAlgorithm, state::BeliefPropagationSweepState ) - subproblem = MessageUpdateProblem(problem.factors) + edge = problem.edges[state.iteration] + subproblem = MessageUpdateProblem(problem.factors, edge) subalgorithm = algorithm.algorithms[state.iteration] substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate @@ -186,14 +190,12 @@ end # === Layer 3: single-edge message update (non-iterative) === -struct MessageUpdateProblem{Factors} <: AI.Problem +struct MessageUpdateProblem{Factors, Edge <: AbstractEdge} <: AI.Problem factors::Factors + edge::Edge end -@kwdef struct SimpleMessageUpdateAlgorithm{ - E <: AbstractEdge, ContractionAlg, - } <: AI.Algorithm - edge::E +@kwdef struct SimpleMessageUpdateAlgorithm{ContractionAlg} <: AI.Algorithm normalize::Bool = true contraction_alg::ContractionAlg = Algorithm"exact" end @@ -222,7 +224,7 @@ function AI.solve_loop!( state::MessageUpdateState ) cache = state.iterate - edge = algorithm.edge + edge = problem.edge messages = collect(incoming_messages(cache, edge)) factor = problem.factors[src(edge)] From 81f7b4f8863d660f28527ff2757e3d48f5fa5102 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 17:46:49 -0400 Subject: [PATCH 08/16] Move BP edge ordering from `BeliefPropagationProblem` to `BeliefPropagationSweepAlgorithm` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit At the outer BP layer the edge update order is an algorithmic choice (which edges to sweep and in what order, potentially varying per rep), not problem data. So edges now live on `BeliefPropagationSweepAlgorithm` — one sweep algorithm is "do a sweep with these edges". The outer `initialize_subsolve` transfers them into a `BeliefPropagationSweepProblem` when stepping down a layer. Side effect: `BeliefPropagationSweepAlgorithm` no longer needs an `algorithms::Vector` field. The per-edge `SimpleMessageUpdateAlgorithm`s were always identical copies (they no longer carry an edge), so it collapses to a single `message_update_algorithm` template — matching the shape `ApplyOperators` uses in the gate-apply code. Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 640f8e0..887b354 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -28,13 +28,12 @@ function beliefpropagation( end cache = MessageCache(messages) + problem = BeliefPropagationProblem(factors) ## Algorithm construction: edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges - problem = BeliefPropagationProblem(factors, edges) - base_stopping_criterion = AI.StopAfterIteration(maxiter) if !isnothing(stopping_criterion) @@ -47,9 +46,12 @@ function beliefpropagation( edge_kwargs = rows(extended_kwargs, maxiter) algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum - return BeliefPropagationSweepAlgorithm(edges) do _ - return SimpleMessageUpdateAlgorithm(; edge_kwargs[repnum]...) - end + return BeliefPropagationSweepAlgorithm(; + edges, + message_update_algorithm = SimpleMessageUpdateAlgorithm(; + edge_kwargs[repnum]... + ) + ) end ## @@ -59,9 +61,8 @@ end # === Layer 1: BP outer loop (iterative) === -struct BeliefPropagationProblem{Factors, Edges} <: AI.Problem +struct BeliefPropagationProblem{Factors} <: AI.Problem factors::Factors - edges::Edges end @kwdef struct BeliefPropagationAlgorithm{ @@ -116,8 +117,8 @@ function AIE.initialize_subsolve( algorithm::BeliefPropagationAlgorithm, state::BeliefPropagationState ) - subproblem = BeliefPropagationSweepProblem(problem.factors, problem.edges) subalgorithm = algorithm.algorithms[state.iteration] + subproblem = BeliefPropagationSweepProblem(problem.factors, subalgorithm.edges) substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end @@ -130,16 +131,13 @@ struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem end @kwdef struct BeliefPropagationSweepAlgorithm{ + Edges, ChildAlgorithm <: AI.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) -end - -function BeliefPropagationSweepAlgorithm(f::Function, edges) - return BeliefPropagationSweepAlgorithm(; algorithms = f.(edges)) + edges::Edges + message_update_algorithm::ChildAlgorithm + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(edges)) end @kwdef mutable struct BeliefPropagationSweepState{ @@ -183,7 +181,7 @@ function AIE.initialize_subsolve( ) edge = problem.edges[state.iteration] subproblem = MessageUpdateProblem(problem.factors, edge) - subalgorithm = algorithm.algorithms[state.iteration] + subalgorithm = algorithm.message_update_algorithm substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end From 6dbf365b1b4205816664e545ef3340595478beb9 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 18:39:29 -0400 Subject: [PATCH 09/16] Store edges on `BeliefPropagationAlgorithm`, not the sweep algorithm Edges shouldn't be stored at both the outer BP algorithm level (as the algorithmic edge-ordering choice) and the sweep algorithm level (as runtime data that gets copied into the sweep problem). Keep them only on `BeliefPropagationAlgorithm`; the outer `initialize_subsolve` then transfers them into the `BeliefPropagationSweepProblem` when stepping down. `BeliefPropagationSweepAlgorithm` is now just `message_update_algorithm` + `stopping_criterion` (the stopping criterion is sized to the edges at construction time in `beliefpropagation()`). Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 887b354..95a4e08 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -45,12 +45,13 @@ function beliefpropagation( extended_kwargs = extend_columns((; kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, maxiter) - algorithm = BeliefPropagationAlgorithm(maxiter; stopping_criterion) do repnum + sweep_stopping_criterion = AI.StopAfterIteration(length(edges)) + algorithm = BeliefPropagationAlgorithm(maxiter; edges, stopping_criterion) do repnum return BeliefPropagationSweepAlgorithm(; - edges, message_update_algorithm = SimpleMessageUpdateAlgorithm(; edge_kwargs[repnum]... - ) + ), + stopping_criterion = sweep_stopping_criterion ) end @@ -66,16 +67,18 @@ struct BeliefPropagationProblem{Factors} <: AI.Problem end @kwdef struct BeliefPropagationAlgorithm{ + Edges, ChildAlgorithm <: AI.Algorithm, Algorithms <: AbstractVector{ChildAlgorithm}, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm + edges::Edges algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end -function BeliefPropagationAlgorithm(f::Function, niterations::Int; kwargs...) - return BeliefPropagationAlgorithm(; algorithms = f.(1:niterations), kwargs...) +function BeliefPropagationAlgorithm(f::Function, niterations::Int; edges, kwargs...) + return BeliefPropagationAlgorithm(; edges, algorithms = f.(1:niterations), kwargs...) end @kwdef mutable struct BeliefPropagationState{ @@ -118,7 +121,7 @@ function AIE.initialize_subsolve( state::BeliefPropagationState ) subalgorithm = algorithm.algorithms[state.iteration] - subproblem = BeliefPropagationSweepProblem(problem.factors, subalgorithm.edges) + subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end @@ -131,13 +134,11 @@ struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem end @kwdef struct BeliefPropagationSweepAlgorithm{ - Edges, ChildAlgorithm <: AI.Algorithm, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm - edges::Edges message_update_algorithm::ChildAlgorithm - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(edges)) + stopping_criterion::StoppingCriterion end @kwdef mutable struct BeliefPropagationSweepState{ From 0d871c20f5400dc0269f07baaec3c89b98df15ba Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 18:54:35 -0400 Subject: [PATCH 10/16] Index per-edge BP algorithms by edge; drop `AbstractVector` constraints Make `BeliefPropagationSweepAlgorithm.algorithms` indexable by edge (default: a `Dict{Edge, MessageUpdateAlgorithm}` populated with the same template at construction time), and drop the `Algorithms <: AbstractVector{ChildAlgorithm}` constraint from both `BeliefPropagationAlgorithm` and `BeliefPropagationSweepAlgorithm` so any indexable container works. Edges remain sourced from `BeliefPropagationSweepProblem` (the source of truth); the sweep `initialize_subsolve` looks up `algorithm.algorithms[edge]` to pick the per-edge update algorithm. Co-Authored-By: Claude Opus 4.7 --- .../beliefpropagationproblem.jl | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl index 95a4e08..2dd7b32 100644 --- a/src/beliefpropagation/beliefpropagationproblem.jl +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -45,13 +45,12 @@ function beliefpropagation( extended_kwargs = extend_columns((; kwargs...), maxiter) edge_kwargs = rows(extended_kwargs, maxiter) - sweep_stopping_criterion = AI.StopAfterIteration(length(edges)) algorithm = BeliefPropagationAlgorithm(maxiter; edges, stopping_criterion) do repnum + message_update_algorithm = SimpleMessageUpdateAlgorithm(; + edge_kwargs[repnum]... + ) return BeliefPropagationSweepAlgorithm(; - message_update_algorithm = SimpleMessageUpdateAlgorithm(; - edge_kwargs[repnum]... - ), - stopping_criterion = sweep_stopping_criterion + algorithms = Dict(edge => message_update_algorithm for edge in edges) ) end @@ -68,11 +67,11 @@ end @kwdef struct BeliefPropagationAlgorithm{ Edges, - ChildAlgorithm <: AI.Algorithm, - Algorithms <: AbstractVector{ChildAlgorithm}, + Algorithms, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm edges::Edges + # Indexable by iteration count (e.g. `Vector` or `Dict{Int, ...}`). algorithms::Algorithms stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end @@ -134,11 +133,14 @@ struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem end @kwdef struct BeliefPropagationSweepAlgorithm{ - ChildAlgorithm <: AI.Algorithm, + Algorithms, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm - message_update_algorithm::ChildAlgorithm - stopping_criterion::StoppingCriterion + # Indexable by edge (e.g. `Dict{Edge, MessageUpdateAlgorithm}`); the + # default constructor in `beliefpropagation()` builds one with the same + # template copied across every edge. + algorithms::Algorithms + stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) end @kwdef mutable struct BeliefPropagationSweepState{ @@ -182,7 +184,7 @@ function AIE.initialize_subsolve( ) edge = problem.edges[state.iteration] subproblem = MessageUpdateProblem(problem.factors, edge) - subalgorithm = algorithm.message_update_algorithm + subalgorithm = algorithm.algorithms[edge] substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end From e5a58d05dc177a9dba07a4fd7ac17d3929c17d31 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 18:57:53 -0400 Subject: [PATCH 11/16] Rename `beliefpropagationproblem.jl` to `beliefpropagation.jl` The file now defines three problem/algorithm/state triples plus the top-level `beliefpropagation` entry point, so the old name is misleading. Co-Authored-By: Claude Opus 4.7 --- src/ITensorNetworksNext.jl | 2 +- .../{beliefpropagationproblem.jl => beliefpropagation.jl} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/beliefpropagation/{beliefpropagationproblem.jl => beliefpropagation.jl} (100%) diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 394546a..41ce78e 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -14,6 +14,6 @@ include("TensorNetworkGenerators/TensorNetworkGenerators.jl") include("contract_network.jl") include("beliefpropagation/messagecache.jl") -include("beliefpropagation/beliefpropagationproblem.jl") +include("beliefpropagation/beliefpropagation.jl") end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagation.jl similarity index 100% rename from src/beliefpropagation/beliefpropagationproblem.jl rename to src/beliefpropagation/beliefpropagation.jl From 2e8ee3b26a070fc1ce33d89ceafcbbac056f562f Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 21:03:05 -0400 Subject: [PATCH 12/16] Simplify BP API: single child algorithms + select_* selectors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Public-API changes: * `beliefpropagation` no longer takes a `maxiter` kwarg; pass either `stopping_criterion = (; maxiter = 10)` or a full criterion such as `stopping_criterion = AI.StopAfterIteration(10) | StopWhenConverged(1.0e-10)`. * The `message_update_algorithm` kwarg accepts either a NamedTuple of keyword arguments forwarded to `SimpleMessageUpdateAlgorithm`, or a full `AI.Algorithm`. * The `edges` kwarg defaults to `default_beliefpropagation_edges(factors)` rather than being computed inside the body. * Per-iteration kwarg broadcasting (`extend_columns`, `rows`, `rowlength`, `repeat_last`) is removed. Internals: * `BeliefPropagationSweepAlgorithm` now holds a single `message_update_algorithm` (was a `Dict{Edge, ...}` of per-edge algorithms). Edge-dependent updates are expressed by defining a new `AI.Algorithm` subtype and dispatching on `problem.edge` in `solve_loop!`. * `BeliefPropagationAlgorithm` holds a single `sweep_algorithm` (was a `Vector` of per-iteration sweep algorithms). Iteration-dependent sweep behavior is expressed by defining a new sweep algorithm subtype and varying the inner algorithm in `AIE.initialize_subsolve` using `state.iteration`. * New `select_*` selectors `select_beliefpropagation_stopping_criterion` and `select_message_update_algorithm`, modeled on `MatrixAlgebraKit.select_truncation`, normalize NamedTuple-or-object inputs into algorithm/criterion objects. * `default_beliefpropagation_edges` and `default_message_update_algorithm` expose the defaults as named functions. * Stopping-criterion default removal: the entry point no longer constructs a `StopAfterIteration(maxiter)` itself — choice of criterion is entirely the caller's responsibility, except for the inner sweep, which still defaults to `StopAfterIteration(length(edges))`. * Type parameter `SCState` renamed to `StoppingCriterionState` on both BP state structs. * Helper definitions reordered so default/select helpers sit above `beliefpropagation` for readability. Co-Authored-By: Claude Opus 4.7 --- src/beliefpropagation/beliefpropagation.jl | 131 ++++++++++----------- test/test_beliefpropagation.jl | 16 +-- 2 files changed, 73 insertions(+), 74 deletions(-) diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index 2dd7b32..667debd 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -11,50 +11,69 @@ using NamedGraphs.PartitionedGraphs: quotientvertices # === Top-level user entry point === -function beliefpropagation( - factors, messages; - edges = nothing, - maxiter = is_tree(factors) ? 1 : nothing, - stopping_criterion = nothing, - kwargs... +default_beliefpropagation_edges(graph) = forest_cover_edge_sequence(graph) + +default_message_update_algorithm(; kwargs...) = SimpleMessageUpdateAlgorithm(; kwargs...) + +select_message_update_algorithm(algorithm::AI.Algorithm) = algorithm +function select_message_update_algorithm(kwargs::NamedTuple) + return default_message_update_algorithm(; kwargs...) +end + +function select_beliefpropagation_stopping_criterion( + stopping_criterion::AI.StoppingCriterion + ) + return stopping_criterion +end +function select_beliefpropagation_stopping_criterion(::Nothing) + return throw( + ArgumentError( + "`stopping_criterion` must be specified, e.g.\n" * + " `stopping_criterion = (; maxiter = 10)` or\n" * + " `stopping_criterion = AI.StopAfterIteration(10) | StopWhenConverged(1.0e-10)`." + ) ) +end +function select_beliefpropagation_stopping_criterion(kwargs::NamedTuple) + return select_beliefpropagation_stopping_criterion(; kwargs...) +end +function select_beliefpropagation_stopping_criterion(; maxiter = nothing, kwargs...) if isnothing(maxiter) + throw(ArgumentError("`maxiter` must be specified in `stopping_criterion`.")) + end + if !isempty(kwargs) throw( ArgumentError( - "`maxiter` must be specified for non-tree graphs, even when - `stopping_criterion` is provided." + "Unrecognized `stopping_criterion` kwargs: $(keys(kwargs)). " * + "Only `maxiter` is currently supported." ) ) end + return AI.StopAfterIteration(maxiter) +end - cache = MessageCache(messages) +function beliefpropagation( + factors, messages; + edges = default_beliefpropagation_edges(factors), + stopping_criterion = nothing, + message_update_algorithm = default_message_update_algorithm() + ) problem = BeliefPropagationProblem(factors) + cache = MessageCache(messages) - ## Algorithm construction: - - edges = isnothing(edges) ? forest_cover_edge_sequence(cache) : edges - - base_stopping_criterion = AI.StopAfterIteration(maxiter) - - if !isnothing(stopping_criterion) - base_stopping_criterion |= stopping_criterion - end - - stopping_criterion = base_stopping_criterion - - extended_kwargs = extend_columns((; kwargs...), maxiter) - edge_kwargs = rows(extended_kwargs, maxiter) + stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) + message_update_algorithm = select_message_update_algorithm(message_update_algorithm) - algorithm = BeliefPropagationAlgorithm(maxiter; edges, stopping_criterion) do repnum - message_update_algorithm = SimpleMessageUpdateAlgorithm(; - edge_kwargs[repnum]... - ) - return BeliefPropagationSweepAlgorithm(; - algorithms = Dict(edge => message_update_algorithm for edge in edges) - ) - end + sweep_algorithm = BeliefPropagationSweepAlgorithm(; + message_update_algorithm, + stopping_criterion = AI.StopAfterIteration(length(edges)) + ) - ## + algorithm = BeliefPropagationAlgorithm(; + edges, + sweep_algorithm, + stopping_criterion + ) return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) end @@ -67,25 +86,20 @@ end @kwdef struct BeliefPropagationAlgorithm{ Edges, - Algorithms, + SweepAlgorithm <: AI.Algorithm, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm edges::Edges - # Indexable by iteration count (e.g. `Vector` or `Dict{Int, ...}`). - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) -end - -function BeliefPropagationAlgorithm(f::Function, niterations::Int; edges, kwargs...) - return BeliefPropagationAlgorithm(; edges, algorithms = f.(1:niterations), kwargs...) + sweep_algorithm::SweepAlgorithm + stopping_criterion::StoppingCriterion end @kwdef mutable struct BeliefPropagationState{ - Iterate, SCState <: AI.StoppingCriterionState, + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, } <: AI.State iterate::Iterate iteration::Int = 0 - stopping_criterion_state::SCState + stopping_criterion_state::StoppingCriterionState end function AI.initialize_state( @@ -119,7 +133,7 @@ function AIE.initialize_subsolve( algorithm::BeliefPropagationAlgorithm, state::BeliefPropagationState ) - subalgorithm = algorithm.algorithms[state.iteration] + subalgorithm = algorithm.sweep_algorithm subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate @@ -133,22 +147,19 @@ struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem end @kwdef struct BeliefPropagationSweepAlgorithm{ - Algorithms, + MessageUpdateAlgorithm <: AI.Algorithm, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm - # Indexable by edge (e.g. `Dict{Edge, MessageUpdateAlgorithm}`); the - # default constructor in `beliefpropagation()` builds one with the same - # template copied across every edge. - algorithms::Algorithms - stopping_criterion::StoppingCriterion = AI.StopAfterIteration(length(algorithms)) + message_update_algorithm::MessageUpdateAlgorithm + stopping_criterion::StoppingCriterion end @kwdef mutable struct BeliefPropagationSweepState{ - Iterate, SCState <: AI.StoppingCriterionState, + Iterate, StoppingCriterionState <: AI.StoppingCriterionState, } <: AI.State iterate::Iterate iteration::Int = 0 - stopping_criterion_state::SCState + stopping_criterion_state::StoppingCriterionState end function AI.initialize_state( @@ -184,7 +195,7 @@ function AIE.initialize_subsolve( ) edge = problem.edges[state.iteration] subproblem = MessageUpdateProblem(problem.factors, edge) - subalgorithm = algorithm.algorithms[edge] + subalgorithm = algorithm.message_update_algorithm substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) return subproblem, subalgorithm, substate end @@ -230,7 +241,7 @@ function AI.solve_loop!( messages = collect(incoming_messages(cache, edge)) factor = problem.factors[src(edge)] - new_message = contract_network(vcat(messages, [factor]); algorithm.contraction_alg) + new_message = contract_network([messages; [factor]]; algorithm.contraction_alg) if algorithm.normalize message_norm = sum(new_message) @@ -253,17 +264,3 @@ function AIE.iterate_diff(cache1::MessageCache, cache2::MessageCache) return 1 - abs2(LinearAlgebra.dot(normalize(m1), normalize(m2))) end end - -# === Utility functions for processing keyword arguments === - -function repeat_last(v::AbstractVector, len::Int) - return [v; fill(v[end], max(len - length(v), 0))] -end -repeat_last(v, len::Int) = fill(v, len) -function extend_columns(nt::NamedTuple, len::Int) - return (; (keys(nt) .=> map(v -> repeat_last(v, len), values(nt)))...) -end -rowlength(nt::NamedTuple) = only(unique(length.(values(nt)))) -function rows(nt::NamedTuple, len::Int = rowlength(nt)) - return [(; (keys(nt) .=> map(v -> v[i], values(nt)))...) for i in 1:len] -end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 01ca6e7..4362f7e 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -167,7 +167,9 @@ end messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) - cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) + cache = ITensorNetworksNext.beliefpropagation( + tn, messages; stopping_criterion = (; maxiter = 1) + ) z_bp = exp(bethe_free_energy(tn, cache)) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -184,7 +186,9 @@ end messages = Dict(edge => onet(tn, edge) for edge in all_edges(g)) - cache = ITensorNetworksNext.beliefpropagation(tn, messages; maxiter = 1) + cache = ITensorNetworksNext.beliefpropagation( + tn, messages; stopping_criterion = (; maxiter = 1) + ) z_bp = exp(bethe_free_energy(tn, cache)) z_exact = reduce(*, [tn[v] for v in vertices(g)])[] @test z_bp ≈ z_exact @@ -198,13 +202,11 @@ end messages = Dict(edge => randt(tn, edge) for edge in all_edges(g)) - stopping_criterion = StopWhenConverged(tol = 1.0e-10) + stopping_criterion = + AI.StopAfterIteration(10) | StopWhenConverged(tol = 1.0e-10) cache = ITensorNetworksNext.beliefpropagation( - tn, - messages; - maxiter = 10, - stopping_criterion + tn, messages; stopping_criterion ) z_bp = exp(bethe_free_energy(tn, cache)) From c8cafb9d78de3620a31876d609484a88534e9c31 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 19 May 2026 21:11:28 -0400 Subject: [PATCH 13/16] Minor reorg and line-collapse cleanup in `beliefpropagation` Reorder the body of `beliefpropagation` so the algorithm construction flows top-to-bottom and `cache` is constructed alongside the `AI.solve` call, and collapse a few short multi-line forms onto single lines where the formatter is happy with it. Co-Authored-By: Claude Opus 4.7 --- src/beliefpropagation/beliefpropagation.jl | 25 ++++++---------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index 667debd..b9104ff 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -20,11 +20,7 @@ function select_message_update_algorithm(kwargs::NamedTuple) return default_message_update_algorithm(; kwargs...) end -function select_beliefpropagation_stopping_criterion( - stopping_criterion::AI.StoppingCriterion - ) - return stopping_criterion -end +select_beliefpropagation_stopping_criterion(c::AI.StoppingCriterion) = c function select_beliefpropagation_stopping_criterion(::Nothing) return throw( ArgumentError( @@ -59,21 +55,16 @@ function beliefpropagation( message_update_algorithm = default_message_update_algorithm() ) problem = BeliefPropagationProblem(factors) - cache = MessageCache(messages) - stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) message_update_algorithm = select_message_update_algorithm(message_update_algorithm) - sweep_algorithm = BeliefPropagationSweepAlgorithm(; message_update_algorithm, stopping_criterion = AI.StopAfterIteration(length(edges)) ) + stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) + algorithm = BeliefPropagationAlgorithm(; edges, sweep_algorithm, stopping_criterion) - algorithm = BeliefPropagationAlgorithm(; - edges, - sweep_algorithm, - stopping_criterion - ) + cache = MessageCache(messages) return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) end @@ -110,9 +101,7 @@ function AI.initialize_state( stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion; iterate ) - return BeliefPropagationState(; - iterate, iteration, stopping_criterion_state - ) + return BeliefPropagationState(; iterate, iteration, stopping_criterion_state) end function AI.initialize_state!( @@ -170,9 +159,7 @@ function AI.initialize_state( stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion; iterate ) - return BeliefPropagationSweepState(; - iterate, iteration, stopping_criterion_state - ) + return BeliefPropagationSweepState(; iterate, iteration, stopping_criterion_state) end function AI.initialize_state!( From b0a1e90c3bb9119e180490575612ee6f036c5fe1 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Wed, 20 May 2026 13:04:54 -0400 Subject: [PATCH 14/16] Collapse BP message-update AI layer; introduce `MessageUpdateAlgorithm` strategy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The third `AlgorithmsInterface` layer (`MessageUpdateProblem` / `SimpleMessageUpdateAlgorithm` / `MessageUpdateState`) was a non-iterative function call dressed in problem/algorithm/state ceremony; this commit drops that framing. The per-edge update body lives in `AI.step!(::BeliefPropagationSweepProblem, ::BeliefPropagationSweepAlgorithm, ::BeliefPropagationSweepState)` and delegates to a new strategy interface: - `abstract type MessageUpdateAlgorithm end` - `message_update!(algorithm, cache, factors, edge)` - `SimpleMessageUpdate <: MessageUpdateAlgorithm` (the default; holds `normalize` + `contraction_alg`) `BeliefPropagationSweepAlgorithm` now holds a `message_update_algorithm` field (default `SimpleMessageUpdate()`) and is no longer a `NestedAlgorithm`. Algorithm selection follows `MatrixAlgebraKit.select_algorithm`. Generic `default_algorithm(f; kwargs...)` and `select_algorithm(f, alg; kwargs...)` helpers live in `AlgorithmsInterfaceExtensions`; operations register a default by overloading `default_algorithm(::typeof(f); kwargs...)`. BP overloads both for `message_update!`. Callers can pass either an explicit `MessageUpdateAlgorithm` instance, a `NamedTuple` of keyword arguments forwarded to the default algorithm, or `nothing` plus flat kwargs (also forwarded to the default). A convenience signature `message_update!(cache, factors, edge; alg = nothing, kwargs...)` routes through `AIE.select_algorithm` so the same dispatch pattern is reachable from a free-standing call. The `message_update_algorithm` kwarg on `beliefpropagation` uses the same selector. To keep the message-store iterate persistent across outer-loop iterations without duplicating it on the outer state, introduce an `AIE.NestedState` abstract type with generic `getproperty` / `setproperty!` / `propertynames` forwarders for `:iterate` through a `:substate` field. `BeliefPropagationState <: AIE.NestedState` now wraps a `BeliefPropagationSweepState` as `substate`; the outer state holds no `iterate` field of its own. The default `AIE.finalize_substate!` (which copies `substate.iterate` back to `state.iterate`) becomes a self-write through the forwarder — harmless and removes the need for a BP-specific override. Rename: `BeliefPropagationAlgorithm.sweep_algorithm` field → `subalgorithm` (matches the generic name already used in `AIE.initialize_subsolve`'s return tuple). Co-Authored-By: Claude Opus 4.7 --- .../AlgorithmsInterfaceExtensions.jl | 45 +++++++ src/beliefpropagation/beliefpropagation.jl | 112 +++++++++--------- 2 files changed, 99 insertions(+), 58 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index c7ac727..eac073f 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -31,6 +31,51 @@ function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.Sta return state end +# ============================ NestedState ================================================= + +# State that wraps an inner `substate` and forwards `:iterate` accesses to it, +# so the inner-loop iterate is shared without duplicating storage on the outer +# state. Subtypes must store the inner state as a field named `substate`. +abstract type NestedState <: AI.State end + +# Use `getfield` on the right-hand side so future edits to this forwarder +# can't accidentally recurse through the overload. +function Base.getproperty(state::NestedState, name::Symbol) + name === :iterate && return getfield(state, :substate).iterate + return getfield(state, name) +end +function Base.setproperty!(state::NestedState, name::Symbol, value) + name === :iterate && return (getfield(state, :substate).iterate = value) + return setfield!(state, name, value) +end +function Base.propertynames(state::NestedState) + return (fieldnames(typeof(state))..., :iterate) +end + +# ============================ select_algorithm / default_algorithm ======================== + +# Modeled on `MatrixAlgebraKit.select_algorithm` / `default_algorithm`. An +# operation `f` (a function) declares its default algorithm strategy by +# overloading `default_algorithm(::typeof(f); kwargs...)`. The strategy can +# then be selected at a call site via `select_algorithm(f, alg; kwargs...)`, +# accepting either `nothing` (use defaults), a `NamedTuple` of keyword +# arguments forwarded to the default constructor, or an explicit algorithm +# instance — the last dispatched per-operation (e.g. on a strategy +# supertype). +function default_algorithm(f; kwargs...) + return throw(MethodError(default_algorithm, (f,))) +end + +select_algorithm(f, ::Nothing; kwargs...) = default_algorithm(f; kwargs...) +function select_algorithm(f, alg::NamedTuple; kwargs...) + isempty(kwargs) || throw( + ArgumentError( + "Additional keyword arguments are not allowed when `alg` is a `NamedTuple`." + ) + ) + return default_algorithm(f; alg...) +end + # ============================ StopWhenConverged =========================================== # Stopping criterion that fires once `iterate_diff(iterate, previous_iterate) < tol`. diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index b9104ff..876e5cc 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -13,13 +13,6 @@ using NamedGraphs.PartitionedGraphs: quotientvertices default_beliefpropagation_edges(graph) = forest_cover_edge_sequence(graph) -default_message_update_algorithm(; kwargs...) = SimpleMessageUpdateAlgorithm(; kwargs...) - -select_message_update_algorithm(algorithm::AI.Algorithm) = algorithm -function select_message_update_algorithm(kwargs::NamedTuple) - return default_message_update_algorithm(; kwargs...) -end - select_beliefpropagation_stopping_criterion(c::AI.StoppingCriterion) = c function select_beliefpropagation_stopping_criterion(::Nothing) return throw( @@ -52,17 +45,19 @@ function beliefpropagation( factors, messages; edges = default_beliefpropagation_edges(factors), stopping_criterion = nothing, - message_update_algorithm = default_message_update_algorithm() + message_update_algorithm = nothing ) problem = BeliefPropagationProblem(factors) - message_update_algorithm = select_message_update_algorithm(message_update_algorithm) - sweep_algorithm = BeliefPropagationSweepAlgorithm(; + message_update_algorithm = AIE.select_algorithm( + message_update!, message_update_algorithm + ) + subalgorithm = BeliefPropagationSweepAlgorithm(; message_update_algorithm, stopping_criterion = AI.StopAfterIteration(length(edges)) ) stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) - algorithm = BeliefPropagationAlgorithm(; edges, sweep_algorithm, stopping_criterion) + algorithm = BeliefPropagationAlgorithm(; edges, subalgorithm, stopping_criterion) cache = MessageCache(messages) @@ -77,18 +72,18 @@ end @kwdef struct BeliefPropagationAlgorithm{ Edges, - SweepAlgorithm <: AI.Algorithm, + Subalgorithm <: AI.Algorithm, StoppingCriterion <: AI.StoppingCriterion, } <: AIE.NestedAlgorithm edges::Edges - sweep_algorithm::SweepAlgorithm + subalgorithm::Subalgorithm stopping_criterion::StoppingCriterion end @kwdef mutable struct BeliefPropagationState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: AI.State - iterate::Iterate + Substate <: AI.State, StoppingCriterionState <: AI.StoppingCriterionState, + } <: AIE.NestedState + substate::Substate iteration::Int = 0 stopping_criterion_state::StoppingCriterionState end @@ -98,10 +93,12 @@ function AI.initialize_state( algorithm::BeliefPropagationAlgorithm; iterate, iteration::Int = 0 ) + subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) + substate = AI.initialize_state(subproblem, algorithm.subalgorithm; iterate) stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion; iterate ) - return BeliefPropagationState(; iterate, iteration, stopping_criterion_state) + return BeliefPropagationState(; iteration, stopping_criterion_state, substate) end function AI.initialize_state!( @@ -122,10 +119,8 @@ function AIE.initialize_subsolve( algorithm::BeliefPropagationAlgorithm, state::BeliefPropagationState ) - subalgorithm = algorithm.sweep_algorithm subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) - substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) - return subproblem, subalgorithm, substate + return subproblem, algorithm.subalgorithm, state.substate end # === Layer 2: one sweep over edges (iterative) === @@ -136,10 +131,10 @@ struct BeliefPropagationSweepProblem{Factors, Edges} <: AI.Problem end @kwdef struct BeliefPropagationSweepAlgorithm{ - MessageUpdateAlgorithm <: AI.Algorithm, + MessageUpdateAlgorithm, StoppingCriterion <: AI.StoppingCriterion, - } <: AIE.NestedAlgorithm - message_update_algorithm::MessageUpdateAlgorithm + } <: AI.Algorithm + message_update_algorithm::MessageUpdateAlgorithm = SimpleMessageUpdate() stopping_criterion::StoppingCriterion end @@ -175,58 +170,60 @@ function AI.initialize_state!( return state end -function AIE.initialize_subsolve( +function AI.step!( problem::BeliefPropagationSweepProblem, algorithm::BeliefPropagationSweepAlgorithm, state::BeliefPropagationSweepState ) edge = problem.edges[state.iteration] - subproblem = MessageUpdateProblem(problem.factors, edge) - subalgorithm = algorithm.message_update_algorithm - substate = AI.initialize_state(subproblem, subalgorithm; state.iterate) - return subproblem, subalgorithm, substate + message_update!( + algorithm.message_update_algorithm, state.iterate, problem.factors, edge + ) + return state end -# === Layer 3: single-edge message update (non-iterative) === +# === Layer 3: single-edge message update strategy === -struct MessageUpdateProblem{Factors, Edge <: AbstractEdge} <: AI.Problem - factors::Factors - edge::Edge -end +# Strategy interface: a `MessageUpdateAlgorithm` defines how a single +# message is computed and written back into the message store. Plug in a +# new strategy by subtyping `MessageUpdateAlgorithm` and overloading +# `message_update!(strategy, cache, factors, edge)`. +abstract type MessageUpdateAlgorithm end -@kwdef struct SimpleMessageUpdateAlgorithm{ContractionAlg} <: AI.Algorithm - normalize::Bool = true - contraction_alg::ContractionAlg = Algorithm"exact" -end +function message_update! end -@kwdef mutable struct MessageUpdateState{Iterate} <: AI.State - iterate::Iterate +# Algorithm selection (MAK-style; see `AIE.select_algorithm`). +function AIE.default_algorithm(::typeof(message_update!); kwargs...) + return SimpleMessageUpdate(; kwargs...) end - -function AI.initialize_state( - ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm; iterate +function AIE.select_algorithm( + ::typeof(message_update!), alg::MessageUpdateAlgorithm; kwargs... ) - return MessageUpdateState(; iterate) + isempty(kwargs) || throw( + ArgumentError( + "Additional keyword arguments are not allowed when `alg` is a `MessageUpdateAlgorithm` instance." + ) + ) + return alg end -# Non-iterative algorithm: no per-call state to reset. -function AI.initialize_state!( - ::MessageUpdateProblem, ::SimpleMessageUpdateAlgorithm, state::MessageUpdateState +# Convenience entry: pick the strategy via `AIE.select_algorithm` +# (accepts either `alg = ::MessageUpdateAlgorithm` / `::NamedTuple`, or flat +# kwargs forwarded to the default algorithm), then dispatch. +function message_update!(cache, factors, edge; alg = nothing, kwargs...) + return message_update!( + AIE.select_algorithm(message_update!, alg; kwargs...), cache, factors, edge ) - return state end -# Non-iterative algorithm: bypass the step!/stopping-criterion loop. -function AI.solve_loop!( - problem::MessageUpdateProblem, - algorithm::SimpleMessageUpdateAlgorithm, - state::MessageUpdateState - ) - cache = state.iterate - edge = problem.edge +@kwdef struct SimpleMessageUpdate{ContractionAlg} <: MessageUpdateAlgorithm + normalize::Bool = true + contraction_alg::ContractionAlg = Algorithm"exact" +end +function message_update!(algorithm::SimpleMessageUpdate, cache, factors, edge) messages = collect(incoming_messages(cache, edge)) - factor = problem.factors[src(edge)] + factor = factors[src(edge)] new_message = contract_network([messages; [factor]]; algorithm.contraction_alg) @@ -238,8 +235,7 @@ function AI.solve_loop!( end cache[edge] = new_message - - return state + return nothing end # === `iterate_diff` for `MessageCache` (used by `AIE.StopWhenConverged`) === From 3c9c2d8cd9f98f6646a7554d2d77175cfcd4d759 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Wed, 20 May 2026 14:00:06 -0400 Subject: [PATCH 15/16] Add args tuple + AbstractAlgorithm supertype to `select_algorithm` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refine the MAK-style algorithm selector introduced in the previous commit: * `select_algorithm` / `default_algorithm` now take an `args` tuple as their final positional argument (after `alg`). A generic value→type wrapper forwards `(::Tuple)` to `(::Type{<:Tuple})`, so operations dispatch on the input-tuple type. Wrapping in a tuple keeps the value and type domains disjoint — `(1.2,)` is unambiguously a value tuple, `Tuple{Float64}` is unambiguously the type form. This matters at sites like `beliefpropagation` where not every input has a concrete value yet (no specific `edge`), and the type-form call simply passes `edgetype(factors)` in its slot. * New `AIE.AbstractAlgorithm` supertype carries the generic "passthrough an explicit algorithm instance" overload of `select_algorithm`, so each operation doesn't have to repeat that method. `MessageUpdateAlgorithm <: AIE.AbstractAlgorithm`. * BP call sites updated: `beliefpropagation` constructs the cache earlier and uses `Tuple{typeof(cache), typeof(factors), edgetype(factors)}` as the args type. The per-call `message_update!(cache, factors, edge; ...)` uses the value tuple `(cache, factors, edge)`. Both describe the same call shape. Co-Authored-By: Claude Opus 4.7 --- .../AlgorithmsInterfaceExtensions.jl | 44 ++++++++++++------- src/beliefpropagation/beliefpropagation.jl | 29 +++++------- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index eac073f..a95e0e0 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -54,26 +54,40 @@ end # ============================ select_algorithm / default_algorithm ======================== -# Modeled on `MatrixAlgebraKit.select_algorithm` / `default_algorithm`. An -# operation `f` (a function) declares its default algorithm strategy by -# overloading `default_algorithm(::typeof(f); kwargs...)`. The strategy can -# then be selected at a call site via `select_algorithm(f, alg; kwargs...)`, -# accepting either `nothing` (use defaults), a `NamedTuple` of keyword -# arguments forwarded to the default constructor, or an explicit algorithm -# instance — the last dispatched per-operation (e.g. on a strategy -# supertype). -function default_algorithm(f; kwargs...) - return throw(MethodError(default_algorithm, (f,))) -end - -select_algorithm(f, ::Nothing; kwargs...) = default_algorithm(f; kwargs...) -function select_algorithm(f, alg::NamedTuple; kwargs...) +# Like `MatrixAlgebraKit.select_algorithm` / `default_algorithm`, but +# selection-relevant inputs are packed into an `args` tuple so the value +# and type domains stay disjoint: `(1.2,)` vs `Tuple{Float64}`. Strategy +# types subtype `AbstractAlgorithm` so the passthrough overload is generic. +abstract type AbstractAlgorithm end + +function default_algorithm(f, ::Type{Args}; kwargs...) where {Args <: Tuple} + return throw(MethodError(default_algorithm, (f, Args))) +end +function default_algorithm(f, args::Tuple; kwargs...) + return default_algorithm(f, typeof(args); kwargs...) +end + +function select_algorithm(f, alg, args::Tuple; kwargs...) + return select_algorithm(f, alg, typeof(args); kwargs...) +end +function select_algorithm(f, ::Nothing, ::Type{Args}; kwargs...) where {Args <: Tuple} + return default_algorithm(f, Args; kwargs...) +end +function select_algorithm(f, alg::NamedTuple, ::Type{Args}; kwargs...) where {Args <: Tuple} isempty(kwargs) || throw( ArgumentError( "Additional keyword arguments are not allowed when `alg` is a `NamedTuple`." ) ) - return default_algorithm(f; alg...) + return default_algorithm(f, Args; alg...) +end +function select_algorithm(f, alg::AbstractAlgorithm, ::Type{<:Tuple}; kwargs...) + isempty(kwargs) || throw( + ArgumentError( + "Additional keyword arguments are not allowed when `alg` is an `AbstractAlgorithm` instance." + ) + ) + return alg end # ============================ StopWhenConverged =========================================== diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index 876e5cc..fbb80fc 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -3,7 +3,7 @@ import AlgorithmsInterface as AI using .AlgorithmsInterfaceExtensions: StopWhenConverged, iterate_diff using BackendSelection: @Algorithm_str, Algorithm using DataGraphs: edge_data -using Graphs: AbstractEdge, edges, has_edge, vertices +using Graphs: AbstractEdge, edges, edgetype, has_edge, vertices using LinearAlgebra: norm, normalize using NamedDimsArrays: AbstractNamedDimsArray using NamedGraphs.GraphsExtensions: add_edges!, boundary_edges, subgraph @@ -48,9 +48,13 @@ function beliefpropagation( message_update_algorithm = nothing ) problem = BeliefPropagationProblem(factors) + cache = MessageCache(messages) + # No concrete `edge` value here, so the args tuple uses `edgetype(factors)`. message_update_algorithm = AIE.select_algorithm( - message_update!, message_update_algorithm + message_update!, + message_update_algorithm, + Tuple{typeof(cache), typeof(factors), edgetype(factors)} ) subalgorithm = BeliefPropagationSweepAlgorithm(; message_update_algorithm, @@ -59,8 +63,6 @@ function beliefpropagation( stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) algorithm = BeliefPropagationAlgorithm(; edges, subalgorithm, stopping_criterion) - cache = MessageCache(messages) - return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) end @@ -188,31 +190,22 @@ end # message is computed and written back into the message store. Plug in a # new strategy by subtyping `MessageUpdateAlgorithm` and overloading # `message_update!(strategy, cache, factors, edge)`. -abstract type MessageUpdateAlgorithm end +abstract type MessageUpdateAlgorithm <: AIE.AbstractAlgorithm end function message_update! end -# Algorithm selection (MAK-style; see `AIE.select_algorithm`). -function AIE.default_algorithm(::typeof(message_update!); kwargs...) +# `args` tuple mirrors the `message_update!(cache, factors, edge)` call shape. +function AIE.default_algorithm(::typeof(message_update!), ::Type{<:Tuple}; kwargs...) return SimpleMessageUpdate(; kwargs...) end -function AIE.select_algorithm( - ::typeof(message_update!), alg::MessageUpdateAlgorithm; kwargs... - ) - isempty(kwargs) || throw( - ArgumentError( - "Additional keyword arguments are not allowed when `alg` is a `MessageUpdateAlgorithm` instance." - ) - ) - return alg -end # Convenience entry: pick the strategy via `AIE.select_algorithm` # (accepts either `alg = ::MessageUpdateAlgorithm` / `::NamedTuple`, or flat # kwargs forwarded to the default algorithm), then dispatch. function message_update!(cache, factors, edge; alg = nothing, kwargs...) return message_update!( - AIE.select_algorithm(message_update!, alg; kwargs...), cache, factors, edge + AIE.select_algorithm(message_update!, alg, (cache, factors, edge); kwargs...), + cache, factors, edge ) end From 5d56f9f32611012966459515d8b8f8f5c7564bd7 Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Wed, 20 May 2026 18:01:52 -0400 Subject: [PATCH 16/16] Return mutated cache from `message_update!`; accept `tol` in stopping_criterion `message_update!` now follows the convention of mutating-functions returning the object they modify (the message cache), so callers can chain it. `select_beliefpropagation_stopping_criterion` accepts a `tol` keyword in the NamedTuple form alongside the existing `maxiter`. Either or both may be specified; when both are given they are combined with `|` (stop on whichever fires first). The spin-ice test now uses the NamedTuple form `stopping_criterion = (; maxiter = 10, tol = 1.0e-10)` to exercise this path, in place of the explicit `StopAfterIteration | StopWhenConverged` construction. Co-Authored-By: Claude Opus 4.7 --- src/beliefpropagation/beliefpropagation.jl | 29 ++++++++++++++++------ test/test_beliefpropagation.jl | 6 ++--- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index fbb80fc..d6dfabc 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -18,7 +18,8 @@ function select_beliefpropagation_stopping_criterion(::Nothing) return throw( ArgumentError( "`stopping_criterion` must be specified, e.g.\n" * - " `stopping_criterion = (; maxiter = 10)` or\n" * + " `stopping_criterion = (; maxiter = 10)`,\n" * + " `stopping_criterion = (; maxiter = 10, tol = 1.0e-10)`, or\n" * " `stopping_criterion = AI.StopAfterIteration(10) | StopWhenConverged(1.0e-10)`." ) ) @@ -26,19 +27,31 @@ end function select_beliefpropagation_stopping_criterion(kwargs::NamedTuple) return select_beliefpropagation_stopping_criterion(; kwargs...) end -function select_beliefpropagation_stopping_criterion(; maxiter = nothing, kwargs...) - if isnothing(maxiter) - throw(ArgumentError("`maxiter` must be specified in `stopping_criterion`.")) - end +function select_beliefpropagation_stopping_criterion(; + maxiter = nothing, tol = nothing, kwargs... + ) if !isempty(kwargs) throw( ArgumentError( "Unrecognized `stopping_criterion` kwargs: $(keys(kwargs)). " * - "Only `maxiter` is currently supported." + "Supported: `maxiter`, `tol`." ) ) end - return AI.StopAfterIteration(maxiter) + if isnothing(maxiter) && isnothing(tol) + throw( + ArgumentError("At least one of `maxiter` or `tol` must be specified.") + ) + end + criterion = nothing + if !isnothing(maxiter) + criterion = AI.StopAfterIteration(maxiter) + end + if !isnothing(tol) + converged = StopWhenConverged(; tol) + criterion = isnothing(criterion) ? converged : criterion | converged + end + return criterion end function beliefpropagation( @@ -228,7 +241,7 @@ function message_update!(algorithm::SimpleMessageUpdate, cache, factors, edge) end cache[edge] = new_message - return nothing + return cache end # === `iterate_diff` for `MessageCache` (used by `AIE.StopWhenConverged`) === diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 4362f7e..34cb305 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -202,11 +202,9 @@ end messages = Dict(edge => randt(tn, edge) for edge in all_edges(g)) - stopping_criterion = - AI.StopAfterIteration(10) | StopWhenConverged(tol = 1.0e-10) - cache = ITensorNetworksNext.beliefpropagation( - tn, messages; stopping_criterion + tn, messages; + stopping_criterion = (; maxiter = 10, tol = 1.0e-10) ) z_bp = exp(bethe_free_energy(tn, cache))