Skip to content

Commit 60731c2

Browse files
committed
wip parametric tree solve
1 parent d81db6f commit 60731c2

4 files changed

Lines changed: 282 additions & 12 deletions

File tree

IncrementalInference/src/CliqueStateMachine/services/CliqueStateMachine.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ function updateFromSubgraph_StateMachine(csmc::CliqStateMachineContainer)
956956
logCSM(
957957
csmc,
958958
"CSM-5 Clique $(csmc.cliq.id) finished, solveKey=$(csmc.solveKey)";
959-
loglevel = Logging.Info,
959+
loglevel = Logging.Debug,
960960
)
961961
return IncrementalInference.exitStateMachine
962962
end

IncrementalInference/src/Factors/GenericFunctions.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,82 @@ function (cf::CalcFactor{<:ManifoldFactor})(X, p, q)
9494
return measurement_residual(cf.factor.M, X, p, q)
9595
end
9696

97+
## ======================================================================================
98+
## adjoint factor - adjoint action applied to the measurement
99+
## ======================================================================================
100+
function Ad(::Union{typeof(SpecialEuclidean(2)), typeof(SpecialEuclidean(3))}, p, X)
101+
t = p.x[1]
102+
R = p.x[2]
103+
v = X.x[1]
104+
Ω = X.x[2]
105+
ArrayPartition(-R*Ω*R'*t + R*v, R*Ω*R')
106+
end
107+
108+
function Ad(::typeof(SpecialEuclidean(3)), p)
109+
t = p.x[1]
110+
R = p.x[2]
111+
vcat(
112+
hcat(R, skew(t)*R),
113+
hcat(zero(SMatrix{3,3,Float64}), R)
114+
)
115+
end
116+
117+
function Ad(::typeof(SpecialEuclidean(2)), p)
118+
t = p.x[1]
119+
R = p.x[2]
120+
vcat(
121+
hcat(R, -SA[0 -1; 1 0]*t),
122+
SA[0 0 1]
123+
)
124+
end
125+
126+
struct AdFactor{F <: AbstractManifoldMinimize} <: AbstractManifoldMinimize
127+
factor::F
128+
end
129+
130+
function (cf::CalcFactor{<:AdFactor})(Xϵ, p, q)
131+
# M = getManifold(cf.factor)
132+
# p,q ∈ M
133+
# Xϵ ∈ TϵM
134+
# ϵ = identity_element(M)
135+
# transform measurement from TϵM to TpM (global to local coordinates)
136+
# Adₚ⁻¹ = AdjointMatrix(M, p)⁻¹ = AdjointMatrix(M, p⁻¹)
137+
# Xp = Adₚ⁻¹ * Xϵᵛ
138+
# ad = Ad(M, inv(M, p))
139+
# Xp = Ad(M, inv(M, p), Xϵ)
140+
# Xp = adjoint_action(M, inv(M, p), Xϵ)
141+
#TODO is vector transport supposed to be the same?
142+
# Xp = vector_transport_to(M, ϵ, Xϵ, p)
143+
144+
# Transform measurement covariance
145+
# ᵉΣₚ = Adₚ ᵖΣₚ Adₚᵀ
146+
#TODO test if transforming sqrt_iΣ is the same as Σ
147+
# Σ = ad * inv(cf.sqrt_iΣ^2) * ad'
148+
# sqrt_iΣ = convert(typeof(cf.sqrt_iΣ), sqrt(inv(Σ)))
149+
# sqrt_iΣ = convert(typeof(cf.sqrt_iΣ), ad * cf.sqrt_iΣ * ad')
150+
Xp =
151+
152+
child_cf = CalcFactorResidual(
153+
cf.faclbl,
154+
cf.factor.factor,
155+
cf.varOrder,
156+
cf.varOrderIdxs,
157+
cf.meas,
158+
cf.sqrt_iΣ,
159+
cf.cache,
160+
)
161+
return child_cf(Xp, p, q)
162+
end
163+
164+
getMeasurementParametric(f::AdFactor) = getMeasurementParametric(f.factor)
165+
166+
getManifold(f::AdFactor) = getManifold(f.factor)
167+
function getSample(cf::CalcFactor{<:AdFactor})
168+
M = getManifold(cf)
169+
return sampleTangent(M, cf.factor.factor.Z)
170+
end
171+
172+
97173
## ======================================================================================
98174
## adjoint factor - adjoint action applied to the measurement
99175
## ======================================================================================

IncrementalInference/src/parametric/services/ParametricCSMFunctions.jl

Lines changed: 183 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Notes
66
- Parametric state machine function nr. 3
77
"""
8-
function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
8+
function solveUp_ParametricStateMachine_Old(csmc::CliqStateMachineContainer)
99
infocsm(csmc, "Par-3, Solving Up")
1010

1111
setCliqueDrawColor!(csmc.cliq, "red")
@@ -96,6 +96,145 @@ function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
9696
return waitForDown_StateMachine
9797
end
9898

99+
# solve relatives ignoring any priors keeping `from` at ϵ
100+
# if clique has priors : solve to get a prior on `from`
101+
# send messages as factors or just the beliefs? for now factors
102+
function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
103+
infocsm(csmc, "Par-3, Solving Up")
104+
105+
setCliqueDrawColor!(csmc.cliq, "red")
106+
# csmc.drawtree ? drawTree(csmc.tree, show=false, filepath=joinpath(getSolverParams(csmc.dfg).logpath,"bt.pdf")) : nothing
107+
108+
msgfcts = Symbol[]
109+
110+
for (idx, upmsg) in getMessageBuffer(csmc.cliq).upRx #get cached messages taken from children saved in this clique
111+
child_factors = addMsgFactors_Parametric!(csmc.cliqSubFg, upmsg, UpwardPass)
112+
append!(msgfcts, getLabel.(child_factors)) # addMsgFactors_Parametric!
113+
end
114+
logCSM(csmc, "length mgsfcts=$(length(msgfcts))")
115+
infocsm(csmc, "length mgsfcts=$(length(msgfcts))")
116+
117+
# store the cliqSubFg for later debugging
118+
_dbgCSMSaveSubFG(csmc, "fg_beforeupsolve")
119+
120+
subfg = csmc.cliqSubFg
121+
122+
frontals = getCliqFrontalVarIds(csmc.cliq)
123+
separators = getCliqSeparatorVarIds(csmc.cliq)
124+
125+
# if its a root do full solve
126+
if length(getParent(csmc.tree, csmc.cliq)) == 0
127+
# M, vartypeslist, lm_r, Σ = solve_RLM(subfg; is_sparse=false, finiteDiffCovariance=true)
128+
autoinitParametric!(subfg)
129+
M, vartypeslist, lm_r, Σ = solveGraphParametric!(subfg; is_sparse=false, finiteDiffCovariance=true, damping_term_min=1e-18)
130+
131+
else
132+
133+
# select first seperator as constant reference at the identity element
134+
isempty(separators) && @warn "empty separators solving cliq $(csmc.cliq.id.value)" ls(subfg) lsf(subfg)
135+
from = first(separators)
136+
from_v = getVariable(subfg, from)
137+
getSolverData(from_v, :parametric).val[1] = getPointIdentity(getVariableType(from_v))
138+
139+
#TODO handle priors
140+
# Variables that are free to move
141+
free_vars = [frontals; separators[2:end]]
142+
# Solve for the free variables
143+
144+
@assert !isempty(lsf(subfg)) "No factors in clique $(csmc.cliq.id.value) ls=$(ls(subfg)) lsf=$(lsf(subfg))"
145+
146+
# M, vartypeslist, lm_r, Σ = solve_RLM_conditional(subfg, free_vars, [from];)
147+
M, vartypeslist, lm_r, Σ = solve_RLM_conditional(subfg, free_vars, [from]; finiteDiffCovariance=false, damping_term_min=1e-18)
148+
149+
end
150+
151+
# FIXME check solve convergence
152+
if !true
153+
@error "Par-3, clique $(csmc.cliq.id) failed to converge in upsolve" result
154+
# propagate error to cleanly exit all cliques
155+
putErrorUp(csmc)
156+
if length(getParent(csmc.tree, csmc.cliq)) == 0
157+
putErrorDown(csmc)
158+
return IncrementalInference.exitStateMachine
159+
end
160+
161+
return waitForDown_StateMachine
162+
end
163+
164+
logCSM(csmc, "$(csmc.cliq.id): subfg solve converged sending messages")
165+
166+
# Pack results in massage factors
167+
168+
sigmas = extractMarginalsAP(M, vartypeslist, Σ)
169+
170+
# FIXME fix MsgRelativeType
171+
relative_message_factors = MsgRelativeType();
172+
for (i, to) in enumerate(vartypeslist)
173+
if to in separators
174+
#assume full dim factor
175+
factype = selectFactorType(subfg, from, to)
176+
# make S symetrical
177+
# S = sigmas[i] # FIXME for some reason SMatrix is not invertable even though it is!!!!!!!!
178+
S = Matrix(sigmas[i])# FIXME
179+
S = (S + S') / 2
180+
# @assert all(isapprox.(S, sigmas[i], rtol=1e-3)) "Bad covariance matrix - not symetrical"
181+
!all(isapprox.(S, sigmas[i], rtol=1e-3)) && @error("Bad covariance matrix - not symetrical")
182+
# @assert all(diag(S) .> 0) "Bad covariance matrix - not positive diag"
183+
!all(diag(S) .> 0) && @error("Bad covariance matrix - not positive diag")
184+
185+
186+
M_to = getManifold(getVariableType(subfg, to))
187+
ϵ = getPointIdentity(M_to)
188+
μ = vee(M_to, ϵ, log(M_to, ϵ, lm_r[i]))
189+
190+
message_factor = AdFactor(factype(MvNormal(μ, S)))
191+
192+
193+
# logCSM(csmc, "$(csmc.cliq.id): Z=$(getMeasurementParametric(message_factor))"; loglevel = Logging.Warn)
194+
195+
push!(relative_message_factors, (variables=[from, to], likelihood=message_factor))
196+
end
197+
end
198+
199+
# Done with solve delete factors
200+
#TODO confirm, maybe don't delete mesage factors on subgraph, maybe delete if its priors, but not conditionals
201+
# deleteMsgFactors!(csmc.cliqSubFg)
202+
203+
# store the cliqSubFg for later debugging
204+
_dbgCSMSaveSubFG(csmc, "fg_afterupsolve")
205+
206+
# cliqueLikelihood = calculateMarginalCliqueLikelihood(vardict, Σ, varIds, cliqSeparatorVarIds)
207+
208+
#Fill in CliqueLikelihood
209+
beliefMsg = LikelihoodMessage(;
210+
sender = (; id = csmc.cliq.id.value, step = csmc._csm_iter),
211+
status = UPSOLVED,
212+
variableOrder = separators,
213+
# cliqueLikelihood,
214+
jointmsg = _MsgJointLikelihood(;relatives=relative_message_factors),
215+
msgType = ParametricMessage(),
216+
)
217+
218+
# @assert length(separators) <= 2 "TODO length(separators) = $(length(separators)) > 2 in clique $(csmc.cliq.id.value)"
219+
@assert isempty(lsfPriors(csmc.cliqSubFg)) || csmc.cliq.id.value == 1 "TODO priors in clique $(csmc.cliq.id.value)"
220+
# if length(lsfPriors(csmc.cliqSubFg)) > 0 || length(separators) > 2
221+
# for si in cliqSeparatorVarIds
222+
# vnd = getSolverData(getVariable(csmc.cliqSubFg, si), :parametric)
223+
# beliefMsg.belief[si] = TreeBelief(deepcopy(vnd))
224+
# end
225+
# end
226+
227+
for e in getEdgesParent(csmc.tree, csmc.cliq)
228+
logCSM(csmc, "$(csmc.cliq.id): put! on edge $(e)")
229+
getMessageBuffer(csmc.cliq).upTx = deepcopy(beliefMsg)
230+
putBeliefMessageUp!(csmc.tree, e, beliefMsg)
231+
end
232+
233+
return waitForDown_StateMachine
234+
end
235+
236+
global g_n = nothing
237+
99238
"""
100239
$SIGNATURES
101240
@@ -120,6 +259,14 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
120259
logCSM(csmc, "$(csmc.cliq.id): Updating separator $msym from message $(belief.val)")
121260
DFG.refMeans(vnd)[1] = belief.val[1] #FIXME 🦨 shares data structure in belief
122261
DFG.refCovariances(vnd)[1] = belief.bw
262+
p = belief.val[1]
263+
264+
S = belief.bw
265+
S = (S + S') / 2
266+
# vnd.bw .= S
267+
268+
nd = MvNormal(getCoordinates(Main.Pose2, p), S)
269+
addFactor!(csmc.cliqSubFg, [msym], Main.PriorPose2(nd))
123270
end
124271
end
125272
end
@@ -132,23 +279,48 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
132279
#only down solve if its not a root
133280
if length(getParent(csmc.tree, csmc.cliq)) != 0
134281
frontals = getCliqFrontalVarIds(csmc.cliq)
135-
vardict, result, flatvars, Σ = solveConditionalsParametric(csmc.cliqSubFg, frontals)
282+
# vardict, result, flatvars, Σ = solveConditionalsParametric(csmc.cliqSubFg, frontals)
136283
#TEMP testing difference
137284
# vardict, result = solveGraphParametric(csmc.cliqSubFg)
138285
# Pack all results in variables
139-
if Optim.g_converged(result) || Optim.f_converged(result)
286+
@assert !isempty(lsf(csmc.cliqSubFg)) "No factors in clique $(csmc.cliq.id.value) ls=$(ls(csmc.cliqSubFg)) lsf=$(lsf(csmc.cliqSubFg))"
287+
288+
# M, vartypeslist, lm_r, Σ = solve_RLM_conditional(csmc.cliqSubFg, frontals; finiteDiffCovariance=false, damping_term_min=1e-18)
289+
M, vartypeslist, lm_r, Σ = solve_RLM(csmc.cliqSubFg; finiteDiffCovariance=false, damping_term_min=1e-18)
290+
sigmas = extractMarginalsAP(M, vartypeslist, Σ)
291+
292+
if true # TODO check for convergence result.g_converged || result.f_converged
140293
logCSM(
141294
csmc,
142295
"$(csmc.cliq.id): subfg optim converged updating variables";
143-
loglevel = Logging.Info,
296+
loglevel = Logging.Debug,
144297
)
145-
for (v, val) in vardict
146-
logCSM(csmc, "$(csmc.cliq.id) down: updating $v : $val"; loglevel = Logging.Info)
147-
vnd = getState(getVariable(csmc.cliqSubFg, v), :parametric)
148-
#Update subfg variables
149-
DFG.refMeans(vnd)[1] = val.val
150-
DFG.refCovariances(vnd)[1] = val.cov
298+
for (i, v) in enumerate(vartypeslist)
299+
if v in frontals
300+
# logCSM(csmc, "$(csmc.cliq.id) down: updating $v"; val, loglevel = Logging.Debug)
301+
vnd = getState(getVariable(csmc.cliqSubFg, v), :parametric)
302+
303+
S = Matrix(sigmas[i])# FIXME
304+
S = (S + S') / 2
305+
# @assert all(isapprox.(S, sigmas[i], rtol=1e-3)) "Bad covariance matrix - not symetrical"
306+
!all(isapprox.(S, sigmas[i], rtol=1e-3)) && @error("Bad covariance matrix - not symetrical")
307+
# @assert all(diag(S) .> 0) "Bad covariance matrix - not positive diag"
308+
!all(diag(S) .> 0) && @error("Bad covariance matrix - not positive diag")
309+
310+
311+
#Update subfg variables
312+
DFG.refMeans(vnd)[1] = lm_r[i]
313+
DFG.refCovariances(vnd)[1] = S
314+
end
151315
end
316+
# for (v, val) in vardict
317+
# logCSM(csmc, "$(csmc.cliq.id) down: updating $v"; val, loglevel = Logging.Debug)
318+
# vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)
319+
320+
# #Update subfg variables
321+
# vnd.val[1] = val.val
322+
# vnd.bw .= val.cov
323+
# end
152324
else
153325
@error "Par-5, clique $(csmc.cliq.id) failed to converge in down solve" result
154326
#propagate error to cleanly exit all cliques
@@ -169,7 +341,7 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
169341
for fi in cliqFrontalVarIds
170342
vnd = getState(getVariable(csmc.cliqSubFg, fi), :parametric)
171343
beliefMsg.belief[fi] = TreeBelief(vnd)
172-
logCSM(csmc, "$(csmc.cliq.id): down message $fi : $beliefMsg"; loglevel = Logging.Info)
344+
logCSM(csmc, "$(csmc.cliq.id): down message $fi"; beliefMsg=beliefMsg.belief[fi], loglevel = Logging.Debug)
173345
end
174346

175347
# pass through the frontal variables that were sent from above

IncrementalInference/src/services/TreeMessageUtils.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,28 @@ function addMsgFactors!(
573573
return msgfcts
574574
end
575575

576+
function addMsgFactors_Parametric!(
577+
subfg::AbstractDFG,
578+
msg::LikelihoodMessage,
579+
::Type{UpwardPass};
580+
tags::Vector{Symbol} = Symbol[],
581+
# attemptPriors::Bool = true,
582+
)
583+
# add differential(relative) message factors
584+
585+
msgfcts = map(msg.jointmsg.relatives) do difflikl
586+
addFactor!(
587+
subfg,
588+
difflikl.variables,
589+
difflikl.likelihood;
590+
graphinit = false,
591+
tags = union(tags, [:__LIKELIHOODMESSAGE__; :__UPWARD_DIFFERENTIAL__]),
592+
)
593+
end
594+
595+
return msgfcts
596+
end
597+
576598
function addMsgFactors!(
577599
subfg::AbstractDFG,
578600
allmsgs::Dict{Int, LikelihoodMessage},

0 commit comments

Comments
 (0)