Skip to content

Commit 7730590

Browse files
committed
ProjectTo for symmetric sparse matrices.
1 parent dad0d24 commit 7730590

3 files changed

Lines changed: 242 additions & 2 deletions

File tree

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ ChainRulesCoreSparseArraysExt = "SparseArrays"
2222
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
2323
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2424
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
25+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2526
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2627
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2728
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2829

2930
[targets]
30-
test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "SparseArrays", "StaticArrays"]
31+
test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "Random", "SparseArrays", "StaticArrays"]
3132

3233
[weakdeps]
3334
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

ext/ChainRulesCoreSparseArraysExt.jl

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,41 @@ module ChainRulesCoreSparseArraysExt
22

33
using ChainRulesCore
44
using ChainRulesCore: project_type, _projection_mismatch
5-
using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals
5+
using LinearAlgebra: Hermitian, Symmetric, tril, triu
6+
using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals, getcolptr, nonzeros
7+
8+
const HermSparse{T, I} = Hermitian{T, SparseMatrixCSC{T, I}}
9+
const SymSparse{T, I} = Symmetric{T, SparseMatrixCSC{T, I}}
10+
const HermOrSymSparse{T, I} = Union{HermSparse{T, I}, SymSparse{T, I}}
11+
12+
const SparseProjectToData{T, I} = NamedTuple{
13+
(:element, :axes, :rowval, :nzranges, :colptr),
14+
Tuple{
15+
ProjectTo{T, NamedTuple{(), Tuple{}}},
16+
Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}},
17+
Vector{I},
18+
Vector{UnitRange{Int64}},
19+
Vector{I},
20+
},
21+
}
22+
23+
const SparseProjectTo{T, I} = ProjectTo{SparseMatrixCSC, SparseProjectToData{T, I}}
24+
25+
const HermSparseProjectTo{T, I} = ProjectTo{
26+
Hermitian,
27+
NamedTuple{
28+
(:uplo, :parent),
29+
Tuple{Symbol, SparseProjectTo{T, I}},
30+
},
31+
}
32+
33+
const SymSparseProjectTo{T, I} = ProjectTo{
34+
Symmetric,
35+
NamedTuple{
36+
(:uplo, :parent),
37+
Tuple{Symbol, SparseProjectTo{T, I}},
38+
},
39+
}
640

741
ChainRulesCore.is_inplaceable_destination(::SparseVector) = true
842
ChainRulesCore.is_inplaceable_destination(::SparseMatrixCSC) = true
@@ -100,4 +134,152 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
100134
end
101135
end
102136

137+
#####
138+
##### Hermitian/Symmetric sparse projection
139+
#####
140+
141+
function project!(A::SparseMatrixCSC{T, I}, B::SparseMatrixCSC{<:Any, J}, uplo::Char) where {T, I, J}
142+
@assert size(A) == size(B)
143+
144+
@inbounds for j in axes(A, 2)
145+
p = getcolptr(A)[j]
146+
pstop = getcolptr(A)[j + 1]
147+
q = getcolptr(B)[j]
148+
qstop = getcolptr(B)[j + 1]
149+
150+
while p < pstop
151+
i = rowvals(A)[p]
152+
153+
if (uplo == 'L' && i >= j) || (uplo == 'U' && i <= j)
154+
while q < qstop && rowvals(B)[q] < i
155+
q += one(J)
156+
end
157+
158+
if q < qstop && rowvals(B)[q] == i
159+
nonzeros(A)[p] = nonzeros(B)[q]
160+
else
161+
nonzeros(A)[p] = zero(T)
162+
end
163+
end
164+
165+
p += one(I)
166+
end
167+
end
168+
169+
return A
170+
end
171+
172+
function project!(A::HermOrSymSparse, B::HermOrSymSparse)
173+
if A.uplo == B.uplo
174+
project!(parent(A), parent(B), A.uplo)
175+
elseif A.uplo == 'L'
176+
project!(parent(A), tril(B), A.uplo)
177+
else
178+
project!(parent(A), triu(B), A.uplo)
179+
end
180+
181+
return A
182+
end
183+
184+
function sparse_from_project(P::SparseProjectTo{T, I}) where {T, I}
185+
m, n = map(length, P.axes)
186+
return SparseMatrixCSC(m, n, P.colptr, P.rowval, zeros(T, length(P.rowval)))
187+
end
188+
189+
function sparse_from_project(P::HermSparseProjectTo)
190+
return Hermitian(sparse_from_project(P.parent), P.uplo)
191+
end
192+
193+
function sparse_from_project(P::SymSparseProjectTo)
194+
return Symmetric(sparse_from_project(P.parent), P.uplo)
195+
end
196+
197+
function checkpatternsym(n, Acolptr::Vector{IA}, Bcolptr::Vector{IB}, Arowval::AbstractVector, Browval::AbstractVector, uplo::Char) where {IA, IB}
198+
for j in 1:n
199+
pa = Acolptr[j]
200+
pb = Bcolptr[j]
201+
pastop = Acolptr[j + 1]
202+
pbstop = Bcolptr[j + 1]
203+
204+
while pa < pastop && pb < pbstop
205+
ia = Arowval[pa]
206+
ib = Browval[pb]
207+
208+
if (uplo == 'L' && ia < j) || (uplo == 'U' && ia > j)
209+
pa += one(IA)
210+
elseif (uplo == 'L' && ib < j) || (uplo == 'U' && ib > j)
211+
pb += one(IB)
212+
elseif ia == ib
213+
pa += one(IA)
214+
pb += one(IB)
215+
else
216+
return false
217+
end
218+
end
219+
220+
while pa < pastop
221+
ia = Arowval[pa]
222+
223+
if (uplo == 'L' && ia >= j) || (uplo == 'U' && ia <= j)
224+
return false
225+
end
226+
227+
pa += one(IA)
228+
end
229+
230+
while pb < pbstop
231+
ib = Browval[pb]
232+
233+
if (uplo == 'L' && ib >= j) || (uplo == 'U' && ib <= j)
234+
return false
235+
end
236+
237+
pb += one(IB)
238+
end
239+
end
240+
241+
return true
242+
end
243+
244+
function checkpatternsym(P, dX)
245+
return false
246+
end
247+
248+
function checkpatternsym(P::Union{HermSparseProjectTo{T, I}, SymSparseProjectTo{T, I}}, dX::HermOrSymSparse{T, I}) where {T, I}
249+
dXP = parent(dX)
250+
return Symbol(dX.uplo) == P.uplo && checkpatternsym(size(dXP, 2), P.parent.colptr, dXP.colptr, P.parent.rowval, dXP.rowval, dX.uplo)
251+
end
252+
253+
function (P::HermSparseProjectTo{T, I})(dX::HermSparse) where {T, I}
254+
if checkpatternsym(P, dX)
255+
return dX
256+
else
257+
return project!(sparse_from_project(P), dX)
258+
end
259+
end
260+
261+
function (P::SymSparseProjectTo{T, I})(dX::SymSparse) where {T, I}
262+
if checkpatternsym(P, dX)
263+
return dX
264+
else
265+
return project!(sparse_from_project(P), dX)
266+
end
267+
end
268+
269+
function (P::HermSparseProjectTo{T, I})(dX::SymSparse{T, I}) where {T <: Real, I}
270+
if checkpatternsym(P, dX)
271+
return Hermitian(parent(dX), P.uplo)
272+
else
273+
return project!(sparse_from_project(P), dX)
274+
end
275+
end
276+
277+
function (P::SymSparseProjectTo{T, I})(dX::HermSparse{T, I}) where {T <: Real, I}
278+
if checkpatternsym(P, dX)
279+
return Symmetric(parent(dX), P.uplo)
280+
else
281+
return project!(sparse_from_project(P), dX)
282+
end
283+
end
284+
103285
end # module

test/projection.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ChainRulesCore, Test
22
using LinearAlgebra, SparseArrays
33
using OffsetArrays, StaticArrays, BenchmarkTools
4+
using Random: rand!
45

56
# Like ForwardDiff.jl's Dual
67
struct Dual{T<:Real} <: Real
@@ -355,6 +356,62 @@ struct NoSuperType end
355356
@test_throws DimensionMismatch pm(ones(Int, 5, 20))
356357
end
357358

359+
@testset "SparseArrays: Hermitian/Symmetric" begin
360+
n = 100
361+
362+
function rand_sparse(SymHerm, T, n, uplo; density=0.3)
363+
A = sprand(T, n, n, density)
364+
if uplo == :U
365+
return SymHerm(triu(A), uplo)
366+
else
367+
return SymHerm(tril(A), uplo)
368+
end
369+
end
370+
371+
function rand_tangent(A, uplo=Symbol(A.uplo))
372+
dA = similar(A)
373+
rand!(nonzeros(parent(dA)))
374+
return typeof(A).name.wrapper(parent(dA), uplo)
375+
end
376+
377+
function nzmatch(A, B)
378+
I, J, _ = findnz(parent(A))
379+
return all(A[i, j] == B[i, j] for (i, j) in zip(I, J))
380+
end
381+
382+
@testset "$(SymHerm){$T}, uplo=:$uplo" for
383+
SymHerm in (Symmetric, Hermitian),
384+
T in (Float64, ComplexF64),
385+
uplo in (:U, :L)
386+
387+
A = rand_sparse(SymHerm, T, n, uplo)
388+
P = ProjectTo(A)
389+
390+
# Same pattern
391+
dA = rand_tangent(A)
392+
@test P(dA) == dA
393+
394+
# Different uplo
395+
other = uplo == :U ? :L : :U
396+
dA2 = rand_tangent(A, other)
397+
@test P(dA2) isa SymHerm{T, <:SparseMatrixCSC}
398+
@test nzmatch(P(dA2), dA2)
399+
400+
# Different pattern
401+
B = rand_sparse(SymHerm, T, n, uplo; density=0.5)
402+
@test P(B) isa SymHerm{T, <:SparseMatrixCSC}
403+
@test nzmatch(P(B), B)
404+
end
405+
406+
@testset "Cross-type (real)" begin
407+
AH = rand_sparse(Hermitian, Float64, n, :U)
408+
AS = rand_sparse(Symmetric, Float64, n, :U)
409+
410+
@test ProjectTo(AH)(rand_tangent(AS)) isa Hermitian{Float64, <:SparseMatrixCSC}
411+
@test ProjectTo(AS)(rand_tangent(AH)) isa Symmetric{Float64, <:SparseMatrixCSC}
412+
end
413+
end
414+
358415
#####
359416
##### `OffsetArrays`
360417
#####

0 commit comments

Comments
 (0)