Skip to content
13 changes: 7 additions & 6 deletions src/SparseVariables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@ using LinearAlgebra
using PrecompileTools

include("sparsearray.jl")
include("slice.jl")
include("broadcast.jl")
include("dictionaries.jl")
include("indexedarray.jl")
include("tables.jl")

export SparseArray
export IndexedVarArray
export SparseArraySlice
export slice
export insertvar!
export unsafe_insertvar!
export SafeInsert, UnsafeInsert
export set_cache_cutoff!

@setup_workload begin
# Putting some things in `setup` can reduce the size of the
Expand All @@ -29,19 +34,15 @@ export SafeInsert, UnsafeInsert
# all calls in this block will be precompiled, regardless of whether
# they belong to your package or not (on Julia 1.8 and higher)

@variable(
m,
x[r = rs, i = is, st = sts, sy = sys];
container = IndexedVarArray
)
@variable(m, x[r=rs, i=is, st=sts, sy=sys]; container = IndexedVarArray)
for r in rs, i in is, st in sts, sy in sys
insertvar!(x, r, i, st, sy)
unsafe_insertvar!(x, r, i, st, sy)
end
x[:, 1, :, :]
x[10, :, :, :]
x[1, :, :, :a]
@variable(m, y[i = rs, j = rs, k = rs]; container = IndexedVarArray)
@variable(m, y[i=rs, j=rs, k=rs]; container = IndexedVarArray)
for i in rs, j in rs, k in rs
insertvar!(y, i, j, k)
end
Expand Down
130 changes: 130 additions & 0 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# ------------------------------------------------------------------------------
# Broadcasting over AbstractSparseArray
# Follows the pattern of JuMP.Containers.SparseAxisArray.
# The result of any broadcast is always a plain SparseArray.
# ------------------------------------------------------------------------------

"""
SparseBroadcastStyle{N,K} <: Broadcast.BroadcastStyle

Broadcasting style for all `AbstractSparseArray` subtypes. `N` is the key
dimensionality and `K` is the key tuple type. All broadcast results are
materialised as `SparseArray`.
"""
struct SparseBroadcastStyle{N,K} <: Broadcast.BroadcastStyle end

function Base.BroadcastStyle(::Type{SA}) where {SA<:AbstractSparseArray}
return SparseBroadcastStyle{ndims(SA),_keytype(SA)}()
end

# Disallow mixing with other array types.
function Base.BroadcastStyle(::SparseBroadcastStyle, ::Base.BroadcastStyle)
return throw(
ArgumentError(
"Cannot broadcast a SparseArray with another array of a different type",
),
)
end

# Scalar (0-d) broadcasting is allowed.
function Base.BroadcastStyle(
style::SparseBroadcastStyle,
::Base.Broadcast.DefaultArrayStyle{0},
)
return style
end

# Fix ambiguity with Unknown.
function Base.BroadcastStyle(::SparseBroadcastStyle, ::Base.Broadcast.Unknown)
return throw(
ArgumentError(
"Cannot broadcast a SparseArray with an unknown broadcast style",
),
)
end

# Bypass the default instantiate which calls axes().
function Base.Broadcast.instantiate(
bc::Base.Broadcast.Broadcasted{<:SparseBroadcastStyle},
)
return bc
end

# ── Internal helpers ──────────────────────────────────────────────────────────

_sparse_getindex(x::AbstractSparseArray, key) = x[key]
_sparse_getindex(x::Any, ::Any) = x
_sparse_getindex(x::Ref, ::Any) = x[]

function _sparse_getindex(
bc::Base.Broadcast.Broadcasted{<:SparseBroadcastStyle},
key,
)
return bc.f(_sparse_get_args(bc.args, key)...)
end

function _sparse_get_args(args::Tuple, key)
return (
_sparse_getindex(first(args), key),
_sparse_get_args(Base.tail(args), key)...,
)
end
_sparse_get_args(::Tuple{}, ::Any) = ()

function _sparse_check_same_keys(ref_keys, x::AbstractSparseArray, args...)
if length(ref_keys) != length(x) || any(k -> !haskey(x, k), ref_keys)
throw(
ArgumentError(
"Cannot broadcast SparseArrays with different indices",
),
)
end
return _sparse_check_same_keys(ref_keys, args...)
end

function _sparse_check_same_keys(ref_keys, ::Any, args...)
return _sparse_check_same_keys(ref_keys, args...)
end
_sparse_check_same_keys(::Any) = nothing

function _sparse_indices(
bc::Base.Broadcast.Broadcasted{<:SparseBroadcastStyle},
rest...,
)
return _sparse_indices(bc.args..., rest...)
end

function _sparse_indices(x::AbstractSparseArray, rest...)
ks = collect(keys(x))
_sparse_check_same_keys(ks, rest...)
return ks
end

_sparse_indices(::Any, rest...) = _sparse_indices(rest...)

# ── Materialise ───────────────────────────────────────────────────────────────

function Base.copy(
bc::Base.Broadcast.Broadcasted{SparseBroadcastStyle{N,K}},
) where {N,K}
indices = _sparse_indices(bc)
isempty(indices) && return SparseArray(Dictionary{K,Any}())
vals = [_sparse_getindex(bc, k) for k in indices]
return SparseArray(Dictionary(indices, vals))
end

function Base.Broadcast.broadcast_preserving_zero_d(
f,
A::AbstractSparseArray,
As...,
)
return broadcast(f, A, As...)
end
function Base.Broadcast.broadcast_preserving_zero_d(
f,
x,
A::AbstractSparseArray,
As...,
)
return broadcast(f, x, A, As...)
end
26 changes: 8 additions & 18 deletions src/dictionaries.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
make_filter_fun(c, pos)

Return function to use for filtering depending on the type and value of `c`
Return function to use for filtering depending on the type and value of `c`
to apply at position `pos`
"""
make_filter_fun(c, pos) = x -> x[pos] == c
Expand Down Expand Up @@ -38,7 +38,7 @@ function indices_fun(some_tuple)
end

"""
_select_rowwise(a, pattern)
_select_rowwise(a, pattern)

Filter iterable data a by tuple `pattern` by row (slow)
"""
Expand All @@ -47,7 +47,7 @@ function _select_rowwise(a, pattern)
end

"""
_select_colwise(a, pattern)
_select_colwise(a, pattern)

Filter iterable data a by tuple `pattern` by column (recursively)
"""
Expand All @@ -59,7 +59,7 @@ end
_select_gen(a, pattern)

Filter iterable data `a` by tuple `pattern` by row, using generated function for speed.
See more straight-forward implementations `_select_rowwise` and `_select_colwise` for reference.
See more straight-forward implementations `_select_rowwise` and `_select_colwise` for reference.
"""
function _select_gen(a, pattern)
return filter(x -> _select_generated(pattern, x), a)
Expand All @@ -68,7 +68,7 @@ end
"""
_select_gen_perm(a, pattern, perm)
Filter iterable `data` byt tuple `pattern` by row using generated function that permutes the sequence of
evaluation by the permutation tuple `perm` for improved control as this can give performance advantages,
evaluation by the permutation tuple `perm` for improved control as this can give performance advantages,
depending on the uniqueness of the search pattern and cost of function evaluation.

## Example
Expand Down Expand Up @@ -226,33 +226,23 @@ Works on types because it is used in generated function
"""
isfixed(t) = true
isfixed(::Type{T} where {T<:Function}) = false
isfixed(::Type{T} where {T<:UnitRange}) = false
iscolon(t) = false
iscolon(::Type{T} where {T<:Colon}) = true
isfixed(::Type{T} where {T<:AbstractRange}) = false

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Colon() is a Function so there is no problem

Comment on lines 228 to +229

@generated function _getindex(
sa::AbstractSparseArray{T,N},
tpl::Tuple,
) where {T,N}
lookup = true
slice = true
for t in fieldtypes(tpl)
if !isfixed(t)
lookup = false
if !iscolon(t)
slice = false
end
end
end

if lookup
return :(get(_data(sa), tpl, zero(T)))
elseif !slice
return :(retval = select(_data(sa), tpl);
length(retval) > 0 ? retval : zero(T))
else # Return selection or zero if empty to avoid reduction of empty iterate
return :(retval = _select_var(sa, tpl);
length(retval) > 0 ? retval : zero(T))
else
return :(_make_slice(sa, tpl))
end
end

Expand Down
Loading
Loading