@@ -98,7 +98,11 @@ function _solveLambdaNumeric(
9898 r = if islen1
9999 Optim. optimize ((x) -> (residual .= objResX (x); sum (residual .^ 2 )), u0, Optim. BFGS ())
100100 else
101- Optim. optimize ((x) -> (residual .= objResX (x); sum (residual .^ 2 )), u0)
101+ Optim. optimize ((x) -> (residual .= objResX (x); sum (residual .^ 2 )), u0, Optim. Options (;iterations= 1000 ))
102+ end
103+
104+ if ! Optim. converged (r)
105+ @warn " Optim did not converge:" r maxlog= 10
102106 end
103107
104108 #
@@ -113,6 +117,20 @@ function _solveLambdaNumeric(
113117 variableType:: InferenceVariable ,
114118 islen1:: Bool = false ,
115119) where {N_, F <: AbstractManifoldMinimize , S, T}
120+
121+ return _solveCCWNumeric_test_SA (fcttype, objResX, residual, u0, variableType, islen1)
122+ # return _solveLambdaNumeric_test_optim_manifold(fcttype, objResX, residual, u0, variableType, islen1)
123+
124+ end
125+
126+ function _solveLambdaNumeric_original (
127+ fcttype:: Union{F, <:Mixture{N_, F, S, T}} ,
128+ objResX:: Function ,
129+ residual:: AbstractVector{<:Real} ,
130+ u0,# ::AbstractVector{<:Real},
131+ variableType:: InferenceVariable ,
132+ islen1:: Bool = false ,
133+ ) where {N_, F <: AbstractManifoldMinimize , S, T}
116134 #
117135 M = getManifold (variableType) # fcttype.M
118136 # the variable is a manifold point, we are working on the tangent plane in optim for now.
@@ -153,6 +171,93 @@ function _solveLambdaNumeric(
153171 return exp (M, ϵ, hat (M, ϵ, r. minimizer))
154172end
155173
174+ # 1.355700 seconds (11.78 M allocations: 557.677 MiB, 6.96% gc time)
175+ function _solveCCWNumeric_test_SA (
176+ fcttype:: Union{F, <:Mixture{N_, F, S, T}} ,
177+ objResX:: Function ,
178+ residual:: AbstractVector{<:Real} ,
179+ u0,# ::AbstractVector{<:Real},
180+ variableType:: InferenceVariable ,
181+ islen1:: Bool = false ,
182+ ) where {N_, F <: AbstractManifoldMinimize , S, T}
183+ #
184+ M = getManifold (variableType) # fcttype.M
185+ # the variable is a manifold point, we are working on the tangent plane in optim for now.
186+ #
187+ # TODO this is not general to all manifolds, should work for lie groups.
188+ # ϵ = identity_element(M, u0)
189+ ϵ = getPointIdentity (variableType)
190+
191+ X0c = zero (MVector{getDimension (M),Float64})
192+ X0c .= vee (M, u0, log (M, ϵ, u0))
193+
194+ # TODO check performance
195+ function cost (Xc)
196+ X = hat (M, ϵ, Xc)
197+ p = exp (M, ϵ, X)
198+ residual = objResX (p)
199+ return sum (residual .^ 2 )
200+ end
201+
202+ alg = islen1 ? Optim. BFGS () : Optim. NelderMead ()
203+
204+ r = Optim. optimize (cost, X0c, alg)
205+ if ! Optim. converged (r)
206+ # TODO find good way for a solve to store diagnostics about number of failed converges etc.
207+ @warn " Optim did not converge (maxlog=10):" r maxlog= 10
208+ end
209+ return exp (M, ϵ, hat (M, ϵ, r. minimizer))
210+ end
211+
212+ # sloooowwww and does not always converge, unusable slow with gradient
213+ # NelderMead 5.513693 seconds (38.60 M allocations: 1.613 GiB, 6.62% gc time)
214+ function _solveLambdaNumeric_test_optim_manifold (
215+ fcttype:: Union{F, <:Mixture{N_, F, S, T}} ,
216+ objResX:: Function ,
217+ residual:: AbstractVector{<:Real} ,
218+ u0,# ::AbstractVector{<:Real},
219+ variableType:: InferenceVariable ,
220+ islen1:: Bool = false ,
221+ ) where {N_, F <: AbstractManifoldMinimize , S, T}
222+ #
223+ M = getManifold (variableType) # fcttype.M
224+ # the variable is a manifold point, we are working on the tangent plane in optim for now.
225+ #
226+ # TODO this is not general to all manifolds, should work for lie groups.
227+ ϵ = getPointIdentity (variableType)
228+
229+ function cost (p)
230+ residual = objResX (p)
231+ return sum (residual .^ 2 )
232+ end
233+
234+ alg = islen1 ? Optim. BFGS (;manifold= ManifoldWrapper (M)) : Optim. NelderMead (;manifold= ManifoldWrapper (M))
235+ # alg = Optim.ConjugateGradient(; manifold=ManifoldWrapper(M))
236+ # alg = Optim.BFGS(; manifold=ManifoldWrapper(M))
237+
238+ # r_backend = ManifoldDiff.TangentDiffBackend(
239+ # ManifoldDiff.FiniteDifferencesBackend()
240+ # )
241+
242+ # ## finitediff gradient
243+ # function costgrad_FD!(X,p)
244+ # copyto!(X, ManifoldDiff.gradient(M, cost, p, r_backend))
245+ # X
246+ # end
247+
248+ u0_m = allocate (M, u0)
249+ u0_m .= u0
250+ # r = Optim.optimize(cost, costgrad_FD!, u0_m, alg)
251+ r = Optim. optimize (cost, u0_m, alg)
252+
253+ if ! Optim. converged (r)
254+ @warn " Optim did not converge:" r maxlog= 10
255+ end
256+
257+ return r. minimizer
258+ # return exp(M, ϵ, hat(M, ϵ, r.minimizer))
259+ end
260+
156261# TODO Consolidate with _solveLambdaNumeric, see #1374
157262function _solveLambdaNumericMeas (
158263 fcttype:: Union{F, <:Mixture{N_, F, S, T}} ,
@@ -317,15 +422,25 @@ function _solveCCWNumeric!(
317422 # islen1 = length(cpt_.X[:, smpid]) == 1 || ccwl.partial
318423
319424 # build the pre-objective function for this sample's hypothesis selection
320- unrollHypo!, target = _buildCalcFactorLambdaSample (ccwl, smpid; _slack = _slack)
425+ unrollHypo!, target = _buildCalcFactorLambdaSample (
426+ ccwl,
427+ smpid,
428+ view (ccwl. varValsAll[ccwl. varidx[]], smpid);
429+ _slack = _slack
430+ )
321431
322432 # broadcast updates original view memory location
323433 # # using CalcFactor legacy path inside (::CalcFactor)
324- _hypoObj = (x) -> (target .= x; unrollHypo! ())
434+
435+ # _hypoObj = (x) -> (target[] = x; unrollHypo!())
436+ function _hypoObj (x)
437+ target[] = x
438+ return unrollHypo! ()
439+ end
325440
326441 # TODO small off-manifold perturbation is a numerical workaround only, make on-manifold requires RoME.jl #244
327442 # use all element dimensions : ==> 1:ccwl.xDim
328- target .+ = _perturbIfNecessary (getFactorType (ccwl), length (target), perturb)
443+ # target .+= _perturbIfNecessary(getFactorType(ccwl), length(target), perturb)
329444
330445 sfidx = ccwl. varidx[]
331446 # do the parameter search over defined decision variables using Minimization
@@ -345,7 +460,7 @@ function _solveCCWNumeric!(
345460 end
346461
347462 # insert result back at the correct variable element location
348- ccwl. varValsAll[sfidx][smpid][ccwl. partialDims] . = retval
463+ copyto! ( ccwl. varValsAll[sfidx][smpid][ccwl. partialDims], retval)
349464
350465 return nothing
351466end
0 commit comments