@@ -2,7 +2,41 @@ module ChainRulesCoreSparseArraysExt
22
33using ChainRulesCore
44using ChainRulesCore: project_type, _projection_mismatch
5- using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals
5+ using LinearAlgebra: Hermitian, Symmetric, tril, triu
6+ using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals, getcolptr, nonzeros
7+
8+ const HermSparse{T, I} = Hermitian{T, SparseMatrixCSC{T, I}}
9+ const SymSparse{T, I} = Symmetric{T, SparseMatrixCSC{T, I}}
10+ const HermOrSymSparse{T, I} = Union{HermSparse{T, I}, SymSparse{T, I}}
11+
12+ const SparseProjectToData{T, I} = NamedTuple{
13+ (:element , :axes , :rowval , :nzranges , :colptr ),
14+ Tuple{
15+ ProjectTo{T, NamedTuple{(), Tuple{}}},
16+ Tuple{Base. OneTo{Int64}, Base. OneTo{Int64}},
17+ Vector{I},
18+ Vector{UnitRange{Int64}},
19+ Vector{I},
20+ },
21+ }
22+
23+ const SparseProjectTo{T, I} = ProjectTo{SparseMatrixCSC, SparseProjectToData{T, I}}
24+
25+ const HermSparseProjectTo{T, I} = ProjectTo{
26+ Hermitian,
27+ NamedTuple{
28+ (:uplo , :parent ),
29+ Tuple{Symbol, SparseProjectTo{T, I}},
30+ },
31+ }
32+
33+ const SymSparseProjectTo{T, I} = ProjectTo{
34+ Symmetric,
35+ NamedTuple{
36+ (:uplo , :parent ),
37+ Tuple{Symbol, SparseProjectTo{T, I}},
38+ },
39+ }
640
741ChainRulesCore. is_inplaceable_destination (:: SparseVector ) = true
842ChainRulesCore. is_inplaceable_destination (:: SparseMatrixCSC ) = true
@@ -100,4 +134,152 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
100134 end
101135end
102136
137+ # ####
138+ # #### Hermitian/Symmetric sparse projection
139+ # ####
140+
141+ function project! (A:: SparseMatrixCSC{T, I} , B:: SparseMatrixCSC{<:Any, J} , uplo:: Char ) where {T, I, J}
142+ @assert size (A) == size (B)
143+
144+ @inbounds for j in axes (A, 2 )
145+ p = getcolptr (A)[j]
146+ pstop = getcolptr (A)[j + 1 ]
147+ q = getcolptr (B)[j]
148+ qstop = getcolptr (B)[j + 1 ]
149+
150+ while p < pstop
151+ i = rowvals (A)[p]
152+
153+ if (uplo == ' L' && i >= j) || (uplo == ' U' && i <= j)
154+ while q < qstop && rowvals (B)[q] < i
155+ q += one (J)
156+ end
157+
158+ if q < qstop && rowvals (B)[q] == i
159+ nonzeros (A)[p] = nonzeros (B)[q]
160+ else
161+ nonzeros (A)[p] = zero (T)
162+ end
163+ end
164+
165+ p += one (I)
166+ end
167+ end
168+
169+ return A
170+ end
171+
172+ function project! (A:: HermOrSymSparse , B:: HermOrSymSparse )
173+ if A. uplo == B. uplo
174+ project! (parent (A), parent (B), A. uplo)
175+ elseif A. uplo == ' L'
176+ project! (parent (A), tril (B), A. uplo)
177+ else
178+ project! (parent (A), triu (B), A. uplo)
179+ end
180+
181+ return A
182+ end
183+
184+ function sparse_from_project (P:: SparseProjectTo{T, I} ) where {T, I}
185+ m, n = map (length, P. axes)
186+ return SparseMatrixCSC (m, n, P. colptr, P. rowval, zeros (T, length (P. rowval)))
187+ end
188+
189+ function sparse_from_project (P:: HermSparseProjectTo )
190+ return Hermitian (sparse_from_project (P. parent), P. uplo)
191+ end
192+
193+ function sparse_from_project (P:: SymSparseProjectTo )
194+ return Symmetric (sparse_from_project (P. parent), P. uplo)
195+ end
196+
197+ function checkpatternsym (n, Acolptr:: Vector{IA} , Bcolptr:: Vector{IB} , Arowval:: AbstractVector , Browval:: AbstractVector , uplo:: Char ) where {IA, IB}
198+ for j in 1 : n
199+ pa = Acolptr[j]
200+ pb = Bcolptr[j]
201+ pastop = Acolptr[j + 1 ]
202+ pbstop = Bcolptr[j + 1 ]
203+
204+ while pa < pastop && pb < pbstop
205+ ia = Arowval[pa]
206+ ib = Browval[pb]
207+
208+ if (uplo == ' L' && ia < j) || (uplo == ' U' && ia > j)
209+ pa += one (IA)
210+ elseif (uplo == ' L' && ib < j) || (uplo == ' U' && ib > j)
211+ pb += one (IB)
212+ elseif ia == ib
213+ pa += one (IA)
214+ pb += one (IB)
215+ else
216+ return false
217+ end
218+ end
219+
220+ while pa < pastop
221+ ia = Arowval[pa]
222+
223+ if (uplo == ' L' && ia >= j) || (uplo == ' U' && ia <= j)
224+ return false
225+ end
226+
227+ pa += one (IA)
228+ end
229+
230+ while pb < pbstop
231+ ib = Browval[pb]
232+
233+ if (uplo == ' L' && ib >= j) || (uplo == ' U' && ib <= j)
234+ return false
235+ end
236+
237+ pb += one (IB)
238+ end
239+ end
240+
241+ return true
242+ end
243+
244+ function checkpatternsym (P, dX)
245+ return false
246+ end
247+
248+ function checkpatternsym (P:: Union{HermSparseProjectTo{T, I}, SymSparseProjectTo{T, I}} , dX:: HermOrSymSparse{T, I} ) where {T, I}
249+ dXP = parent (dX)
250+ return Symbol (dX. uplo) == P. uplo && checkpatternsym (size (dXP, 2 ), P. parent. colptr, dXP. colptr, P. parent. rowval, dXP. rowval, dX. uplo)
251+ end
252+
253+ function (P:: HermSparseProjectTo{T, I} )(dX:: HermSparse ) where {T, I}
254+ if checkpatternsym (P, dX)
255+ return dX
256+ else
257+ return project! (sparse_from_project (P), dX)
258+ end
259+ end
260+
261+ function (P:: SymSparseProjectTo{T, I} )(dX:: SymSparse ) where {T, I}
262+ if checkpatternsym (P, dX)
263+ return dX
264+ else
265+ return project! (sparse_from_project (P), dX)
266+ end
267+ end
268+
269+ function (P:: HermSparseProjectTo{T, I} )(dX:: SymSparse{T, I} ) where {T <: Real , I}
270+ if checkpatternsym (P, dX)
271+ return Hermitian (parent (dX), P. uplo)
272+ else
273+ return project! (sparse_from_project (P), dX)
274+ end
275+ end
276+
277+ function (P:: SymSparseProjectTo{T, I} )(dX:: HermSparse{T, I} ) where {T <: Real , I}
278+ if checkpatternsym (P, dX)
279+ return Symmetric (parent (dX), P. uplo)
280+ else
281+ return project! (sparse_from_project (P), dX)
282+ end
283+ end
284+
103285end # module
0 commit comments