From a1e9facc2a35065a8ef215d9960323a719410498 Mon Sep 17 00:00:00 2001 From: Johannes Terblanche Date: Fri, 3 Nov 2023 08:20:22 +0200 Subject: [PATCH 1/3] wip parametric tree solve --- .../services/CliqueStateMachine.jl | 2 +- .../src/Factors/GenericFunctions.jl | 76 +++++++ .../services/ParametricCSMFunctions.jl | 194 +++++++++++++++++- .../src/services/TreeMessageUtils.jl | 22 ++ 4 files changed, 282 insertions(+), 12 deletions(-) diff --git a/IncrementalInference/src/CliqueStateMachine/services/CliqueStateMachine.jl b/IncrementalInference/src/CliqueStateMachine/services/CliqueStateMachine.jl index d15f4d01..1cab9498 100644 --- a/IncrementalInference/src/CliqueStateMachine/services/CliqueStateMachine.jl +++ b/IncrementalInference/src/CliqueStateMachine/services/CliqueStateMachine.jl @@ -956,7 +956,7 @@ function updateFromSubgraph_StateMachine(csmc::CliqStateMachineContainer) logCSM( csmc, "CSM-5 Clique $(csmc.cliq.id) finished, solveKey=$(csmc.solveKey)"; - loglevel = Logging.Info, + loglevel = Logging.Debug, ) return IncrementalInference.exitStateMachine end diff --git a/IncrementalInference/src/Factors/GenericFunctions.jl b/IncrementalInference/src/Factors/GenericFunctions.jl index b8a2e69d..7cd59527 100644 --- a/IncrementalInference/src/Factors/GenericFunctions.jl +++ b/IncrementalInference/src/Factors/GenericFunctions.jl @@ -94,6 +94,82 @@ function (cf::CalcFactor{<:ManifoldFactor})(X, p, q) return measurement_residual(cf.factor.M, X, p, q) end +## ====================================================================================== +## adjoint factor - adjoint action applied to the measurement +## ====================================================================================== +function Ad(::Union{typeof(SpecialEuclidean(2)), typeof(SpecialEuclidean(3))}, p, X) + t = p.x[1] + R = p.x[2] + v = X.x[1] + Ω = X.x[2] + ArrayPartition(-R*Ω*R'*t + R*v, R*Ω*R') +end + +function Ad(::typeof(SpecialEuclidean(3)), p) + t = p.x[1] + R = p.x[2] + vcat( + hcat(R, skew(t)*R), + hcat(zero(SMatrix{3,3,Float64}), R) + ) +end + +function Ad(::typeof(SpecialEuclidean(2)), p) + t = p.x[1] + R = p.x[2] + vcat( + hcat(R, -SA[0 -1; 1 0]*t), + SA[0 0 1] + ) +end + +struct AdFactor{F <: AbstractManifoldMinimize} <: AbstractManifoldMinimize + factor::F +end + +function (cf::CalcFactor{<:AdFactor})(Xϵ, p, q) + # M = getManifold(cf.factor) + # p,q ∈ M + # Xϵ ∈ TϵM + # ϵ = identity_element(M) + # transform measurement from TϵM to TpM (global to local coordinates) + # Adₚ⁻¹ = AdjointMatrix(M, p)⁻¹ = AdjointMatrix(M, p⁻¹) + # Xp = Adₚ⁻¹ * Xϵᵛ + # ad = Ad(M, inv(M, p)) + # Xp = Ad(M, inv(M, p), Xϵ) + # Xp = adjoint_action(M, inv(M, p), Xϵ) + #TODO is vector transport supposed to be the same? + # Xp = vector_transport_to(M, ϵ, Xϵ, p) + + # Transform measurement covariance + # ᵉΣₚ = Adₚ ᵖΣₚ Adₚᵀ + #TODO test if transforming sqrt_iΣ is the same as Σ + # Σ = ad * inv(cf.sqrt_iΣ^2) * ad' + # sqrt_iΣ = convert(typeof(cf.sqrt_iΣ), sqrt(inv(Σ))) + # sqrt_iΣ = convert(typeof(cf.sqrt_iΣ), ad * cf.sqrt_iΣ * ad') + Xp = Xϵ + + child_cf = CalcFactorResidual( + cf.faclbl, + cf.factor.factor, + cf.varOrder, + cf.varOrderIdxs, + cf.meas, + cf.sqrt_iΣ, + cf.cache, + ) + return child_cf(Xp, p, q) +end + +getMeasurementParametric(f::AdFactor) = getMeasurementParametric(f.factor) + +getManifold(f::AdFactor) = getManifold(f.factor) +function getSample(cf::CalcFactor{<:AdFactor}) + M = getManifold(cf) + return sampleTangent(M, cf.factor.factor.Z) +end + + ## ====================================================================================== ## adjoint factor - adjoint action applied to the measurement ## ====================================================================================== diff --git a/IncrementalInference/src/parametric/services/ParametricCSMFunctions.jl b/IncrementalInference/src/parametric/services/ParametricCSMFunctions.jl index 16c692eb..73a24135 100644 --- a/IncrementalInference/src/parametric/services/ParametricCSMFunctions.jl +++ b/IncrementalInference/src/parametric/services/ParametricCSMFunctions.jl @@ -5,7 +5,7 @@ Notes - Parametric state machine function nr. 3 """ -function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer) +function solveUp_ParametricStateMachine_Old(csmc::CliqStateMachineContainer) infocsm(csmc, "Par-3, Solving Up") setCliqueDrawColor!(csmc.cliq, "red") @@ -96,6 +96,145 @@ function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer) return waitForDown_StateMachine end +# solve relatives ignoring any priors keeping `from` at ϵ +# if clique has priors : solve to get a prior on `from` +# send messages as factors or just the beliefs? for now factors +function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer) + infocsm(csmc, "Par-3, Solving Up") + + setCliqueDrawColor!(csmc.cliq, "red") + # csmc.drawtree ? drawTree(csmc.tree, show=false, filepath=joinpath(getSolverParams(csmc.dfg).logpath,"bt.pdf")) : nothing + + msgfcts = Symbol[] + + for (idx, upmsg) in getMessageBuffer(csmc.cliq).upRx #get cached messages taken from children saved in this clique + child_factors = addMsgFactors_Parametric!(csmc.cliqSubFg, upmsg, UpwardPass) + append!(msgfcts, getLabel.(child_factors)) # addMsgFactors_Parametric! + end + logCSM(csmc, "length mgsfcts=$(length(msgfcts))") + infocsm(csmc, "length mgsfcts=$(length(msgfcts))") + + # store the cliqSubFg for later debugging + _dbgCSMSaveSubFG(csmc, "fg_beforeupsolve") + + subfg = csmc.cliqSubFg + + frontals = getCliqFrontalVarIds(csmc.cliq) + separators = getCliqSeparatorVarIds(csmc.cliq) + + # if its a root do full solve + if length(getParent(csmc.tree, csmc.cliq)) == 0 + # M, vartypeslist, lm_r, Σ = solve_RLM(subfg; is_sparse=false, finiteDiffCovariance=true) + autoinitParametric!(subfg) + M, vartypeslist, lm_r, Σ = solveGraphParametric!(subfg; is_sparse=false, finiteDiffCovariance=true, damping_term_min=1e-18) + + else + + # select first seperator as constant reference at the identity element + isempty(separators) && @warn "empty separators solving cliq $(csmc.cliq.id.value)" ls(subfg) lsf(subfg) + from = first(separators) + from_v = getVariable(subfg, from) + getSolverData(from_v, :parametric).val[1] = getPointIdentity(getVariableType(from_v)) + + #TODO handle priors + # Variables that are free to move + free_vars = [frontals; separators[2:end]] + # Solve for the free variables + + @assert !isempty(lsf(subfg)) "No factors in clique $(csmc.cliq.id.value) ls=$(ls(subfg)) lsf=$(lsf(subfg))" + + # M, vartypeslist, lm_r, Σ = solve_RLM_conditional(subfg, free_vars, [from];) + M, vartypeslist, lm_r, Σ = solve_RLM_conditional(subfg, free_vars, [from]; finiteDiffCovariance=false, damping_term_min=1e-18) + + end + + # FIXME check solve convergence + if !true + @error "Par-3, clique $(csmc.cliq.id) failed to converge in upsolve" result + # propagate error to cleanly exit all cliques + putErrorUp(csmc) + if length(getParent(csmc.tree, csmc.cliq)) == 0 + putErrorDown(csmc) + return IncrementalInference.exitStateMachine + end + + return waitForDown_StateMachine + end + + logCSM(csmc, "$(csmc.cliq.id): subfg solve converged sending messages") + + # Pack results in massage factors + + sigmas = extractMarginalsAP(M, vartypeslist, Σ) + + # FIXME fix MsgRelativeType + relative_message_factors = MsgRelativeType(); + for (i, to) in enumerate(vartypeslist) + if to in separators + #assume full dim factor + factype = selectFactorType(subfg, from, to) + # make S symetrical + # S = sigmas[i] # FIXME for some reason SMatrix is not invertable even though it is!!!!!!!! + S = Matrix(sigmas[i])# FIXME + S = (S + S') / 2 + # @assert all(isapprox.(S, sigmas[i], rtol=1e-3)) "Bad covariance matrix - not symetrical" + !all(isapprox.(S, sigmas[i], rtol=1e-3)) && @error("Bad covariance matrix - not symetrical") + # @assert all(diag(S) .> 0) "Bad covariance matrix - not positive diag" + !all(diag(S) .> 0) && @error("Bad covariance matrix - not positive diag") + + + M_to = getManifold(getVariableType(subfg, to)) + ϵ = getPointIdentity(M_to) + μ = vee(M_to, ϵ, log(M_to, ϵ, lm_r[i])) + + message_factor = AdFactor(factype(MvNormal(μ, S))) + + + # logCSM(csmc, "$(csmc.cliq.id): Z=$(getMeasurementParametric(message_factor))"; loglevel = Logging.Warn) + + push!(relative_message_factors, (variables=[from, to], likelihood=message_factor)) + end + end + + # Done with solve delete factors + #TODO confirm, maybe don't delete mesage factors on subgraph, maybe delete if its priors, but not conditionals + # deleteMsgFactors!(csmc.cliqSubFg) + + # store the cliqSubFg for later debugging + _dbgCSMSaveSubFG(csmc, "fg_afterupsolve") + + # cliqueLikelihood = calculateMarginalCliqueLikelihood(vardict, Σ, varIds, cliqSeparatorVarIds) + + #Fill in CliqueLikelihood + beliefMsg = LikelihoodMessage(; + sender = (; id = csmc.cliq.id.value, step = csmc._csm_iter), + status = UPSOLVED, + variableOrder = separators, + # cliqueLikelihood, + jointmsg = _MsgJointLikelihood(;relatives=relative_message_factors), + msgType = ParametricMessage(), + ) + + # @assert length(separators) <= 2 "TODO length(separators) = $(length(separators)) > 2 in clique $(csmc.cliq.id.value)" + @assert isempty(lsfPriors(csmc.cliqSubFg)) || csmc.cliq.id.value == 1 "TODO priors in clique $(csmc.cliq.id.value)" + # if length(lsfPriors(csmc.cliqSubFg)) > 0 || length(separators) > 2 + # for si in cliqSeparatorVarIds + # vnd = getSolverData(getVariable(csmc.cliqSubFg, si), :parametric) + # beliefMsg.belief[si] = TreeBelief(deepcopy(vnd)) + # end + # end + + for e in getEdgesParent(csmc.tree, csmc.cliq) + logCSM(csmc, "$(csmc.cliq.id): put! on edge $(e)") + getMessageBuffer(csmc.cliq).upTx = deepcopy(beliefMsg) + putBeliefMessageUp!(csmc.tree, e, beliefMsg) + end + + return waitForDown_StateMachine +end + +global g_n = nothing + """ $SIGNATURES @@ -120,6 +259,14 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer) logCSM(csmc, "$(csmc.cliq.id): Updating separator $msym from message $(belief.val)") DFG.refMeans(vnd)[1] = belief.val[1] #FIXME 🦨 shares data structure in belief DFG.refCovariances(vnd)[1] = belief.bw + p = belief.val[1] + + S = belief.bw + S = (S + S') / 2 + # vnd.bw .= S + + nd = MvNormal(getCoordinates(Main.Pose2, p), S) + addFactor!(csmc.cliqSubFg, [msym], Main.PriorPose2(nd)) end end end @@ -132,23 +279,48 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer) #only down solve if its not a root if length(getParent(csmc.tree, csmc.cliq)) != 0 frontals = getCliqFrontalVarIds(csmc.cliq) - vardict, result, flatvars, Σ = solveConditionalsParametric(csmc.cliqSubFg, frontals) + # vardict, result, flatvars, Σ = solveConditionalsParametric(csmc.cliqSubFg, frontals) #TEMP testing difference # vardict, result = solveGraphParametric(csmc.cliqSubFg) # Pack all results in variables - if Optim.g_converged(result) || Optim.f_converged(result) + @assert !isempty(lsf(csmc.cliqSubFg)) "No factors in clique $(csmc.cliq.id.value) ls=$(ls(csmc.cliqSubFg)) lsf=$(lsf(csmc.cliqSubFg))" + + # M, vartypeslist, lm_r, Σ = solve_RLM_conditional(csmc.cliqSubFg, frontals; finiteDiffCovariance=false, damping_term_min=1e-18) + M, vartypeslist, lm_r, Σ = solve_RLM(csmc.cliqSubFg; finiteDiffCovariance=false, damping_term_min=1e-18) + sigmas = extractMarginalsAP(M, vartypeslist, Σ) + + if true # TODO check for convergence result.g_converged || result.f_converged logCSM( csmc, "$(csmc.cliq.id): subfg optim converged updating variables"; - loglevel = Logging.Info, + loglevel = Logging.Debug, ) - for (v, val) in vardict - logCSM(csmc, "$(csmc.cliq.id) down: updating $v : $val"; loglevel = Logging.Info) - vnd = getState(getVariable(csmc.cliqSubFg, v), :parametric) - #Update subfg variables - DFG.refMeans(vnd)[1] = val.val - DFG.refCovariances(vnd)[1] = val.cov + for (i, v) in enumerate(vartypeslist) + if v in frontals + # logCSM(csmc, "$(csmc.cliq.id) down: updating $v"; val, loglevel = Logging.Debug) + vnd = getState(getVariable(csmc.cliqSubFg, v), :parametric) + + S = Matrix(sigmas[i])# FIXME + S = (S + S') / 2 + # @assert all(isapprox.(S, sigmas[i], rtol=1e-3)) "Bad covariance matrix - not symetrical" + !all(isapprox.(S, sigmas[i], rtol=1e-3)) && @error("Bad covariance matrix - not symetrical") + # @assert all(diag(S) .> 0) "Bad covariance matrix - not positive diag" + !all(diag(S) .> 0) && @error("Bad covariance matrix - not positive diag") + + + #Update subfg variables + DFG.refMeans(vnd)[1] = lm_r[i] + DFG.refCovariances(vnd)[1] = S + end end + # for (v, val) in vardict + # logCSM(csmc, "$(csmc.cliq.id) down: updating $v"; val, loglevel = Logging.Debug) + # vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric) + + # #Update subfg variables + # vnd.val[1] = val.val + # vnd.bw .= val.cov + # end else @error "Par-5, clique $(csmc.cliq.id) failed to converge in down solve" result #propagate error to cleanly exit all cliques @@ -169,7 +341,7 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer) for fi in cliqFrontalVarIds vnd = getState(getVariable(csmc.cliqSubFg, fi), :parametric) beliefMsg.belief[fi] = TreeBelief(vnd) - logCSM(csmc, "$(csmc.cliq.id): down message $fi : $beliefMsg"; loglevel = Logging.Info) + logCSM(csmc, "$(csmc.cliq.id): down message $fi"; beliefMsg=beliefMsg.belief[fi], loglevel = Logging.Debug) end # pass through the frontal variables that were sent from above diff --git a/IncrementalInference/src/services/TreeMessageUtils.jl b/IncrementalInference/src/services/TreeMessageUtils.jl index 4aceffce..fcd92646 100644 --- a/IncrementalInference/src/services/TreeMessageUtils.jl +++ b/IncrementalInference/src/services/TreeMessageUtils.jl @@ -573,6 +573,28 @@ function addMsgFactors!( return msgfcts end +function addMsgFactors_Parametric!( + subfg::AbstractDFG, + msg::LikelihoodMessage, + ::Type{UpwardPass}; + tags::Vector{Symbol} = Symbol[], + # attemptPriors::Bool = true, +) + # add differential(relative) message factors + + msgfcts = map(msg.jointmsg.relatives) do difflikl + addFactor!( + subfg, + difflikl.variables, + difflikl.likelihood; + graphinit = false, + tags = union(tags, [:__LIKELIHOODMESSAGE__; :__UPWARD_DIFFERENTIAL__]), + ) + end + + return msgfcts +end + function addMsgFactors!( subfg::AbstractDFG, allmsgs::Dict{Int, LikelihoodMessage}, From ccc67eb80a4c00ba80322fac05adbbf5eb5dcb69 Mon Sep 17 00:00:00 2001 From: Johannes Terblanche Date: Thu, 3 Jul 2025 13:23:40 +0200 Subject: [PATCH 2/3] stash on par_tree --- IncrementalInference/src/Factors/GenericFunctions.jl | 9 +++++++++ IncrementalInference/src/services/FactorGradients.jl | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/IncrementalInference/src/Factors/GenericFunctions.jl b/IncrementalInference/src/Factors/GenericFunctions.jl index 7cd59527..2df74acf 100644 --- a/IncrementalInference/src/Factors/GenericFunctions.jl +++ b/IncrementalInference/src/Factors/GenericFunctions.jl @@ -123,6 +123,15 @@ function Ad(::typeof(SpecialEuclidean(2)), p) ) end +function Ad(::Motion(2), p) + t = p.x[1] + R = p.x[2] + vcat( + hcat(R, -SA[0 -1; 1 0]*t), + SA[0 0 1] + ) +end + struct AdFactor{F <: AbstractManifoldMinimize} <: AbstractManifoldMinimize factor::F end diff --git a/IncrementalInference/src/services/FactorGradients.jl b/IncrementalInference/src/services/FactorGradients.jl index e1831530..3ef89d43 100644 --- a/IncrementalInference/src/services/FactorGradients.jl +++ b/IncrementalInference/src/services/FactorGradients.jl @@ -31,7 +31,7 @@ function factorJacobian( M_codom = Euclidean(manifold_dimension(getManifold(fac))) # Jx(M, p) = ManifoldDiff.jacobian(M, M_codom, calcfac, p, backend) - return ManifoldDiff.jacobian(M_dom, M_codom, costf, p0, backend) + return ManifoldDiff.jacobian(M_dom, M_codom, costf, p0, backend), costf(p0) end From 44207fba6ff9d5b8c313f9fa3198aa16911cd38b Mon Sep 17 00:00:00 2001 From: Johannes Terblanche Date: Thu, 3 Jul 2025 13:35:23 +0200 Subject: [PATCH 3/3] unpackDistribution fixes covar --- .../services/SerializingDistributions.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/IncrementalInferenceTypes/src/serialization/services/SerializingDistributions.jl b/IncrementalInferenceTypes/src/serialization/services/SerializingDistributions.jl index 4e6ac01d..82d28913 100644 --- a/IncrementalInferenceTypes/src/serialization/services/SerializingDistributions.jl +++ b/IncrementalInferenceTypes/src/serialization/services/SerializingDistributions.jl @@ -28,3 +28,20 @@ function DFG.unpack(dtr::PackedFullNormal) end DFG.unpack(dtr::PackedRayleigh) = Rayleigh(dtr.sigma) + +# function unpackDistribution(dtr::PackedFullNormal) + +# SM = SymmetricPositiveDefinite(length(dtr.mu)) +# S = reshape(dtr.cov, length(dtr.mu), :) +# ch = check_point(SM, S; atol = 1e-9) +# if !isnothing(ch) +# @warn "IMU Covar check" ch +# S = (S + S') / 2 +# S = S + diagm((diag(S) .== 0)*1e-15) +# ch = check_point(SM, S) +# !isnothing(ch) && @error "IMU Covar check" ch +# end + +# # return MvNormal(dtr.mu, reshape(dtr.cov, length(dtr.mu), :)) +# return MvNormal(dtr.mu, S) +# end \ No newline at end of file