Skip to content

Commit 50ce223

Browse files
committed
keep solve, refac CalcFactor
1 parent 4281b25 commit 50ce223

1 file changed

Lines changed: 18 additions & 7 deletions

File tree

IncrementalInference/src/services/CalcFactor.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ function _createCCW(
376376
_blockRecursion::Bool = false,
377377
attemptGradients::Bool = true,
378378
userCache::CT = nothing,
379+
keepCalcFactor::Bool = false,
379380
) where {T <: AbstractFactor, CT}
380381
#
381382
if length(Xi) !== 0
@@ -394,7 +395,7 @@ function _createCCW(
394395
fullvariables = tuple(Xi...) # convert(Vector{VariableCompute}, Xi)
395396
# create a temporary CalcFactor object for extracting the first sample
396397

397-
_cf = CalcFactorNormSq(
398+
_cf_nt = CalcFactorNormSq(
398399
usrfnc,
399400
1,
400401
_varValsAll,
@@ -408,12 +409,12 @@ function _createCCW(
408409
)
409410

410411
# get a measurement sample
411-
meas_single = sampleFactor(_cf, 1)[1]
412+
meas_single = sampleFactor(_cf_nt, 1)[1]
412413
elT = typeof(meas_single)
413414
#TODO preallocate measurement?
414415
measurement = Vector{elT}()
415416

416-
#FIXME chicken and egg problem for getting measurement type, so creating twice.
417+
# NOTE chicken and egg problem for getting measurement type, so creating twice.
417418
_cf = CalcFactorNormSq(
418419
usrfnc,
419420
1,
@@ -427,6 +428,11 @@ function _createCCW(
427428
nothing,
428429
)
429430

431+
keepCalcFactor_ = if keepCalcFactor
432+
Channel{CalcFactor}(1024)
433+
else
434+
nothing
435+
end
430436

431437
# partialDims are sensitive to both which solvefor variable index and whether the factor is partial
432438
partial = hasfield(T, :partial) # FIXME, use isPartial function instead
@@ -441,6 +447,7 @@ function _createCCW(
441447

442448
# as per struct CommonConvWrapper
443449
_gradients = if attemptGradients
450+
# FIXME update to proper AD tools
444451
attemptGradientPrep(
445452
varTypes,
446453
usrfnc,
@@ -477,6 +484,7 @@ function _createCCW(
477484
),
478485
measurement,
479486
_gradients,
487+
keepCalcFactor = keepCalcFactor_
480488
)
481489
end
482490

@@ -485,14 +493,15 @@ function updateMeasurement!(
485493
N::Int=1;
486494
measurement::AbstractVector = Vector{Tuple{}}(),
487495
needFreshMeasurements::Bool=true,
488-
_allowThreads::Bool = true
496+
_allowThreads::Bool = true,
497+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
489498
)
490499
# FIXME do not divert Mixture for sampling
491500

492501
# option to disable fresh samples or user provided
493502
if needFreshMeasurements
494503
# TODO this is only one thread, make this a for loop for multithreaded sampling
495-
sampleFactor!(ccwl, N; _allowThreads)
504+
sampleFactor!(ccwl, N; _allowThreads, keepCalcFactor)
496505
elseif 0 < length(measurement)
497506
resize!(ccwl.measurement, length(measurement))
498507
ccwl.measurement[:] = measurement
@@ -520,6 +529,7 @@ function _beforeSolveCCW!(
520529
measurement = Vector{Tuple{}}(),
521530
needFreshMeasurements::Bool = true,
522531
solveKey::Symbol = :default,
532+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
523533
) where {F <: AbstractFactor} # F might be Mixture
524534
#
525535
if length(variables) !== 0
@@ -570,7 +580,7 @@ function _beforeSolveCCW!(
570580
_setCCWDecisionDimsConv!(ccwl, xDim)
571581

572582
# FIXME do not divert Mixture for sampling
573-
updateMeasurement!(ccwl, maxlen; needFreshMeasurements, measurement, _allowThreads=true)
583+
updateMeasurement!(ccwl, maxlen; needFreshMeasurements, measurement, _allowThreads=true, keepCalcFactor)
574584

575585
# used in ccw functor for AbstractRelativeMinimize
576586
resize!(ccwl.res, _getZDim(ccwl))
@@ -591,6 +601,7 @@ function _beforeSolveCCW!(
591601
measurement = Vector{Tuple{}}(),
592602
needFreshMeasurements::Bool = true,
593603
solveKey::Symbol = :default,
604+
keepCalcFactor::Union{Nothing, <:Channel} = nothing,
594605
) where {F <: AbstractFactor} # F might be Mixture
595606
# FIXME, NEEDS TO BE CLEANED UP AND WORK ON MANIFOLDS PROPER
596607

@@ -606,7 +617,7 @@ function _beforeSolveCCW!(
606617

607618
# FIXME do not divert Mixture for sampling
608619
# update ccwl.measurement values
609-
updateMeasurement!(ccwl, maxlen; needFreshMeasurements, measurement, _allowThreads=true)
620+
updateMeasurement!(ccwl, maxlen; needFreshMeasurements, measurement, _allowThreads=true, keepCalcFactor)
610621

611622
return maxlen
612623
end

0 commit comments

Comments
 (0)