Skip to content

Commit 42e452c

Browse files
committed
ForwardDiff support in solve_RLM and use precision matrices
1 parent 209f926 commit 42e452c

4 files changed

Lines changed: 161 additions & 85 deletions

File tree

IncrementalInference/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
2020
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
2121
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
2222
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
23+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2324
FunctionalStateMachine = "3e9e306e-7e3c-11e9-12d2-8f8f67a2f951"
2425
IncrementalInferenceTypes = "9808408f-4dbc-47e4-913c-6068b950e289"
2526
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
@@ -83,6 +84,7 @@ FileIO = "1"
8384
FiniteDiff = "2"
8485
FiniteDifferences = "0.12"
8586
Flux = "0.14, 0.15, 0.16"
87+
ForwardDiff = "1.3.3"
8688
FunctionalStateMachine = "0.2.9, 0.3"
8789
Gadfly = "1"
8890
IncrementalInferenceTypes = "0.1.0"

IncrementalInference/src/manifolds/services/ManifoldsExtentions.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ function Manifolds.get_vector!(M::NPowerManifold, Y, p, c, B::AbstractBasis)
5555
return Y
5656
end
5757

58+
# Allocating get_vector that infers element type from coordinates (Dual-compatible)
59+
function ManifoldsBase.get_vector(M::NPowerManifold, p, c, B::AbstractBasis)
60+
dim = manifold_dimension(M.manifold)
61+
rep_size = Manifolds.representation_size(M.manifold)
62+
v_iter = Ref(0)
63+
return [begin
64+
coords_i = SVector{dim}(view(c, v_iter[]+1:v_iter[]+dim))
65+
v_iter[] += dim
66+
get_vector(M.manifold, Manifolds._read(M, rep_size, p, i), coords_i, B)
67+
end for i in Manifolds.get_iterator(M)]
68+
end
69+
5870
function Manifolds.exp!(M::NPowerManifold, q, p, X)
5971
rep_size = Manifolds.representation_size(M.manifold)
6072
for i in Manifolds.get_iterator(M)
@@ -67,6 +79,16 @@ function Manifolds.exp!(M::NPowerManifold, q, p, X)
6779
return q
6880
end
6981

82+
# Allocating exp that infers element type from tangent vector (Dual-compatible)
83+
function ManifoldsBase.exp(M::NPowerManifold, p, X)
84+
rep_size = Manifolds.representation_size(M.manifold)
85+
return [exp(
86+
M.manifold,
87+
Manifolds._read(M, rep_size, p, i),
88+
Manifolds._read(M, rep_size, X, i),
89+
) for i in Manifolds.get_iterator(M)]
90+
end
91+
7092
function LieGroups.compose!(M::NPowerManifold, x, p, q)
7193
rep_size = representation_size(M.manifold)
7294
for i in Manifolds.get_iterator(M)

IncrementalInference/src/parametric/services/ParametricManopt.jl

Lines changed: 98 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using FiniteDiff
33
using SparseDiffTools
44
using SparseArrays
55

6-
# using ForwardDiff
6+
using ForwardDiff
77
# using Zygote
88

99
##
@@ -111,7 +111,7 @@ end
111111
function calcFactorResVec!(
112112
x::Vector{T},
113113
cfm_part::Vector{<:CalcFactorResidual{FT, N, D}},
114-
p::AbstractArray{T},
114+
p::AbstractArray,
115115
st::Int
116116
) where {T, FT, N, D}
117117
for cfm in cfm_part
@@ -121,7 +121,7 @@ function calcFactorResVec!(
121121
return st
122122
end
123123

124-
function calcFactorResVec_threaded!(x::Vector{T}, cfm_part::Vector{<:CalcFactorResidual}, p::AbstractArray{T}, st::Int) where T
124+
function calcFactorResVec_threaded!(x::Vector{T}, cfm_part::Vector{<:CalcFactorResidual}, p::AbstractArray, st::Int) where T
125125
l = getDimension(cfm_part[1]) # all should be the same
126126
N = length(cfm_part)
127127
chunkies = Iterators.partition(1:N, N ÷ Threads.nthreads())
@@ -135,7 +135,7 @@ function calcFactorResVec_threaded!(x::Vector{T}, cfm_part::Vector{<:CalcFactorR
135135
return st + l*N
136136
end
137137

138-
function (costf::CostFres!{CFT})(M::AbstractManifold, x::Vector{T}, p::AbstractVector{T}) where {CFT,T}
138+
function (costf::CostFres!{CFT})(M::AbstractManifold, x::Vector{T}, p::AbstractVector) where {CFT,T}
139139
st = 1
140140
for cfm_part in costf.costfuns.x
141141
# if length(cfm_part) > Threads.nthreads() * 10
@@ -150,11 +150,11 @@ end
150150
## --------------------------------------------------------------------------------------------------------------
151151
## jacobian of function for Riemannian Levenberg-Marquardt
152152
## --------------------------------------------------------------------------------------------------------------
153-
struct JacF_RLM!{CF, T, JC}
153+
struct JacF_RLM!{CF, TX, TQ, JC}
154154
costF!::CF
155155
X0::Vector{Float64}
156-
X::T
157-
q::T
156+
X::TX
157+
q::TQ
158158
res::Vector{Float64}
159159
Jcache::JC
160160
end
@@ -245,6 +245,48 @@ end
245245
# ManifoldDiff.default_differential_backend()
246246
# )
247247

248+
## --------------------------------------------------------------------------------------------------------------
249+
## ForwardDiff jacobian for Riemannian Levenberg-Marquardt
250+
## --------------------------------------------------------------------------------------------------------------
251+
252+
struct JacF_RLM_ForwardDiff!{CF}
253+
costF!::CF
254+
X0::Vector{Float64}
255+
res::Vector{Float64}
256+
end
257+
258+
function JacF_RLM_ForwardDiff!(M, costF!, p, _fg=nothing;
259+
all_points=p,
260+
basis_domain::AbstractBasis = LieGroups.DefaultLieAlgebraOrthogonalBasis(),
261+
is_sparse=false, # unused, kept for interface compat
262+
)
263+
res = reduce(vcat, map(f -> f(all_points), Vector(costF!.costfuns)))
264+
X0 = zeros(manifold_dimension(M))
265+
return JacF_RLM_ForwardDiff!(costF!, X0, res)
266+
end
267+
268+
function (jacF!::JacF_RLM_ForwardDiff!)(
269+
M::AbstractManifold,
270+
J,
271+
p;
272+
basis_domain::AbstractBasis = DefaultOrthogonalBasis(),
273+
)
274+
X0 = jacF!.X0
275+
fill!(X0, 0)
276+
277+
nres = length(jacF!.res)
278+
function costf(Xc)
279+
X = get_vector(M, p, Xc, basis_domain)
280+
q = exp(M, p, X)
281+
_res = zeros(eltype(Xc), nres)
282+
jacF!.costF!(M, _res, q)
283+
return _res
284+
end
285+
286+
ForwardDiff.jacobian!(J, costf, X0)
287+
return J
288+
end
289+
248290
struct FactorGradient{A <: AbstractMatrix}
249291
manifold::AbstractManifold
250292
JacF!::JacF_RLM!
@@ -278,9 +320,7 @@ function getSparsityPattern(fg, varLabels, factLabels)
278320
return sparse(getindex.(iter,1), getindex.(iter,2), ones(Bool, length(iter)))
279321
end
280322

281-
# TODO only calculate marginal covariances
282-
283-
function covarianceFiniteDiff(M, jacF!::JacF_RLM!, p0)
323+
function precisionFiniteDiff(M, jacF!::JacF_RLM!, p0)
284324
# Jcache
285325
X0 = fill!(deepcopy(jacF!.X0), 0)
286326

@@ -292,17 +332,22 @@ function covarianceFiniteDiff(M, jacF!::JacF_RLM!, p0)
292332
end
293333
end
294334

295-
H = FiniteDiff.finite_difference_hessian(costf, X0)
296-
297-
# inv(H)
298-
Σ = try
299-
Matrix(H) \ Matrix{eltype(H)}(I, size(H)...)
300-
catch ex #TODO only catch correct exception and try with pinv as fallback in certain cases.
301-
@warn "Hessian inverse failed" ex H
302-
pinv(H)
303-
# nothing
304-
end
305-
return Σ
335+
FiniteDiff.finite_difference_hessian(costf, X0)
336+
end
337+
338+
function precisionFiniteDiff(M, jacF!::JacF_RLM_ForwardDiff!, p0)
339+
X0 = fill!(copy(jacF!.X0), 0)
340+
nres = length(jacF!.res)
341+
342+
function costf(Xc)
343+
X = get_vector(M, p0, Xc, DefaultOrthogonalBasis())
344+
q = exp(M, p0, X)
345+
_res = zeros(nres)
346+
jacF!.costF!(M, _res, q)
347+
return 1/2*norm(_res)^2
348+
end
349+
350+
FiniteDiff.finite_difference_hessian(costf, X0)
306351
end
307352

308353
function qr_linear_subsolver!(sk, JJ, grad_f_c)
@@ -316,6 +361,7 @@ function solve_RLM(
316361
faclabels = lsf(fg);
317362
is_sparse = true,
318363
finiteDiffCovariance = false,
364+
jacobian_method::Symbol = :finitediff,
319365
solveKey::Symbol = :parametric,
320366
linear_subsolver! = qr_linear_subsolver!,
321367
kwargs...
@@ -341,13 +387,17 @@ function solve_RLM(
341387
costF! = CostFres!(calcfacs, collect(varlabelsAP))
342388

343389
# jacobian of function for Riemannian Levenberg-Marquardt
344-
jacF! = JacF_RLM!(M, costF!, p0, fg; is_sparse)
390+
if jacobian_method == :forwarddiff
391+
jacF! = JacF_RLM_ForwardDiff!(M, costF!, p0)
392+
else
393+
jacF! = JacF_RLM!(M, costF!, p0, fg; is_sparse)
394+
end
345395

346396
num_components = length(jacF!.res)
347397
initial_residual_values = zeros(num_components)
348398

349399
# initial_jacobian_f not type stable, but function barrier so should be ok.
350-
initial_jacobian_f = is_sparse ?
400+
initial_jacobian_f = (jacF! isa JacF_RLM! && is_sparse) ?
351401
jacF!.Jcache.sparsity :
352402
zeros(num_components, manifold_dimension(M))
353403

@@ -365,23 +415,15 @@ function solve_RLM(
365415
kwargs...
366416
)
367417

368-
if length(initial_residual_values) < 1000
369-
if finiteDiffCovariance
370-
# TODO this seems to be correct, but way to slow
371-
Σ = covarianceFiniteDiff(M, jacF!, lm_r)
372-
else
373-
# TODO make sure J initial_jacobian_f is updated, otherwise recalc jacF!(M, J, lm_r) # lm_r === p0
374-
J = initial_jacobian_f
375-
H = J'J # approx
376-
Σ = H \ Matrix{eltype(H)}(I, size(H)...)
377-
# Σ = pinv(H)
378-
end
418+
if finiteDiffCovariance
419+
Λ = precisionFiniteDiff(M, jacF!, lm_r)
379420
else
380-
@warn "Not estimating a Dense Covariance $(size(initial_jacobian_f))"
381-
Σ = nothing
421+
J = initial_jacobian_f
422+
jacF!(M, J, lm_r) # recompute J at solution point
423+
Λ = Symmetric(J'J) # approx Hessian = precision matrix
382424
end
383425

384-
return M, varlabelsAP, lm_r, Σ
426+
return M, varlabelsAP, lm_r, Λ
385427
end
386428

387429
# nlso = NonlinearLeastSquaresObjective(
@@ -440,6 +482,7 @@ function solve_RLM_conditional(
440482
separators::Vector{Symbol} = setdiff(ls(fg), frontals);
441483
is_sparse=false,
442484
finiteDiffCovariance=true,
485+
jacobian_method::Symbol = :finitediff,
443486
solveKey::Symbol = :parametric,
444487
kwargs...
445488
)
@@ -498,13 +541,17 @@ function solve_RLM_conditional(
498541
costF! = CostFres_cond!(all_points, calcfacs, Vector{Symbol}(collect(all_varlabelsAP)))
499542

500543
# jacobian of function for Riemannian Levenberg-Marquardt
501-
jacF! = JacF_RLM!(M, costF!, p0, fg; all_points, is_sparse)
544+
if jacobian_method == :forwarddiff
545+
jacF! = JacF_RLM_ForwardDiff!(M, costF!, p0; all_points)
546+
else
547+
jacF! = JacF_RLM!(M, costF!, p0, fg; all_points, is_sparse)
548+
end
502549

503550
num_components = length(jacF!.res)
504551

505552
initial_residual_values = zeros(num_components)
506553

507-
initial_jacobian_f = is_sparse ?
554+
initial_jacobian_f = (jacF! isa JacF_RLM! && is_sparse) ?
508555
jacF!.Jcache.sparsity :
509556
zeros(num_components, manifold_dimension(M))
510557

@@ -521,13 +568,13 @@ function solve_RLM_conditional(
521568
)
522569

523570
if finiteDiffCovariance
524-
Σ = covarianceFiniteDiff(M, jacF!, lm_r)
571+
Λ = precisionFiniteDiff(M, jacF!, lm_r)
525572
else
526-
J = initial_jacobian_f
527-
Σ = pinv(J'J)
573+
jacF!(M, initial_jacobian_f, lm_r)
574+
Λ = Symmetric(initial_jacobian_f' * initial_jacobian_f)
528575
end
529-
530-
return M, frontal_varlabelsAP, lm_r, Σ
576+
577+
return M, frontal_varlabelsAP, lm_r, Λ
531578
end
532579

533580
function extractMarginalsAP(M, labelsAP::ArrayPartition{Symbol}, Σ::AbstractArray{<:Real})
@@ -608,12 +655,14 @@ function autoinitParametric!(
608655
)
609656
)
610657
end
611-
M, vartypeslist, lm_r, Σ = solve_RLM_conditional(dfg, [initme], initfrom; solveKey, kwargs...)
658+
M, vartypeslist, lm_r, Λ = solve_RLM_conditional(dfg, [initme], initfrom; solveKey, kwargs...)
612659

613660
val = lm_r[1]
614661
DFG.refMeans(vnd)[1] = val
615662

616-
!isnothing(Σ) && (DFG.refCovariances(vnd)[1] .= Σ)
663+
if !isnothing(Λ)
664+
DFG.refCovariances(vnd)[1] .= inv(Matrix(Λ))
665+
end
617666

618667
# updateSolverDataParametric!(vnd, val, Σ)
619668

@@ -656,11 +705,11 @@ function DFG.solveGraphParametric!(
656705
error("TODO: not implemented")
657706
end
658707

659-
M, v, r, Σ = solve_RLM(fg, args...; is_sparse, kwargs...)
708+
M, v, r, Λ = solve_RLM(fg, args...; is_sparse, kwargs...)
660709

661-
updateParametricSolution!(fg, M, v, r, Σ)
710+
updateParametricSolution!(fg, M, v, r, Λ)
662711

663-
return M, v, r, Σ
712+
return M, v, r, Λ
664713
end
665714

666715

0 commit comments

Comments
 (0)