Skip to content

Commit 05dd76a

Browse files
authored
Fix some sampling dispatch and cleanup (#1880)
1 parent c2fdb03 commit 05dd76a

10 files changed

Lines changed: 50 additions & 69 deletions

File tree

IncrementalInference/src/Deprecated.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,27 @@ function _perturbIfNecessary(
3535
end
3636
#
3737

38+
# lets create all the vertices first and then deal with the elimination variables thereafter
39+
function addBayesNetVerts!(dfg::AbstractDFG, elimOrder::Array{Symbol, 1})
40+
#
41+
for pId in elimOrder
42+
vert = DFG.getVariable(dfg, pId)
43+
if getState(vert, :default).BayesNetVertID === nothing ||
44+
getState(vert, :default).BayesNetVertID == :_null # Special serialization case of nothing
45+
@debug "[AddBayesNetVerts] Assigning $pId.data.BayesNetVertID = $pId"
46+
getState(vert, :default).BayesNetVertID = pId
47+
else
48+
@warn "addBayesNetVerts -- Something is wrong, variable '$pId' should not have an existing Bayes net reference to '$(getState(vert, :default).BayesNetVertID)'"
49+
end
50+
end
51+
end
3852

3953
## ================================================================================================
4054
## ================================================================================================
4155

4256
# TODO maybe upstream to DFG
43-
DFG.MeanMaxPPE(solveKey::Symbol, suggested::SVector, max::SVector, mean::SVector) =
44-
DFG.MeanMaxPPE(solveKey, collect(suggested), collect(max), collect(mean))
57+
DFG.MeanMaxPPE(solveKey::Symbol, suggested::StaticArray, max::StaticArray, mean::StaticArray) =
58+
DFG.MeanMaxPPE(solveKey, Vector(suggested), Vector(max), Vector(mean))
4559

4660

4761
## ================================================================================================

IncrementalInference/src/Factors/Circular.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ function getSample(cf::CalcFactor{<:CircularCircular})
1313
return sampleTangent(getManifold(cf), cf.factor.Z, getPointIdentity(Circular))
1414
end
1515

16-
function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{CircularCircular})
17-
return Manifolds.RealCircleGroup()
18-
end
16+
#TODO should be deprecated?
17+
# function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{CircularCircular})
18+
# return Manifolds.RealCircleGroup()
19+
# end
1920

2021
IIFTypes.CircularCircular(::UniformScaling) = CircularCircular(Normal())
2122

@@ -40,8 +41,9 @@ function (cf::CalcFactor{<:PriorCircular})(m, p)
4041
return Xc
4142
end
4243

43-
function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{PriorCircular})
44-
return Manifolds.RealCircleGroup()
45-
end
44+
#TODO should be deprecated?
45+
# function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{PriorCircular})
46+
# return Manifolds.RealCircleGroup()
47+
# end
4648

4749
# --------------------------------------------

IncrementalInference/src/Factors/EuclidDistance.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ getDimension(::InstanceType{<:EuclidDistance}) = 1
1919
# new and simplified interface for both nonparametric and parametric
2020
(s::CalcFactor{<:EuclidDistance})(z, x1, x2) = z .- norm(x2 .- x1)
2121

22-
function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{EuclidDistance})
23-
return LieGroups.TranslationGroup(1)
24-
end
22+
#TODO should be deprecated?
23+
# function Base.convert(::Type{<:MB.AbstractManifold}, ::InstanceType{EuclidDistance})
24+
# return LieGroups.TranslationGroup(1)
25+
# end
2526

2627
"""
2728
$(TYPEDEF)

IncrementalInference/src/manifolds/services/ManifoldSampling.jl

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,61 +9,34 @@ Notes
99
"""
1010
function sampleTangent end
1111

12-
# Sampling MKD
13-
# function sampleTangent(M::AbstractDecoratorManifold, x::ManifoldKernelDensity, p = mean(x))
14-
# # get legacy matrix of coordinates and selected labels
15-
# #TODO make sure that when `sample` is replaced in MKD, coordinates is a vector
16-
# coords, lbls = sample(x.belief, 1)
17-
# X = hat(x.manifold, p, coords[:])
18-
# return X
19-
# end
20-
2112
function sampleTangent(x::ManifoldKernelDensity, p = mean(x))
22-
Base.depwarn(
23-
"sampleTangent(x::ManifoldKernelDensity, p) should be replaced by sampleTangent(M<:AbstractManifold, x::ManifoldKernelDensity, p)",
24-
:sampleTangent,
25-
)
26-
return sampleTangent(x.manifold, x, p)
13+
error("sampleTangent(x::ManifoldKernelDensity, p) should be replaced by sampleTangent(M<:AbstractManifold, x::ManifoldKernelDensity, p)")
2714
end
2815

2916
# Sampling Distributions
3017
# assumes M is a group and will break for Riemannian, but leaving that enhancement as TODO
3118
function sampleTangent(
3219
M::AbstractManifold,
33-
z::Distribution,
34-
p = getPointIdentity(M),
20+
z,
21+
p,
3522
basis::AbstractBasis = DefaultOrthogonalBasis(),
3623
)
3724
return get_vector(M, p, rand(z), basis)
3825
end
3926

40-
# function sampleTangent(
41-
# M::AbstractDecoratorManifold,
42-
# z::Distribution,
43-
# p = getPointIdentity(M),
44-
# )
45-
# return hat(M, p, SVector{length(z)}(rand(z))) #TODO make sure all Distribution has length,
46-
# # if this errors maybe fall back no next line
47-
# # return convert(typeof(p), hat(M, p, rand(z, 1)[:])) #TODO find something better than (z,1)[:]
48-
# end
49-
50-
function sampleTangent(M::AbstractLieGroup, z::Distribution, p = getPointIdentity(M))
51-
return hat(LieAlgebra(M), SVector{length(z)}(rand(z)), typeof(p))
27+
function sampleTangent(M::AbstractLieGroup, z, p = getPointIdentity(M))
28+
return hat(LieAlgebra(M), SVector{manifold_dimension(M)}(rand(z)), typeof(p))
5229
end
5330

5431
function sampleTangent(M::typeof(LieGroups.CircleGroup()), z::Distribution, p = getPointIdentity(M))
5532
return hat(LieAlgebra(M), rand(z))
5633
end
5734

58-
function sampleTangent(M::AbstractLieGroup, z, p = getPointIdentity(M))
59-
return hat(LieAlgebra(M), rand(z), typeof(p))
60-
end
61-
6235
function sampleTangent(M::AbstractLieGroup, x::ManifoldKernelDensity, p = mean(x))
6336
# get legacy matrix of coordinates and selected labels
6437
#TODO make sure that when `sample` is replaced in MKD, coordinates is a vector
6538
coords, lbls = sample(x.belief, 1)
66-
X = hat(LieAlgebra(x.manifold), coords[:], typeof(p))
39+
X = hat(LieAlgebra(M), coords[:], typeof(p))
6740
return X
6841
end
6942

IncrementalInference/src/manifolds/services/ManifoldsExtentions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ end
123123

124124
#TODO test
125125
function DFG.getPointIdentity(
126-
PrG::AbstractLieGroup{𝔽, Op, M},
126+
PrG::LieGroup{𝔽, Op, M},
127127
::Type{T} = Float64,
128128
) where {𝔽, Op <: AbstractProductGroupOperation, M <: ProductManifold, T <: Real}
129129
PrM = PrG.manifold
@@ -162,7 +162,7 @@ end
162162
function DFG.getPointIdentity(
163163
::typeof(SpecialEuclideanGroup(3; variant = :right)),
164164
::Type{T} = Float64,
165-
) where {T}
165+
) where {T <: Real}
166166
N = 3
167167
return ArrayPartition(zeros(SVector{N, T}), SMatrix{N, N, T}(I))
168168
end

IncrementalInference/src/services/BayesNet.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,6 @@ function getEliminationOrder(
7171
return permuteds[p]
7272
end
7373

74-
# lets create all the vertices first and then deal with the elimination variables thereafter
75-
function addBayesNetVerts!(dfg::AbstractDFG, elimOrder::Array{Symbol, 1})
76-
#
77-
for pId in elimOrder
78-
vert = DFG.getVariable(dfg, pId)
79-
if getState(vert, :default).BayesNetVertID === nothing ||
80-
getState(vert, :default).BayesNetVertID == :_null # Special serialization case of nothing
81-
@debug "[AddBayesNetVerts] Assigning $pId.data.BayesNetVertID = $pId"
82-
getState(vert, :default).BayesNetVertID = pId
83-
else
84-
@warn "addBayesNetVerts -- Something is wrong, variable '$pId' should not have an existing Bayes net reference to '$(getState(vert, :default).BayesNetVertID)'"
85-
end
86-
end
87-
end
88-
8974
function addConditional!(dfg::AbstractDFG, vertId::Symbol, Si::Vector{Symbol})
9075
#
9176
bnv = DFG.getVariable(dfg, vertId)

IncrementalInference/src/services/TreeMessageUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ function addLikelihoodsDifferentialCHILD!(
311311
# chain of user factors are of the same type
312312
if isHom
313313
_sft = selectFactorType(tfg, sym1_, sym2_)
314-
sft = _sft() #FIXME empty factor observation constructor
314+
sft = _sft(MvNormal( getDimension(getManifold(_sft)), 1.0)) #FIXME empty factor observation constructor
315315
# only take factors that are homogeneous with the generic relative
316316
if typeof(sft).name == ftyps[1]
317317
# assume default helper function # buildFactorDefault(nfactype)

IncrementalInference/test/manifolds/factordiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ end
5151

5252

5353
# ##
54-
# @testset "using RoME; FiniteDiff.jacobian of SpecialEuclidean(2) factor" begin
54+
# @testset "using RoME; FiniteDiff.jacobian of SpecialEuclideanGroup(2) factor" begin
5555
# ##
5656

5757
# fg = LocalDFG(;

IncrementalInference/test/runtests.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@ using Test
44
using LieGroups: TranslationGroup
55

66
TEST_GROUP = get(ENV, "IIF_TEST_GROUP", "all")
7-
7+
@testset "IncrementalInference Tests" begin
88
# temporarily moved to start (for debugging)
99
#...
1010
if TEST_GROUP in ["all", "tmp_debug_group"]
11+
@testset "Temporary Debug Group" begin
1112
include("testSpecialOrthogonalMani.jl")
1213
include("testMultiHypo3Door.jl")
1314
include("priorusetest.jl")
1415
end
16+
end
1517

1618
if TEST_GROUP in ["all", "basic_functional_group"]
19+
@testset "Basic Functional Group" begin
1720
# more frequent stochasic failures from numerics
1821
include("testSpecialEuclidean2Mani.jl")
1922
include("testEuclidDistance.jl")
@@ -31,7 +34,7 @@ include("testCliqSolveDbgUtils.jl")
3134
include("basicGraphsOperations.jl")
3235

3336
# regular testing
34-
@test_broken include("testSphereMani.jl")
37+
@test_broken error("testSphereMani.jl broken")#include("testSphereMani.jl")
3538
include("testBasicManifolds.jl")
3639
include("testDERelative.jl")
3740
include("testHeatmapGridDensity.jl")
@@ -86,8 +89,10 @@ include("testSolveOrphanedFG.jl")
8689
include("testSolveSetPPE.jl")
8790
include("testSolveKey.jl")
8891
end
92+
end
8993

9094
if TEST_GROUP in ["all", "test_cases_group"]
95+
@testset "Test Cases Group" begin
9196
include("testnullhypothesis.jl")
9297
include("testVariousNSolveSize.jl")
9398
include("testExplicitMultihypo.jl")
@@ -114,6 +119,7 @@ end
114119
# include("testMultiprocess.jl")
115120
include("testDeadReckoningTether.jl")
116121
end
117-
122+
end
123+
end
118124

119125
#

IncrementalInference/test/testSpecialEuclidean2Mani.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ end
4646

4747
##
4848

49-
@testset "Test SpecialEuclidean(2)" begin
49+
@testset "Test SpecialEuclideanGroup(2)" begin
5050
##
5151

5252
M = getManifold(SpecialEuclidean2)
@@ -66,8 +66,8 @@ fg = initfg()
6666

6767
v0 = addVariable!(fg, :x0, SpecialEuclidean2)
6868

69-
# mp = ManifoldPrior(SpecialEuclidean(2), ArrayPartition(@MVector([0.0,0.0]), @MMatrix([1.0 0.0; 0.0 1.0])), MvNormal([0.01, 0.01, 0.01]))
70-
# mp = ManifoldPrior(SpecialEuclidean(2), ArrayPartition(@MVector([0.0,0.0]), @MMatrix([1.0 0.0; 0.0 1.0])), MvNormal(Diagonal(abs2.([0.01, 0.01, 0.01]))))
69+
# mp = ManifoldPrior(SpecialEuclideanGroup(2), ArrayPartition(@MVector([0.0,0.0]), @MMatrix([1.0 0.0; 0.0 1.0])), MvNormal([0.01, 0.01, 0.01]))
70+
# mp = ManifoldPrior(SpecialEuclideanGroup(2), ArrayPartition(@MVector([0.0,0.0]), @MMatrix([1.0 0.0; 0.0 1.0])), MvNormal(Diagonal(abs2.([0.01, 0.01, 0.01]))))
7171
# mp = ManifoldPrior(SE2, ArrayPartition([0.0,0.0], [1.0 0.0; 0.0 1.]), MvNormal(Diagonal(abs2.([0.01, 0.01, 0.01]))))
7272
mp = ManifoldPrior(SE2, ArrayPartition(SA[0.0,0.0], SA[1.0 0.0; 0.0 1.]), MvNormal(Diagonal(abs2.(SA[0.01, 0.01, 0.01]))))
7373
p = addFactor!(fg, [:x0], mp)

0 commit comments

Comments
 (0)