diff --git a/Project.toml b/Project.toml index c8358d0..5c8f666 100644 --- a/Project.toml +++ b/Project.toml @@ -30,9 +30,11 @@ WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [weakdeps] TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" [extensions] ITensorNetworksNextTensorOperationsExt = "TensorOperations" +ITensorNetworksNextDaggerExt = "Dagger" [compat] AbstractTrees = "0.4.5" @@ -40,6 +42,7 @@ Adapt = "4.3" AlgorithmsInterface = "0.1" BackendSelection = "0.1.6" Combinatorics = "1" +Dagger = "0.19.0" DataGraphs = "0.4" DiagonalArrays = "0.3.31" Dictionaries = "0.4.5" diff --git a/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl new file mode 100644 index 0000000..837a3d3 --- /dev/null +++ b/ext/ITensorNetworksNextDaggerExt/ITensorNetworksNextDaggerExt.jl @@ -0,0 +1,84 @@ +module ITensorNetworksNextDaggerExt + +import AlgorithmsInterface as AI +import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +import ITensorNetworksNext.ITensorNetworksNextParallel as ITNNP +using Dagger +using Dictionaries: Dictionary, set! + +@kwdef mutable struct DaggerState{Chunk <: Dagger.Chunk} <: AI.State + chunk::Chunk # A chunk living on the host worker. + futures = Dictionary{Int, Dagger.DTask}() # the futures from the remote steps. +end + +function Base.getproperty(state::DaggerState, name::Symbol) + if name in (:chunk, :futures) + return getfield(state, name) + end + return getproperty(fetch(state.chunk), name) +end +function Base.setproperty!(state::DaggerState, name::Symbol, val) + if name === (:chunk, :futures) + return setfield!(state, name, val) + end + fetch(Dagger.@spawn setproperty!(state.chunk, name, val)) + return state +end + +# ====================================== overloads ======================================= # + +function ITNNP.default_workers(::AI.Algorithm, ::ITNNP.AbstractDaggerStrategy) + return Dagger.Distributed.workers() +end + +function ITNNP.initialize_parallel_state( + problem::AI.Problem, algorithm::AI.Algorithm, + _strategy::ITNNP.GenericDaggerStrategy; kwargs... + ) + chunk = Dagger.@mutable AI.initialize_state(problem, algorithm; iterate, kwargs...) + + return DaggerState(; chunk) +end + +function AI.step!(problem::AI.Problem, algorithm::ITNNP.Parallelized, state::DaggerState) + worker_list = algorithm.workers + algorithm = algorithm.parent + + function remote_solve(state) + subsolve = AIE.initialize_subsolve(problem, algorithm, state) + + subproblem, subalgorithm, substate = subsolve + AI.solve!(subproblem, subalgorithm, substate) + + return subsolve + end + + worker = worker_list[mod(state.iteration, 1:length(worker_list))] + + dtask = Dagger.@spawn scope = Dagger.scope(; worker) remote_solve(fetch(state.chunk)) + + # Spawns on the host only (as we do not fetch the chunk before hand.) + dtask = Dagger.spawn(dtask, state.chunk) do subsolve, state + subproblem, subalgorithm, substate = subsolve + + AIE.finalize_substate!(subproblem, subalgorithm, substate, state) + return state + end + + set!(state.futures, state.iteration, dtask) + + return state +end +function AI.finalize_state!(::AI.Problem, ::AI.Algorithm, state::DaggerState) + foreach(fetch, state.futures) + return state.iterate +end + +function AIE.finalize_substate!( + problem::AI.Problem, algorithm::AI.Algorithm, substate::DaggerState, state::AI.State + ) + AIE.finalize_substate!(problem, algorithm, fetch(substate.chunk), state) + return state +end + +end # ITensorNetworksNextDaggerExt diff --git a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl index a95e0e0..82b3536 100644 --- a/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl +++ b/src/AlgorithmsInterfaceExtensions/AlgorithmsInterfaceExtensions.jl @@ -2,6 +2,37 @@ module AlgorithmsInterfaceExtensions import AlgorithmsInterface as AI +@kwdef mutable struct GenericAlgorithmState{Iterate, StoppingCriterionState} <: AI.State + iterate::Iterate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + +function AI.initialize_state( + problem::AI.Problem, + algorithm::AI.Algorithm; + iterate, + iteration = 0 + ) + stopping_criterion_state = AI.initialize_state( + problem, algorithm, algorithm.stopping_criterion; iterate + ) + return GenericAlgorithmState(; iterate, iteration, stopping_criterion_state) +end + +function AI.initialize_state!( + problem::AI.Problem, + algorithm::AI.Algorithm, + state::AI.State; + iteration = 0 + ) + state.iteration = iteration + AI.initialize_state!( + problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state + ) + return state +end + # ============================ NestedAlgorithm ============================================= abstract type NestedAlgorithm <: AI.Algorithm end @@ -18,7 +49,7 @@ function initialize_subsolve( end function finalize_substate!( - problem::AI.Problem, algorithm::AI.Algorithm, state::AI.State, substate::AI.State + ::AI.Problem, ::AI.Algorithm, substate::AI.State, state::AI.State ) state.iterate = substate.iterate return state @@ -27,7 +58,7 @@ end function AI.step!(problem::AI.Problem, algorithm::NestedAlgorithm, state::AI.State) subproblem, subalgorithm, substate = initialize_subsolve(problem, algorithm, state) AI.solve!(subproblem, subalgorithm, substate) - finalize_substate!(problem, algorithm, state, substate) + finalize_substate!(subproblem, subalgorithm, substate, state) return state end @@ -38,6 +69,12 @@ end # state. Subtypes must store the inner state as a field named `substate`. abstract type NestedState <: AI.State end +@kwdef mutable struct GenericNestedState{Substate, StoppingCriterionState} <: NestedState + substate::Substate + iteration::Int = 0 + stopping_criterion_state::StoppingCriterionState +end + # Use `getfield` on the right-hand side so future edits to this forwarder # can't accidentally recurse through the overload. function Base.getproperty(state::NestedState, name::Symbol) diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 41ce78e..4d42c44 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -15,5 +15,9 @@ include("contract_network.jl") include("beliefpropagation/messagecache.jl") include("beliefpropagation/beliefpropagation.jl") +include("beliefpropagation/partitioned.jl") + +# lib +include("ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl") end diff --git a/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl new file mode 100644 index 0000000..333edea --- /dev/null +++ b/src/ITensorNetworksNextParallel/ITensorNetworksNextParallel.jl @@ -0,0 +1,60 @@ +module ITensorNetworksNextParallel + +import ..ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE +import AlgorithmsInterface as AI + +abstract type AbstractParallelizationStrategy end + +function default_workers end +function initialize_parallel_state end + +@kwdef struct Parallelized{Strategy, Workers, Algorithm <: AI.Algorithm} <: AI.Algorithm + parent::Algorithm + strategy::Strategy + workers::Workers = default_workers(parent, strategy) +end + +function Base.getproperty(algorithm::Parallelized, name::Symbol) + if name in (:parent, :strategy, :workers) + return getfield(algorithm, name) + end + return getproperty(getfield(algorithm, :parent), name) +end + +function AI.initialize_state(problem::AI.Problem, algorithm::Parallelized; kwargs...) + return initialize_parallel_state( + problem, + algorithm.parent, + algorithm.strategy; + kwargs... + ) +end + +# ====================================== Dagger.jl ======================================= # + +abstract type AbstractDaggerStrategy <: AbstractParallelizationStrategy end +struct GenericDaggerStrategy <: AbstractDaggerStrategy end + +function initialize_parallel_state( + _problem, + _algorithm, + strategy::AbstractDaggerStrategy; + _kwargs... + ) + throw( + ArgumentError( + "package Dagger.jl not loaded; please install and load Dagger.jl to use \ + strategy of type $(typeof(strategy))." + ) + ) +end + +function default_workers(algorithm, strategy::AbstractDaggerStrategy) + @warn( + "package Dagger.jl may not be loaded; please install and load Dagger.jl to use \ + strategy of type `$(typeof(strategy))`" + ) + throw(MethodError(default_workers, (algorithm, strategy))) +end + +end # ITensorNetworksNextParallel diff --git a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl index 172ec08..44bae0a 100644 --- a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl +++ b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl @@ -5,6 +5,9 @@ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = function symnameddims(symname, dims) return lazy(nameddims(SymbolicArray(symname, denamed.(dims)), name.(dims))) end +function symnameddims(name, ndarray::AbstractNamedDimsArray) + return symnameddims(name, Tuple(inds(ndarray))) +end symnameddims(name) = symnameddims(name, ()) using AbstractTrees: AbstractTrees function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 121073d..802e8b8 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -34,6 +34,7 @@ Base.copy(::AbstractTensorNetwork) = not_implemented() # Iteration Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) + Base.keys(tn::AbstractTensorNetwork) = vertices(tn) # TODO: This contrasts with the `DataGraphs.AbstractDataGraph` definition, diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl new file mode 100644 index 0000000..a810c9a --- /dev/null +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -0,0 +1,161 @@ +using DataGraphs: AbstractDataGraph, edge_data, edge_data_type, vertex_data +using Graphs: AbstractEdge, AbstractGraph +using NamedGraphs.GraphsExtensions: boundary_edges +using NamedGraphs.PartitionedGraphs: QuotientEdge, QuotientView, parent, QuotientVertex + +messages(bp_cache::AbstractGraph) = edge_data(bp_cache) +messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges] + +message(bp_cache::AbstractGraph, edge::AbstractEdge) = messages(bp_cache)[edge] + +deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented() +function deletemessage!(bp_cache::AbstractDataGraph, edge) + ms = messages(bp_cache) + delete!(ms, edge) + return bp_cache +end + +function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache)) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache +end + +setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented() +function setmessage!(bp_cache::AbstractDataGraph, edge, message) + setindex!(bp_cache, message, edge) + return bp_cache +end +function setmessage!(bp_cache::QuotientView, edge, message) + setmessages!(parent(bp_cache), QuotientEdge(edge), message) + return bp_cache +end + +function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message) + for e in edges(bp_cache, edge) + setmessage!(parent(bp_cache), e, message[e]) + end + return bp_cache +end +function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges) + for e in edges + setmessage!(bpc_dst, e, message(bpc_src, e)) + end + return bpc_dst +end + +factors(bpc::AbstractGraph) = vertex_data(bpc) +factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices] +factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex]) + +factor(bpc::AbstractGraph, vertex) = bpc[vertex] + +setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented() +function setfactor!(bpc::AbstractDataGraph, vertex, factor) + fs = factors(bpc) + setindex!(fs, vertex, factor) + return bpc +end + +function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge) + return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] +end + +function region_scalar(bp_cache::AbstractGraph, vertex) + messages = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + + return (reduce(*, messages) * reduce(*, state))[] +end + +message_type(bpc::AbstractGraph) = message_type(typeof(bpc)) +message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G)) +message_type(type::Type{<:AbstractDataGraph}) = edge_data_type(type) + +function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache)) + return map(v -> region_scalar(bp_cache, v), vertices) +end + +function edge_scalars( + bp_cache::AbstractGraph, + edges = edges(undirected_graph(underlying_graph(bp_cache))) + ) + return map(e -> region_scalar(bp_cache, e), edges) +end + +function scalar_factors_quotient(bp_cache::AbstractGraph) + return vertex_scalars(bp_cache), edge_scalars(bp_cache) +end + +function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = []) + b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in) + b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges + return messages(bp_cache, b_edges) +end + +default_messages(::AbstractGraph) = not_implemented() + +#Adapt interface for changing device +map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es) +function map_messages!(f, bp_cache, es = edges(bp_cache)) + for e in es + setmessage!(bp_cache, e, f(message(bp_cache, e))) + end + return bp_cache +end + +map_factors(f, bp_cache, vs = vertices(bp_cache)) = map_factors!(f, copy(bp_cache), vs) +function map_factors!(f, bp_cache, vs = vertices(bp_cache)) + for v in vs + setfactor!(bp_cache, v, f(factor(bp_cache, v))) + end + return bp_cache +end + +adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es) +adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs) + +abstract type AbstractBeliefPropagationCache{V, VD, ED} <: AbstractDataGraph{V, VD, ED} end + +factor_type(bpc::AbstractBeliefPropagationCache) = typeof(bpc) +factor_type(::Type{<:AbstractBeliefPropagationCache{<:Any, VD}}) where {VD} = VD + +message_type(bpc::AbstractBeliefPropagationCache) = message_type(typeof(bpc)) +message_type(::Type{<:AbstractBeliefPropagationCache{<:Any, <:Any, ED}}) where {ED} = ED + +function free_energy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + if any(iszero, denominator_terms) + return -Inf + end + + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) +end +partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache)) + +function subcache(cache::AbstractBeliefPropagationCache, vertex::QuotientVertex) + return subcache(cache, vertices(cache, vertex)) +end +function subcache(cache::AbstractBeliefPropagationCache, vertices) + subcache = subgraph(cache, vertices) + + for vertex in vertices + for neighbor_vertex in neighbors(cache, vertex) + add_vertex!(subcache, neighbor_vertex) + # Add in necessary messages. + subcache[vertex => neighbor_vertex] = cache[vertex => neighbor_vertex] + subcache[neighbor_vertex => vertex] = cache[neighbor_vertex => vertex] + end + end + + return subcache +end diff --git a/src/beliefpropagation/beliefpropagation.jl b/src/beliefpropagation/beliefpropagation.jl index d6dfabc..db134b1 100644 --- a/src/beliefpropagation/beliefpropagation.jl +++ b/src/beliefpropagation/beliefpropagation.jl @@ -95,14 +95,6 @@ end stopping_criterion::StoppingCriterion end -@kwdef mutable struct BeliefPropagationState{ - Substate <: AI.State, StoppingCriterionState <: AI.StoppingCriterionState, - } <: AIE.NestedState - substate::Substate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState -end - function AI.initialize_state( problem::BeliefPropagationProblem, algorithm::BeliefPropagationAlgorithm; @@ -110,29 +102,18 @@ function AI.initialize_state( ) subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) substate = AI.initialize_state(subproblem, algorithm.subalgorithm; iterate) + stopping_criterion_state = AI.initialize_state( problem, algorithm, algorithm.stopping_criterion; iterate ) - return BeliefPropagationState(; iteration, stopping_criterion_state, substate) -end -function AI.initialize_state!( - problem::BeliefPropagationProblem, - algorithm::BeliefPropagationAlgorithm, - state::BeliefPropagationState; - iteration::Int = 0 - ) - state.iteration = iteration - AI.initialize_state!( - problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state - ) - return state + return AIE.GenericNestedState(; iteration, stopping_criterion_state, substate) end function AIE.initialize_subsolve( problem::BeliefPropagationProblem, algorithm::BeliefPropagationAlgorithm, - state::BeliefPropagationState + state::AIE.NestedState ) subproblem = BeliefPropagationSweepProblem(problem.factors, algorithm.edges) return subproblem, algorithm.subalgorithm, state.substate @@ -153,42 +134,10 @@ end stopping_criterion::StoppingCriterion end -@kwdef mutable struct BeliefPropagationSweepState{ - Iterate, StoppingCriterionState <: AI.StoppingCriterionState, - } <: AI.State - iterate::Iterate - iteration::Int = 0 - stopping_criterion_state::StoppingCriterionState -end - -function AI.initialize_state( - problem::BeliefPropagationSweepProblem, - algorithm::BeliefPropagationSweepAlgorithm; - iterate, iteration::Int = 0 - ) - stopping_criterion_state = AI.initialize_state( - problem, algorithm, algorithm.stopping_criterion; iterate - ) - return BeliefPropagationSweepState(; iterate, iteration, stopping_criterion_state) -end - -function AI.initialize_state!( - problem::BeliefPropagationSweepProblem, - algorithm::BeliefPropagationSweepAlgorithm, - state::BeliefPropagationSweepState; - iteration::Int = 0 - ) - state.iteration = iteration - AI.initialize_state!( - problem, algorithm, algorithm.stopping_criterion, state.stopping_criterion_state - ) - return state -end - function AI.step!( problem::BeliefPropagationSweepProblem, algorithm::BeliefPropagationSweepAlgorithm, - state::BeliefPropagationSweepState + state::AI.State ) edge = problem.edges[state.iteration] message_update!( @@ -243,7 +192,6 @@ function message_update!(algorithm::SimpleMessageUpdate, cache, factors, edge) cache[edge] = new_message return cache end - # === `iterate_diff` for `MessageCache` (used by `AIE.StopWhenConverged`) === function AIE.iterate_diff(cache1::MessageCache, cache2::MessageCache) diff --git a/src/beliefpropagation/partitioned.jl b/src/beliefpropagation/partitioned.jl new file mode 100644 index 0000000..243f896 --- /dev/null +++ b/src/beliefpropagation/partitioned.jl @@ -0,0 +1,126 @@ +@kwdef struct PartitionedBeliefPropagationSweep{ + Subfactors, + MessageUpdateAlgorithm, + } <: AIE.NestedAlgorithm + partitioned_factors::Vector{Subfactors} + message_update_algorithm::MessageUpdateAlgorithm +end + +function Base.getproperty(alg::PartitionedBeliefPropagationSweep, name::Symbol) + if name === :stopping_criterion + return AI.StopAfterIteration(length(alg.partitioned_factors)) + end + return getfield(alg, name) +end + +function AIE.initialize_subsolve( + problem::BeliefPropagationSweepProblem, + algorithm::PartitionedBeliefPropagationSweep, + state::AI.State + ) + subvertices = algorithm.partitioned_factors[state.iteration] + cache = state.iterate + + incoming_edges = boundary_edges(problem.factors, subvertices; dir = :in) + + factors = subgraph(problem.factors, subvertices) + iterate = subgraph(cache, subvertices) + + + for edge in incoming_edges + add_vertex!(iterate, src(edge)) + add_vertex!(iterate, dst(edge)) + + iterate[edge] = cache[edge] + iterate[reverse(edge)] = cache[reverse(edge)] + end + + # Don't want to update the incoming messages. + subedges = setdiff(edges(iterate), incoming_edges) + + subproblem = BeliefPropagationSweepProblem(factors, subedges) + + subalgorithm = BeliefPropagationSweepAlgorithm(; + message_update_algorithm = algorithm.message_update_algorithm, + stopping_criterion = AI.StopAfterIteration(length(subedges)) + ) + + substate = AI.initialize_state(subproblem, subalgorithm; iterate) + + return subproblem, subalgorithm, substate +end + +function AIE.finalize_substate!( + subproblem::BeliefPropagationSweepProblem, + _subalgorithm::BeliefPropagationSweepAlgorithm, + substate::AI.State, + state::AI.State + ) + subcache = substate.iterate + subedges = subproblem.edges + + for edge in subedges + state.iterate[edge] = subcache[edge] + end + + return state +end + +function beliefpropagation( + factors, messages, partitions; + edges = default_beliefpropagation_edges(factors), + stopping_criterion = nothing, + message_update_algorithm = nothing, + sweep_algorithm = nothing + ) + problem = BeliefPropagationProblem(factors) + cache = MessageCache(messages) + + # No concrete `edge` value here, so the args tuple uses `edgetype(factors)`. + message_update_algorithm = AIE.select_algorithm( + message_update!, + message_update_algorithm, + Tuple{typeof(cache), typeof(factors), edgetype(factors)} + ) + + if isnothing(sweep_algorithm) + sweep_algorithm = PartitionedBeliefPropagationSweep(; + partitioned_factors = partitions, + message_update_algorithm + ) + end + stopping_criterion = select_beliefpropagation_stopping_criterion(stopping_criterion) + algorithm = BeliefPropagationAlgorithm(; + edges, + subalgorithm = sweep_algorithm, + stopping_criterion + ) + + return AI.solve(problem, algorithm; iterate = cache) # -> typeof(cache) +end + +# function message_update!(algorithm::KrylovMessageUpdate, cache, factors, path) +# function f(messages) +# temp_cache = copy(cache) # shallow copy. +# +# for (edge, message) in zip(path, messages) +# temp_cache[edge] = message +# end +# +# for (i, edge) in enumerate(path) +# message_update!( +# algorithm.message_update_algorithm, +# temp_cache, +# factors, +# edge +# ) +# messages[i] = temp_cache[edge] +# end +# +# return messages +# end +# +# # do eigsolve step. +# +# return cache +# end diff --git a/test/test_algorithmsinterfaceextensions.jl b/test/test_algorithmsinterfaceextensions.jl index f580826..6ef4ce8 100644 --- a/test/test_algorithmsinterfaceextensions.jl +++ b/test/test_algorithmsinterfaceextensions.jl @@ -115,7 +115,7 @@ end # `finalize_substate!` copies the substate's iterate back into the # parent state. substate = AI.initialize_state(problem, algorithm; iterate = [42.0]) - AIE.finalize_substate!(problem, algorithm, state, substate) + AIE.finalize_substate!(problem, algorithm, substate, state) @test state.iterate == [42.0] end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 34cb305..1610c72 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -213,5 +213,28 @@ end end end end + @testset "PartitionedBeliefPropagationSweep" begin + n = 4 + dims = (n, n) + g = named_grid(dims; periodic = true) + tn = spin_ice_tensornetwork(g) + + messages = + Dict(edge => rand(Tuple(linkinds(tn, edge))) for edge in all_edges(g)) + + partitions = [ + [ij, ij .+ (0, 1), ij .+ (1, 0), ij .+ (1, 1)] for + ij in [(1, 1), (1, 3), (3, 1), (3, 3)] + ] + + cache = ITensorNetworksNext.beliefpropagation( + tn, messages, partitions; + stopping_criterion = (; maxiter = 30, tol = 1.0e-10) + ) + + z_bp = exp(bethe_free_energy(tn, cache)) + + @test z_bp ≈ 1.5^(n^2) + end end end