Skip to content

Commit 2d633e2

Browse files
authored
Fix and test broken quad! and invquad! (#190)
* Fix and test broken `quad!` and `invquad!` * Fix spurious test failures * Better fixes for older Julia versions * Address PR review
1 parent c50b8a0 commit 2d633e2

File tree

7 files changed

+67
-34
lines changed

7 files changed

+67
-34
lines changed

src/PDMats.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ module PDMats
3939

4040
# source files
4141

42-
include("chol.jl")
4342
include("utils.jl")
43+
include("chol.jl")
4444

4545
include("pdmat.jl")
4646

src/chol.jl

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,45 +46,73 @@ for T in (:AbstractVector, :AbstractMatrix)
4646
end
4747

4848
# quad
49-
quad(A::Cholesky, x::AbstractVector) = sum(abs2, chol_upper(A) * x)
49+
function quad(A::Cholesky, x::AbstractVector)
50+
@check_argdims size(A, 1) == length(x)
51+
return sum(abs2, chol_upper(A) * x)
52+
end
5053
function quad(A::Cholesky, X::AbstractMatrix)
54+
@check_argdims size(A, 1) == size(X, 1)
5155
Z = chol_upper(A) * X
5256
return vec(sum(abs2, Z; dims=1))
5357
end
5458
function quad!(r::AbstractArray, A::Cholesky, X::AbstractMatrix)
55-
Z = chol_upper(A) * X
56-
return map!(Base.Fix1(sum, abs2), r, eachcol(Z))
59+
@check_argdims eachindex(r) == axes(X, 2)
60+
@check_argdims size(A, 1) == size(X, 1)
61+
aU = chol_upper(A)
62+
z = similar(r, size(A, 1)) # buffer to save allocations
63+
@inbounds for i in axes(X, 2)
64+
copyto!(z, view(X, :, i))
65+
lmul!(aU, z)
66+
r[i] = sum(abs2, z)
67+
end
68+
return r
5769
end
5870

5971
# invquad
60-
invquad(A::Cholesky, x::AbstractVector) = sum(abs2, chol_lower(A) \ x)
61-
function invquad(A::Cholesky, X::AbstractMatrix)
72+
function invquad(A::Cholesky, x::AbstractVector)
73+
@check_argdims size(A, 1) == size(x, 1)
74+
return sum(abs2, chol_lower(A) \ x)
75+
end
76+
function invquad(A::Cholesky, X::AbstractMatrix)
77+
@check_argdims size(A, 1) == size(X, 1)
6278
Z = chol_lower(A) \ X
6379
return vec(sum(abs2, Z; dims=1))
6480
end
6581
function invquad!(r::AbstractArray, A::Cholesky, X::AbstractMatrix)
66-
Z = chol_lower(A) * X
67-
return map!(Base.Fix1(sum, abs2), r, eachcol(Z))
82+
@check_argdims eachindex(r) == axes(X, 2)
83+
@check_argdims size(A, 1) == size(X, 1)
84+
aL = chol_lower(A)
85+
z = similar(r, size(A, 1)) # buffer to save allocations
86+
@inbounds for i in axes(X, 2)
87+
copyto!(z, view(X, :, i))
88+
ldiv!(aL, z)
89+
r[i] = sum(abs2, z)
90+
end
91+
return r
6892
end
6993

7094
# tri products
7195

7296
function X_A_Xt(A::Cholesky, X::AbstractMatrix)
97+
@check_argdims size(A, 1) == size(X, 2)
7398
Z = X * chol_lower(A)
7499
return Z * transpose(Z)
75100
end
76101

77102
function Xt_A_X(A::Cholesky, X::AbstractMatrix)
103+
@check_argdims size(A, 1) == size(X, 1)
78104
Z = chol_upper(A) * X
79105
return transpose(Z) * Z
80106
end
81107

82108
function X_invA_Xt(A::Cholesky, X::AbstractMatrix)
109+
@check_argdims size(A, 1) == size(X, 2)
83110
Z = X / chol_upper(A)
84111
return Z * transpose(Z)
85112
end
86113

87114
function Xt_invA_X(A::Cholesky, X::AbstractMatrix)
115+
@check_argdims size(A, 1) == size(X, 1)
88116
Z = chol_lower(A) \ X
89117
return transpose(Z) * Z
90118
end

src/pdmat.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ function quad(a::PDMat, x::AbstractVecOrMat)
123123
end
124124

125125
function quad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
126-
@check_argdims axes(r) == axes(x, 2)
126+
@check_argdims eachindex(r) == axes(x, 2)
127127
@check_argdims a.dim == size(x, 1)
128128
aU = chol_upper(cholesky(a))
129129
z = similar(r, a.dim) # buffer to save allocations
@@ -146,7 +146,7 @@ function invquad(a::PDMat, x::AbstractVecOrMat)
146146
end
147147

148148
function invquad!(r::AbstractArray, a::PDMat, x::AbstractMatrix)
149-
@check_argdims axes(r) == axes(x, 2)
149+
@check_argdims eachindex(r) == axes(x, 2)
150150
@check_argdims a.dim == size(x, 1)
151151
aL = chol_lower(cholesky(a))
152152
z = similar(r, a.dim) # buffer to save allocations

src/pdsparsemat.jl

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,14 @@ function quad(a::PDSparseMat, x::AbstractVecOrMat)
122122
end
123123

124124
function quad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix)
125-
@check_argdims axes(r) == axes(x, 2)
126-
# https://github.com/JuliaLang/julia/commit/2425ae760fb5151c5c7dd0554e87c5fc9e24de73
127-
if VERSION < v"1.4.0-DEV.92"
128-
z = similar(r, a.dim) # buffer to save allocations
129-
@inbounds for i in axes(x, 2)
130-
xi = view(x, :, i)
131-
copyto!(z, xi)
132-
lmul!(a.mat, z)
133-
r[i] = dot(xi, z)
134-
end
135-
else
136-
@inbounds for i in axes(x, 2)
137-
xi = view(x, :, i)
125+
@check_argdims eachindex(r) == axes(x, 2)
126+
@inbounds for i in axes(x, 2)
127+
xi = view(x, :, i)
128+
# https://github.com/JuliaLang/julia/commit/2425ae760fb5151c5c7dd0554e87c5fc9e24de73
129+
if VERSION < v"1.4.0-DEV.92"
130+
# Can't use `lmul!` with buffer due to missing support in SparseArrays
131+
r[i] = dot(xi, a.mat * xi)
132+
else
138133
r[i] = dot(xi, a.mat, xi)
139134
end
140135
end
@@ -148,14 +143,12 @@ function invquad(a::PDSparseMat, x::AbstractVecOrMat)
148143
end
149144

150145
function invquad!(r::AbstractArray, a::PDSparseMat, x::AbstractMatrix)
151-
@check_argdims axes(r) == axes(x, 2)
146+
@check_argdims eachindex(r) == axes(x, 2)
152147
@check_argdims a.dim == size(x, 1)
153-
z = similar(r, a.dim) # buffer to save allocations
148+
# Can't use `ldiv!` with buffer due to missing support in SparseArrays
154149
@inbounds for i in axes(x, 2)
155150
xi = view(x, :, i)
156-
copyto!(z, xi)
157-
ldiv!(a.chol, z)
158-
r[i] = dot(xi, z)
151+
r[i] = dot(xi, a.chol \ xi)
159152
end
160153
return r
161154
end

src/scalmat.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ end
115115
function quad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix)
116116
@check_argdims eachindex(r) == axes(x, 2)
117117
@check_argdims a.dim == size(x, 1)
118-
return map!(Base.Fix1(quad, a), r, eachcol(x))
118+
@inbounds for i in axes(x, 2)
119+
r[i] = quad(a, view(x, :, i))
120+
end
121+
return r
119122
end
120123

121124
function invquad(a::ScalMat, x::AbstractVecOrMat)
@@ -135,7 +138,10 @@ end
135138
function invquad!(r::AbstractArray, a::ScalMat, x::AbstractMatrix)
136139
@check_argdims eachindex(r) == axes(x, 2)
137140
@check_argdims a.dim == size(x, 1)
138-
return map!(Base.Fix1(invquad, a), r, eachcol(x))
141+
@inbounds for i in axes(x, 2)
142+
r[i] = invquad(a, view(x, :, i))
143+
end
144+
return r
139145
end
140146

141147

test/specialarrays.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,16 @@ using StaticArrays
4747
@test A \ Y Matrix(A) \ Matrix(Y)
4848

4949
@test whiten(A, x) isa SVector{4, Float64}
50-
@test whiten(A, x) cholesky(Matrix(A)).L \ Vector(x)
50+
@test whiten(A, x) cholesky(Symmetric(Matrix(A))).L \ Vector(x)
5151

5252
@test whiten(A, Y) isa SMatrix{4, 10, Float64}
53-
@test whiten(A, Y) cholesky(Matrix(A)).L \ Matrix(Y)
53+
@test whiten(A, Y) cholesky(Symmetric(Matrix(A))).L \ Matrix(Y)
5454

5555
@test unwhiten(A, x) isa SVector{4, Float64}
56-
@test unwhiten(A, x) cholesky(Matrix(A)).L * Vector(x)
56+
@test unwhiten(A, x) cholesky(Symmetric(Matrix(A))).L * Vector(x)
5757

5858
@test unwhiten(A, Y) isa SMatrix{4, 10, Float64}
59-
@test unwhiten(A, Y) cholesky(Matrix(A)).L * Matrix(Y)
59+
@test unwhiten(A, Y) cholesky(Symmetric(Matrix(A))).L * Matrix(Y)
6060

6161
@test quad(A, x) isa Float64
6262
@test quad(A, x) Vector(x)' * Matrix(A) * Vector(x)

test/testutils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ function pdtest_quad(C, Cmat::Matrix, Imat::Matrix, X::Matrix, verbose::Int)
272272
@test quad(C, view(X,:,i)) r_quad[i]
273273
end
274274
@test quad(C, X) r_quad
275+
r = similar(r_quad)
276+
@test quad!(r, C, X) === r
277+
@test r r_quad
275278

276279
_pdt(verbose, "invquad")
277280
r_invquad = zeros(eltype(C),n)
@@ -282,6 +285,9 @@ function pdtest_quad(C, Cmat::Matrix, Imat::Matrix, X::Matrix, verbose::Int)
282285
@test invquad(C, view(X,:,i)) r_invquad[i]
283286
end
284287
@test invquad(C, X) r_invquad
288+
r = similar(r_invquad)
289+
@test invquad!(r, C, X) === r
290+
@test r r_invquad
285291
end
286292

287293

0 commit comments

Comments
 (0)