Skip to content

Commit d40b7bf

Browse files
Merge pull request #854 from j-fu/fix-abstractsparsecsc
Use nonzeros(A) instead of A.nzval in pattern check
2 parents 5671d6e + f84592d commit d40b7bf

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

ext/LinearSolveSparseArraysExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ function SciMLBase.solve!(
201201
cacheval = LinearSolve.@get_cacheval(cache, :UMFPACKFactorization)
202202
if alg.reuse_symbolic
203203
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
204-
if length(cacheval.nzval) != length(A.nzval) || alg.check_pattern && pattern_changed(cacheval, A)
204+
if length(cacheval.nzval) != length(nonzeros(A)) || alg.check_pattern && pattern_changed(cacheval, A)
205205
fact = lu(
206206
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
207207
nonzeros(A)),
@@ -331,7 +331,7 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::KLUFactorization;
331331
if cache.isfresh
332332
cacheval = LinearSolve.@get_cacheval(cache, :KLUFactorization)
333333
if alg.reuse_symbolic
334-
if length(cacheval.nzval) != length(A.nzval) || alg.check_pattern && pattern_changed(cacheval, A)
334+
if length(cacheval.nzval) != length(nonzeros(A)) || alg.check_pattern && pattern_changed(cacheval, A)
335335
fact = KLU.klu(
336336
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A),
337337
nonzeros(A)),

test/basictests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,25 @@ end
782782
reinit!(cache; A = B1, b = b1)
783783
u = solve!(cache)
784784
@test norm(u - u0, Inf) < 1.0e-8
785+
786+
pr = LinearProblem(B, b)
787+
solver = UMFPACKFactorization()
788+
cache = init(pr, solver)
789+
u = solve!(cache)
790+
@test norm(u - u0, Inf) < 1.0e-8
791+
reinit!(cache; A = B1, b = b1)
792+
u = solve!(cache)
793+
@test norm(u - u0, Inf) < 1.0e-8
794+
795+
pr = LinearProblem(B, b)
796+
solver = KLUFactorization()
797+
cache = init(pr, solver)
798+
u = solve!(cache)
799+
@test norm(u - u0, Inf) < 1.0e-8
800+
reinit!(cache; A = B1, b = b1)
801+
u = solve!(cache)
802+
@test norm(u - u0, Inf) < 1.0e-8
803+
785804
end
786805

787806
@testset "ParallelSolves" begin

0 commit comments

Comments
 (0)