Skip to content

Commit 1894b15

Browse files
authored
Refactor Mixture as direct Density and standardize sampling on FluxModelsDistributions to resemble <: Sampleable (#1882)
* Refactor Mixture as direct Density and standardize sampling on FluxModelsDistributions to resemble <: Sambleable * rm getFactorMechanics --------- Co-authored-by: Johannes Terblanche <Affie@users.noreply.github.com>
1 parent 439796a commit 1894b15

18 files changed

Lines changed: 112 additions & 275 deletions

IncrementalInference/ext/IncrInfrFluxFactorsExt.jl

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function Random.rand(nfb::FluxModelsDistribution, N::Integer = 1)
3535
numModels = length(nfb.models)
3636
allPreds = 1:numModels |> collect
3737
# TODO -- compensate when there arent enough prediction models
38-
if !(N isa Nothing) && numModels < N
38+
if numModels < N
3939
reps = (N ÷ numModels) + 1
4040
allPreds = repeat(allPreds, reps)
4141
resize!(allPreds, N)
@@ -44,20 +44,16 @@ function Random.rand(nfb::FluxModelsDistribution, N::Integer = 1)
4444
# can suppress shuffle for NN training purposes
4545
selPred = 1 < numModels && nfb.shuffle[] ? rand(allPreds, N) : view(allPreds, 1:N)
4646

47-
# dev function, TODO simplify to direct call
48-
_sample() = map(pred -> (nfb.models[pred])(nfb.data), selPred)
49-
50-
return _sample()
51-
# return [_sample() for _ in 1:N]
47+
# FIXME not type stable, maybe use Univariate and Multivariate from Distributions.jl
48+
samples = map(pred -> (nfb.models[pred])(nfb.data), selPred)
49+
dim = length(samples[1]) # dim 1 is Univariate
50+
if N == 1
51+
return dim == 1 ? samples[1][1] : samples[1]
52+
else
53+
return dim == 1 ? reduce(vcat, samples) : reduce(hcat, samples)
54+
end
5255
end
5356

54-
sampleTangent(M::AbstractManifold, fmd::FluxModelsDistribution, p = 0) = rand(fmd, 1)[1]
55-
sampleTangent(M::AbstractLieGroup, fmd::FluxModelsDistribution, p = 0) = rand(fmd, 1)[1]
56-
57-
samplePoint(M::AbstractManifold, fmd::FluxModelsDistribution, p = 0) = rand(fmd, 1)[1]
58-
function samplePoint(M::AbstractLieGroup, fmd::FluxModelsDistribution, p = 0)
59-
return rand(fmd, 1)[1]
60-
end
6157

6258
function FluxModelsDistribution(
6359
inDim::NTuple{ID, Int},
@@ -138,7 +134,7 @@ Related
138134
Mixture, FluxModelsDistribution
139135
"""
140136
function MixtureFluxModels(
141-
F_::AbstractObservation,
137+
F_::Type{<:AbstractObservation},
142138
nnModels::Vector{P},
143139
inDim::NTuple{ID, Int},
144140
data::D,
@@ -174,10 +170,6 @@ function MixtureFluxModels(
174170
return Mixture(F_, ntup, diversity)
175171
end
176172

177-
function MixtureFluxModels(::Type{F}, w...; kw...) where {F <: AbstractObservation}
178-
return MixtureFluxModels(F(LinearAlgebra.I), w...; kw...)
179-
end
180-
181173
#
182174

183175
include("FluxModelsSerialization.jl")

IncrementalInference/src/Deprecated.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,32 @@
55
#TODO this looks like dead code, should be removed
66
# TODO deprecate testshuffle
77
function _checkErrorCCWNumerics(
8-
ccwl::Union{CommonConvWrapper{F}, CommonConvWrapper{Mixture{N_, F, S, T}}},
8+
ccwl::CommonConvWrapper{F},
99
testshuffle::Bool = false,
10-
) where {N_, F <: AbstractRelativeMinimize, S, T}
10+
) where {F <: AbstractRelativeMinimize}
1111
return nothing
1212
end
1313
function _checkErrorCCWNumerics(
14-
ccwl::Union{CommonConvWrapper{F}, CommonConvWrapper{Mixture{N_, F, S, T}}},
14+
ccwl::CommonConvWrapper{F},
1515
testshuffle::Bool = false,
16-
) where {N_, F <: AbstractManifoldMinimize, S, T}
16+
) where {F <: AbstractManifoldMinimize}
1717
return nothing
1818
end
1919

2020
#TODO this looks like dead code, should be removed
2121
function _perturbIfNecessary(
22-
fcttype::Union{F, <:Mixture{N_, F, S, T}},
22+
fcttype::AbstractRelativeMinimize,
2323
len::Int = 1,
2424
perturbation::Real = 1e-10,
25-
) where {N_, F <: AbstractRelativeMinimize, S, T}
25+
)
2626
return 0
2727
end
2828

2929
function _perturbIfNecessary(
30-
fcttype::Union{F, <:Mixture{N_, F, S, T}},
30+
fcttype::AbstractManifoldMinimize,
3131
len::Int = 1,
3232
perturbation::Real = 1e-10,
33-
) where {N_, F <: AbstractManifoldMinimize, S, T}
33+
)
3434
return 0
3535
end
3636
#
@@ -151,6 +151,10 @@ function resetData!(vdata::DFG.FunctionNodeData)
151151
error("resetData!(vdata::FunctionNodeData) is deprecated, use resetData!(state::FactorState) instead")
152152
end
153153

154+
function sampleTangent(x::ManifoldKernelDensity, p = mean(x))
155+
error("sampleTangent(x::ManifoldKernelDensity, p) should be replaced by sampleTangent(M<:AbstractManifold, x::ManifoldKernelDensity, p)")
156+
end
157+
154158
## ================================================================================================
155159
## ================================================================================================
156160

IncrementalInference/src/Factors/LinearRelative.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ LinearRelative(nm::MvNormal) = LinearRelative{length(nm.μ), typeof(nm)}(nm)
3232
function LinearRelative(nm::Union{<:BallTreeDensity, <:ManifoldKernelDensity})
3333
return LinearRelative{Ndim(nm), typeof(nm)}(nm)
3434
end
35+
LinearRelative(z) = LinearRelative{length(z)}(z)
3536

3637
getManifold(::InstanceType{LinearRelative{N}}) where {N} = getManifold(ContinuousEuclid{N})
3738

Lines changed: 46 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
2-
_defaultNamesMixtures(N::Int) = ((Symbol[Symbol("c$i") for i = 1:N])...,)
3-
41
"""
52
$(TYPEDEF)
63
@@ -12,195 +9,77 @@ Notes
129
- `N` is the number of components used to make the mixture, so two bumps from two Normal components means `N=2`.
1310
1411
DevNotes
15-
- FIXME swap API order so Mixture of distibutions works like a distribtion, see Caesar.jl #808
16-
- Should not have field mechanics.
1712
- TODO on sampling see #1099 and #1094 and #1069
1813
19-
2014
Example
21-
```juila
15+
```julia
2216
# prior factor
23-
msp = Mixture(Prior,
24-
[Normal(0,0.1), Uniform(-pi/1,pi/2)],
25-
[0.5;0.5])
17+
msp = Prior(Mixture([Normal(0,0.1), Uniform(-pi/1,pi/2)], [0.5;0.5])
2618
2719
addFactor!(fg, [:head], msp, tags=[:MAGNETOMETER;])
2820
2921
# Or relative
30-
mlr = Mixture(LinearRelative,
22+
mlr = LinearRelative(Mixture(
3123
(correlator=AliasingScalarSampler(...), naive=Normal(0.5,5), lucky=Uniform(0,10)),
32-
[0.5;0.4;0.1])
24+
[0.5;0.4;0.1]
25+
)
3326
3427
addFactor!(fg, [:x0;:x1], mlr)
3528
```
3629
"""
37-
struct Mixture{N, F <: AbstractObservation, S, T <: Tuple} <: AbstractObservation
38-
""" factor mechanics """
39-
mechanics::F
40-
components::NamedTuple{S, T}
41-
diversity::Distributions.Categorical
42-
""" dimension of factor, so range measurement would be dims=1 """
43-
dims::Int
44-
labels::Vector{Int}
30+
# we can do <: Sampleable{VF, VS}, but then all components must implement Sampleable
31+
struct Mixture{C <: NamedTuple, P <: Categorical}
32+
components::C
33+
prior::P
4534
end
4635

36+
Base.length(d::Mixture) = length(d.components[1])
37+
#FIXME next not working
38+
# Base.rand(rng::AbstractRNG, d::Mixture) = rand(rng, d.components[rand(rng, d.prior)])
39+
Base.rand(d::Mixture) = rand(d.components[rand(d.prior)])
40+
getDimension(p::Mixture) = length(p)
41+
42+
Mixture(z::NamedTuple) = Mixture(z, Categorical(length(z)))
43+
Mixture(z::NamedTuple, c::Union{<:Tuple, <:AbstractVector}) = Mixture(z, Categorical(c))
44+
4745
function Mixture(
48-
f::Type{F},
49-
z::NamedTuple{S, T},
50-
c::Distributions.DiscreteNonParametric,
51-
) where {F <: AbstractObservation, S, T}
52-
return Mixture{length(z), F, S, T}(
53-
f(LinearAlgebra.I),
54-
z,
55-
c,
56-
size(rand(z[1], 1), 1),
57-
zeros(Int, 0),
58-
)
59-
end
60-
function Mixture(
61-
f::F,
62-
z::NamedTuple{S, T},
63-
c::Distributions.DiscreteNonParametric,
64-
) where {F <: AbstractObservation, S, T}
65-
return Mixture{length(z), F, S, T}(f, z, c, size(rand(z[1], 1), 1), zeros(Int, 0))
66-
end
67-
function Mixture(
68-
f::Union{F, Type{F}},
69-
z::NamedTuple{S, T},
70-
c::AbstractVector{<:Real},
71-
) where {F <: AbstractObservation, S, T}
72-
return Mixture(f, z, Categorical([c...]))
73-
end
74-
function Mixture(
75-
f::Union{F, Type{F}},
76-
z::NamedTuple{S, T},
77-
c::NTuple{N, <:Real},
78-
) where {N, F <: AbstractObservation, S, T}
79-
return Mixture(f, z, [c...])
80-
end
81-
function Mixture(
82-
f::Union{F, Type{F}},
83-
z::Tuple,
84-
c::Union{
85-
<:Distributions.DiscreteNonParametric,
86-
<:AbstractVector{<:Real},
87-
<:NTuple{N, <:Real},
88-
},
89-
) where {F <: AbstractObservation, N}
90-
return Mixture(f, NamedTuple{_defaultNamesMixtures(length(z))}(z), c)
91-
end
92-
function Mixture(
93-
f::Union{F, Type{F}},
94-
z::AbstractVector{<:SamplableBelief},
95-
c::Union{
96-
<:Distributions.DiscreteNonParametric,
97-
<:AbstractVector{<:Real},
98-
<:NTuple{N, <:Real},
99-
},
100-
) where {F <: AbstractObservation, N}
101-
return Mixture(f, (z...,), c)
46+
z::Union{<:Tuple, <:AbstractVector},
47+
args...,
48+
)
49+
cnames = tuple((Symbol("c", i) for i = 1:length(z))...)
50+
return Mixture(NamedTuple{cnames}(z), args...)
10251
end
10352

104-
function Base.resize!(mp::Mixture, s::Int)
105-
return resize!(mp.labels, s)
53+
function Mixture(f::Type{<:AbstractObservation}, args...)
54+
return f(Mixture(args...))
10655
end
10756

108-
_lengthOrNothing(val) = length(val)
109-
_lengthOrNothing(val::Nothing) = 0
110-
111-
getManifold(m::Mixture) = getManifold(m.mechanics)
112-
113-
# TODO make in-place memory version
114-
function sampleFactor(cf::CalcFactor{<:Mixture}, N::Int = 1)
115-
#
116-
# TODO consolidate #927, case if mechanics has a special sampler
117-
# TODO slight bit of waste in computation, but easiest way to ensure special tricks in s.mechanics::F are included
118-
## example case is old FluxModelsPose2Pose2 requiring velocity
119-
# FIXME better consolidation of when to pass down .mechanics, also see #1099 and #1094 and #1069
120-
121-
cf_ = CalcFactorNormSq(
122-
cf.factor.mechanics,
123-
0,
124-
cf._legacyParams,
125-
cf._allowThreads,
126-
cf.cache,
127-
cf.fullvariables,
128-
cf.solvefor,
129-
cf.manifold,
130-
cf.measurement,
131-
nothing,
132-
)
133-
smpls = [getSample(cf_) for _ = 1:N]
134-
# smpls = Array{Float64,2}(undef,s.dims,N)
135-
#out memory should be right size first
136-
length(cf.factor.labels) != N ? resize!(cf.factor.labels, N) : nothing
137-
cf.factor.labels .= rand(cf.factor.diversity, N)
138-
M = cf.manifold
139-
140-
# mixture needs to be refactored so let's make it worse :-)
141-
if cf.factor.mechanics isa AbstractPriorObservation
142-
samplef = samplePoint
143-
elseif cf.factor.mechanics isa AbstractRelativeObservation
144-
samplef = sampleTangent
145-
end
146-
147-
for i = 1:N
148-
mixComponent = cf.factor.components[cf.factor.labels[i]]
149-
# measurements relate to the factor's manifold (either tangent vector or manifold point)
150-
setPointsMani!(smpls, samplef(M, mixComponent), i)
151-
end
152-
153-
# TODO only does first element of meas::Tuple at this stage, see #1099
154-
return smpls
57+
Base.@kwdef struct PackedMixture{C} <: PackedBelief
58+
_type::String = "IncrementalInference.PackedMixture"
59+
components::C
60+
prior::PackedCategorical
15561
end
15662

157-
function DistributedFactorGraphs.isPrior(::Type{Mixture{N, F, S, T}}) where {N, F, S, T}
158-
return F <: AbstractPriorObservation
63+
# FIXME 🦨 use JSON properly (in JSONv1 upgrade)
64+
function PackedMixture(_type::String, comp_obj::JSON3.Object, prior_obj::JSON3.Object)
65+
PackedMixture(
66+
_type,
67+
NamedTuple(map(c -> c[1]=>convert(PackedBelief, c[2]), collect(pairs(comp_obj)))),
68+
PackedCategorical(; prior_obj...),
69+
)
15970
end
16071

161-
"""
162-
$(TYPEDEF)
16372

164-
Serialization type for `Mixture`.
165-
"""
166-
Base.@kwdef mutable struct PackedMixture <: AbstractPackedObservation
167-
N::Int
168-
# store the packed type for later unpacking
169-
F_::String
170-
S::Vector{String}
171-
components::Vector{PackedBelief}
172-
diversity::PackedBelief
173-
end
174-
175-
function convert(::Type{<:PackedMixture}, obj::Mixture{N, F, S, T}) where {N, F, S, T}
176-
allcomp = PackedBelief[]
177-
for val in obj.components
178-
dtr_ = convert(PackedBelief, val)
179-
# FIXME ON FIRE, likely to be difficult for non-standard "Samplable" types -- e.g. Flux models in RoME
180-
push!(allcomp, dtr_)
181-
end
182-
if hasmethod(pack, (typeof(obj.mechanics),))
183-
pm = pack(obj.mechanics)
184-
else
185-
@warn("No pack method for mechanics type $(typeof(obj.mechanics)), using deprecated convert instead.")
186-
pm = convert(DFG.convertPackedType(obj.mechanics), obj.mechanics)
187-
end
188-
sT = string(typeof(pm))
189-
dvst = convert(PackedBelief, obj.diversity)
190-
return PackedMixture(N, sT, string.(collect(S)), allcomp, dvst)
73+
function DFG.packDistribution(m::Mixture)
74+
PackedMixture(;
75+
components = map(packDistribution, m.components),
76+
prior = PackedCategorical(; p = m.prior.p)
77+
)
19178
end
19279

193-
function convert(::Type{<:Mixture}, obj::PackedMixture)
194-
N = obj.N
195-
F1 = getfield(Main, Symbol(obj.F_))
196-
S = (Symbol.(obj.S)...,)
197-
F2 = DFG.convertStructType(F1)
198-
199-
components = map(c -> convert(SamplableBelief, c), obj.components)
200-
diversity = convert(SamplableBelief, obj.diversity)
201-
# tupcomp = (components...,)
202-
ntup = NamedTuple{S}(components) # ,typeof(tupcomp)
203-
return Mixture(F2, ntup, diversity)
80+
function DFG.unpackDistribution(pm::PackedMixture)
81+
Mixture(
82+
map(unpackDistribution, pm.components),
83+
unpackDistribution(pm.prior)
84+
)
20485
end
205-
206-
#

IncrementalInference/src/Factors/PartialPrior.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ struct PartialPrior{T <: SamplableBelief, P <: Tuple} <: AbstractPriorObservatio
1515
end
1616

1717
# TODO, standardize, but shows error on testPartialNH.jl
18+
#FIXME should call samplePoint(M, cf.factor.Z)
1819
getSample(cf::CalcFactor{<:PartialPrior}) = samplePoint(cf.factor.Z) # remove in favor of ManifoldSampling.jl
1920
# getManifold(pp::PartialPrior) = TranslationGroup(length(pp.partial)) # uncomment
2021

IncrementalInference/src/IncrementalInference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ include("Factors/GenericMarginal.jl")
169169
include("entities/AliasScalarSampling.jl")
170170
include("entities/ExtDensities.jl") # used in BeliefTypes.jl::SamplableBeliefs
171171
include("entities/ExtFactors.jl")
172+
include("Factors/Mixture.jl")
172173
include("entities/BeliefTypes.jl")
173174

174175
include("services/HypoRecipe.jl")
@@ -218,7 +219,6 @@ include("Variables/DefaultVariables.jl")
218219

219220
# included factors, see RoME.jl for more examples
220221
include("Factors/GenericFunctions.jl")
221-
include("Factors/Mixture.jl")
222222
include("Factors/DefaultPrior.jl")
223223
include("Factors/LinearRelative.jl")
224224
include("Factors/EuclidDistance.jl")

0 commit comments

Comments
 (0)