Skip to content

Commit 172aafc

Browse files
committed
Optimized implementation for structured matrices V2
1 parent d318be5 commit 172aafc

12 files changed

Lines changed: 416 additions & 39 deletions

Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1212
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1313

1414
[weakdeps]
15+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
16+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
17+
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
1518
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1619
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
1720
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
@@ -21,6 +24,8 @@ MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
2124
cuSPARSE = "b26da814-b3bc-49ef-b0ee-c816305aa060"
2225

2326
[extensions]
27+
SparseMatrixColoringsBandedMatricesExt = "BandedMatrices"
28+
SparseMatrixColoringsBlockBandedMatricesExt = ["BlockArrays", "BlockBandedMatrices"]
2429
SparseMatrixColoringsCUDAExt = ["CUDA", "cuSPARSE"]
2530
SparseMatrixColoringsCliqueTreesExt = "CliqueTrees"
2631
SparseMatrixColoringsColorsExt = "Colors"
@@ -29,6 +34,9 @@ SparseMatrixColoringsJuMPExt = ["JuMP", "MathOptInterface"]
2934

3035
[compat]
3136
ADTypes = "1.2.1"
37+
BandedMatrices = "1.9.4"
38+
BlockArrays = "1.6.3"
39+
BlockBandedMatrices = "0.13.1"
3240
CUDA = "6.0.0"
3341
CliqueTrees = "1"
3442
Colors = "0.12.11, 0.13"

docs/make.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ using Documenter
22
using DocumenterInterLinks
33
using SparseMatrixColorings
44

5-
links = InterLinks("ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/")
5+
links = InterLinks(
6+
"ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/",
7+
"BandedMatrices" => "https://julialinearalgebra.github.io/BandedMatrices.jl/stable/",
8+
)
69

710
cp(joinpath(@__DIR__, "..", "README.md"), joinpath(@__DIR__, "src", "index.md"); force=true)
811

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ ColoringProblem
2525
GreedyColoringAlgorithm
2626
ConstantColoringAlgorithm
2727
OptimalColoringAlgorithm
28+
StructuredColoringAlgorithm
2829
```
2930

3031
## Result analysis

docs/src/dev.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,9 @@ SparseMatrixColorings.what_fig_61
8080
SparseMatrixColorings.efficient_fig_1
8181
SparseMatrixColorings.efficient_fig_4
8282
```
83+
84+
## Misc
85+
86+
```@docs
87+
SparseMatrixColorings.cycle_range
88+
```
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
module SparseMatrixColoringsBandedMatricesExt
2+
3+
using BandedMatrices: BandedMatrix, bandrange, bandwidths, colrange, rowrange
4+
using SparseMatrixColorings:
5+
BipartiteGraph,
6+
ColoringProblem,
7+
ColumnColoringResult,
8+
StructuredColoringAlgorithm,
9+
RowColoringResult,
10+
column_colors,
11+
cycle_range,
12+
row_colors
13+
import SparseMatrixColorings as SMC
14+
15+
#=
16+
This code is partly taken from ArrayInterface.jl and FiniteDiff.jl
17+
https://github.com/JuliaArrays/ArrayInterface.jl
18+
https://github.com/JuliaDiff/FiniteDiff.jl
19+
=#
20+
21+
function SMC.coloring(
22+
A::BandedMatrix,
23+
::ColoringProblem{:nonsymmetric,:column},
24+
::StructuredColoringAlgorithm;
25+
kwargs...,
26+
)
27+
width = length(bandrange(A))
28+
color = cycle_range(width, size(A, 2))
29+
bg = BipartiteGraph(A)
30+
return ColumnColoringResult(A, bg, color)
31+
end
32+
33+
function SMC.coloring(
34+
A::BandedMatrix,
35+
::ColoringProblem{:nonsymmetric,:row},
36+
::StructuredColoringAlgorithm;
37+
kwargs...,
38+
)
39+
width = length(bandrange(A))
40+
color = cycle_range(width, size(A, 1))
41+
bg = BipartiteGraph(A)
42+
return RowColoringResult(A, bg, color)
43+
end
44+
45+
end
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
module SparseMatrixColoringsBlockBandedMatricesExt
2+
3+
using BlockArrays: blockaxes, blockfirsts, blocklasts, blocksize, blocklengths
4+
using BlockBandedMatrices:
5+
BandedBlockBandedMatrix,
6+
BlockBandedMatrix,
7+
blockbandrange,
8+
blockbandwidths,
9+
blocklengths,
10+
blocksize,
11+
subblockbandwidths
12+
using SparseMatrixColorings:
13+
BipartiteGraph,
14+
ColoringProblem,
15+
ColumnColoringResult,
16+
StructuredColoringAlgorithm,
17+
RowColoringResult,
18+
column_colors,
19+
cycle_range,
20+
row_colors
21+
import SparseMatrixColorings as SMC
22+
23+
#=
24+
This code is partly taken from ArrayInterface.jl and FiniteDiff.jl
25+
https://github.com/JuliaArrays/ArrayInterface.jl
26+
https://github.com/JuliaDiff/FiniteDiff.jl
27+
=#
28+
29+
function subblockbandrange(A::BandedBlockBandedMatrix)
30+
u, l = subblockbandwidths(A)
31+
return (-l):u
32+
end
33+
34+
function blockbanded_coloring(
35+
A::Union{BlockBandedMatrix,BandedBlockBandedMatrix}, dim::Integer
36+
)
37+
# consider blocks of columns or rows (let's call them vertices) depending on `dim`
38+
nb_blocks = blocksize(A, dim)
39+
nb_in_block = blocklengths(axes(A, dim))
40+
first_in_block = blockfirsts(axes(A, dim))
41+
last_in_block = blocklasts(axes(A, dim))
42+
color = zeros(Int, size(A, dim))
43+
44+
# give a macroscopic color to each block, so that 2 blocks with the same macro color are orthogonal
45+
# same idea as for BandedMatrices
46+
nb_macrocolors = length(blockbandrange(A))
47+
macrocolor = cycle_range(nb_macrocolors, nb_blocks)
48+
49+
width = if A isa BandedBlockBandedMatrix
50+
# vertices within a block are colored cleverly using bands
51+
length(subblockbandrange(A))
52+
else
53+
# vertices within a block are colored naively with distinct micro colors (~ infinite band width)
54+
typemax(Int)
55+
end
56+
57+
# for each macroscopic color, count how many microscopic colors will be needed
58+
nb_colors_in_macrocolor = zeros(Int, nb_macrocolors)
59+
for mc in 1:nb_macrocolors
60+
largest_nb_in_macrocolor = maximum(nb_in_block[mc:nb_macrocolors:nb_blocks]; init=0)
61+
nb_colors_in_macrocolor[mc] = min(width, largest_nb_in_macrocolor)
62+
end
63+
color_shift_in_macrocolor = vcat(0, cumsum(nb_colors_in_macrocolor)[1:(end - 1)])
64+
65+
# assign a microscopic color to each column as a function of its macroscopic color and its position within the block
66+
for b in 1:nb_blocks
67+
block_color = cycle_range(width, nb_in_block[b])
68+
shift = color_shift_in_macrocolor[macrocolor[b]]
69+
color[first_in_block[b]:last_in_block[b]] .= shift .+ block_color
70+
end
71+
72+
return color
73+
end
74+
75+
function SMC.coloring(
76+
A::Union{BlockBandedMatrix,BandedBlockBandedMatrix},
77+
::ColoringProblem{:nonsymmetric,:column},
78+
::StructuredColoringAlgorithm;
79+
kwargs...,
80+
)
81+
color = blockbanded_coloring(A, 2)
82+
bg = BipartiteGraph(A)
83+
return ColumnColoringResult(A, bg, color)
84+
end
85+
86+
function SMC.coloring(
87+
A::Union{BlockBandedMatrix,BandedBlockBandedMatrix},
88+
::ColoringProblem{:nonsymmetric,:row},
89+
::StructuredColoringAlgorithm;
90+
kwargs...,
91+
)
92+
color = blockbanded_coloring(A, 1)
93+
bg = BipartiteGraph(A)
94+
return RowColoringResult(A, bg, color)
95+
end
96+
97+
end

ext/SparseMatrixColoringsCUDAExt.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
11
module SparseMatrixColoringsCUDAExt
2-
2+
using LinearAlgebra
33
import SparseMatrixColorings as SMC
44
using SparseArrays: SparseMatrixCSC, rowvals, nnz, nzrange
5-
using CUDA: CuVector, CuMatrix
5+
using CUDA: CuArray, CuVector, CuMatrix
66
using cuSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR
77

8+
## Basic support for GPU sparsity pattern stuff
9+
10+
SMC.SparsityPatternCSC(A::CuSparseMatrixCSC) = SMC.SparsityPatternCSC(first(A.dims), last(A.dims), A.colPtr, A.rowVal)
11+
12+
for R in (:Diagonal, :Bidiagonal, :Tridiagonal)
13+
@eval function SMC.BipartiteGraph(A::$R{T, <:CuArray}; symmetric_pattern::Bool=false) where {T}
14+
return SMC.BipartiteGraph(CuSparseMatrixCSC(A); symmetric_pattern)
15+
end
16+
end
17+
18+
function SMC.BipartiteGraph(A::CuSparseMatrixCSC; symmetric_pattern::Bool=false)
19+
S2 = SMC.SparsityPatternCSC(A)
20+
if symmetric_pattern
21+
checksquare(A) # proxy for checking full symmetry
22+
S1 = S2
23+
else
24+
S1 = transpose(S2) # rows to columns
25+
end
26+
return SMC.BipartiteGraph(S1, S2)
27+
end
28+
829
## CSC Result
930

1031
function SMC.ColumnColoringResult(

src/SparseMatrixColorings.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ using Base.Iterators: Iterators
1414
using DocStringExtensions: README, EXPORTS, SIGNATURES, TYPEDEF, TYPEDFIELDS
1515
using LinearAlgebra:
1616
Adjoint,
17+
Bidiagonal,
1718
Diagonal,
1819
Hermitian,
1920
LowerTriangular,
2021
Symmetric,
2122
Transpose,
23+
Tridiagonal,
2224
UpperTriangular,
2325
adjoint,
2426
checksquare,
@@ -54,6 +56,7 @@ include("interface.jl")
5456
include("constant.jl")
5557
include("adtypes.jl")
5658
include("decompression.jl")
59+
include("structured.jl")
5760
include("check.jl")
5861
include("examples.jl")
5962
include("show_colors.jl")
@@ -65,7 +68,7 @@ export NaturalOrder, RandomOrder, LargestFirst
6568
export DynamicDegreeBasedOrder, SmallestLast, IncidenceDegree, DynamicLargestFirst
6669
export PerfectEliminationOrder
6770
export ColoringProblem, GreedyColoringAlgorithm, AbstractColoringResult
68-
export ConstantColoringAlgorithm
71+
export ConstantColoringAlgorithm, StructuredColoringAlgorithm
6972
export OptimalColoringAlgorithm
7073
export coloring, fast_coloring
7174
export column_colors, row_colors, ncolors

src/graph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ Copied from `SparseMatrixCSC`:
1717
struct SparsityPatternCSC{Ti<:Integer} <: AbstractMatrix{Bool}
1818
m::Int
1919
n::Int
20-
colptr::Vector{Ti}
21-
rowval::Vector{Ti}
20+
colptr::AbstractVector{Ti}
21+
rowval::AbstractVector{Ti}
2222
end
2323

2424
SparsityPatternCSC(A::SparseMatrixCSC) = SparsityPatternCSC(A.m, A.n, A.colptr, A.rowval)

0 commit comments

Comments
 (0)