Skip to content

Commit 82d1604

Browse files
committed
move more more to Extension
1 parent 85fff3f commit 82d1604

2 files changed

Lines changed: 12 additions & 3 deletions

File tree

ext/FiniteDiffStaticArraysExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ else
99
end
1010
FiniteDiff._mat(x::StaticVector) = reshape(x, (axes(x, 1), SOneTo(1)))
1111
FiniteDiff.setindex(x::StaticArray, v, i::Int...) = StaticArrays.setindex(x, v, i...)
12+
FiniteDiff.__Symmetric(x::SMatrix) = Symmetric(SArray(H))
1213

1314
end #module

src/hessians.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,16 @@ struct HessianCache{T,fdtype,inplace}
55
xmm::T
66
end
77

8+
#used to dispatch on StaticArrays
89
_hessian_inplace(::Type{T}) where T = Val(ArrayInterface.ismutable(T))
910
_hessian_inplace(x) = _hessian_inplace(typeof(x))
11+
__Symmetric(x) = Symmetric(x)
12+
13+
function mutable_zeromatrix(x)
14+
A = ArrayInterface.zeromatrix(x)
15+
ArrayInterface.ismutable(A) ? A : Base.copymutable(A)
16+
end
17+
1018

1119
function HessianCache(xpp,xpm,xmp,xmm,
1220
fdtype=Val(:hcentral),
@@ -39,10 +47,10 @@ function finite_difference_hessian(
3947
cache::HessianCache{T,fdtype,inplace};
4048
relstep=default_relstep(fdtype, eltype(x)),
4149
absstep=relstep) where {T,fdtype,inplace}
42-
_H = false .* x .* x'
43-
_H isa SMatrix ? H = MArray(_H) : H = _H
50+
H = mutable_zeromatrix(x)
4451
finite_difference_hessian!(H, f, x, cache; relstep=relstep, absstep=absstep)
45-
Symmetric(_H isa SMatrix ? SArray(H) : H)
52+
__Symmetric(H)
53+
Symmetric(H isa SMatrix ? SArray(H) : H)
4654
end
4755

4856
function finite_difference_hessian!(H,f,

0 commit comments

Comments
 (0)