Skip to content

Commit 0ef483f

Browse files
committed
move StaticArrays to extensions
1 parent b41a859 commit 0ef483f

4 files changed

Lines changed: 26 additions & 8 deletions

File tree

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1313
[weakdeps]
1414
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
1515
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
16+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1617

1718
[extensions]
1819
FiniteDiffBandedMatricesExt = "BandedMatrices"
1920
FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices"
21+
FiniteDiffStaticArraysExt = "StaticArrays"
2022

2123
[compat]
2224
ArrayInterface = "7"
@@ -33,4 +35,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3335
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3436

3537
[targets]
36-
test = ["Test", "BlockBandedMatrices", "BandedMatrices", "Pkg", "SafeTestsets"]
38+
test = ["Test", "BlockBandedMatrices", "BandedMatrices", "Pkg", "SafeTestsets", "StaticArrays"]

ext/FiniteDiffStaticArraysExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module FiniteDiffStaticArraysExt
2+
3+
if isdefined(Base, :get_extension)
4+
using FiniteDiff: FiniteDiff, ArrayInterface
5+
using StaticArrays
6+
else
7+
using ..FiniteDiff: FiniteDiff, ArrayInterface
8+
using ..StaticArrays
9+
end
10+
FiniteDiff._mat(x::StaticVector) = reshape(x, (axes(x, 1), SOneTo(1)))
11+
FiniteDiff.setindex(x::StaticArray, v, i::Int...) = StaticArrays.setindex(x, v, i...)
12+
13+
end #module

src/FiniteDiff.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
module FiniteDiff
22

3-
using LinearAlgebra, SparseArrays, StaticArrays, ArrayInterface, Requires
3+
using LinearAlgebra, SparseArrays, ArrayInterface, Requires
44

55
import Base: resize!
66

77
_vec(x) = vec(x)
88
_vec(x::Number) = x
99

1010
_mat(x::AbstractMatrix) = x
11-
_mat(x::StaticVector) = reshape(x, (axes(x, 1), SOneTo(1)))
1211
_mat(x::AbstractVector) = reshape(x, (axes(x, 1), Base.OneTo(1)))
1312

1413
# Setindex overloads without piracy
1514
setindex(x...) = Base.setindex(x...)
16-
setindex(x::StaticArray, v, i::Int...) = StaticArrays.setindex(x, v, i...)
1715

1816
function setindex(x::AbstractArray, v, i...)
1917
_x = Base.copymutable(x)
@@ -39,6 +37,8 @@ include("jacobians.jl")
3937
include("hessians.jl")
4038

4139
if !isdefined(Base,:get_extension)
40+
using StaticArrays
41+
include("../ext/FiniteDiffStaticArraysExt.jl")
4242
using Requires
4343
function __init__()
4444
@require BandedMatrices="aae01518-5342-5314-be14-df237901396f" begin

src/hessians.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@ struct HessianCache{T,fdtype,inplace}
55
xmm::T
66
end
77

8+
_hessian_inplace(::Type{T}) where T = Val(ArrayInterface.ismutable(T))
9+
_hessian_inplace(x) = _hessian_inplace(typeof(x))
10+
811
function HessianCache(xpp,xpm,xmp,xmm,
912
fdtype=Val(:hcentral),
10-
inplace = x isa StaticArray ? Val(false) : Val(true))
13+
inplace = _hessian_inplace(x))
1114
fdtype isa Type && (fdtype = fdtype())
1215
inplace isa Type && (inplace = inplace())
1316
HessianCache{typeof(xpp),fdtype,inplace}(xpp,xpm,xmp,xmm)
1417
end
1518

1619
function HessianCache(x, fdtype=Val(:hcentral),
17-
inplace = x isa StaticArray ? Val(false) : Val(true))
20+
inplace = _hessian_inplace(x))
1821
cx = copy(x)
1922
fdtype isa Type && (fdtype = fdtype())
2023
inplace isa Type && (inplace = inplace())
@@ -23,7 +26,7 @@ end
2326

2427
function finite_difference_hessian(f, x,
2528
fdtype = Val(:hcentral),
26-
inplace = x isa StaticArray ? Val(false) : Val(true);
29+
inplace = _hessian_inplace(x);
2730
relstep = default_relstep(fdtype, eltype(x)),
2831
absstep = relstep)
2932

@@ -45,7 +48,7 @@ end
4548
function finite_difference_hessian!(H,f,
4649
x,
4750
fdtype = Val(:hcentral),
48-
inplace = x isa StaticArray ? Val(false) : Val(true);
51+
inplace = _hessian_inplace(x);
4952
relstep=default_relstep(fdtype, eltype(x)),
5053
absstep=relstep)
5154

0 commit comments

Comments
 (0)