Skip to content

Commit 283d865

Browse files
authored
Ensure that the output of X_A_Xt etc. is symmetric (#191)
1 parent 187f741 commit 283d865

File tree

7 files changed

+70
-71
lines changed

7 files changed

+70
-71
lines changed

src/chol.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,26 +93,26 @@ end
9393

9494
# tri products
9595

96-
function X_A_Xt(A::Cholesky, X::AbstractMatrix)
96+
function X_A_Xt(A::Cholesky, X::AbstractMatrix{<:Real})
9797
@check_argdims size(A, 1) == size(X, 2)
9898
Z = X * chol_lower(A)
99-
return Z * transpose(Z)
99+
return Symmetric(Z * transpose(Z))
100100
end
101101

102-
function Xt_A_X(A::Cholesky, X::AbstractMatrix)
102+
function Xt_A_X(A::Cholesky, X::AbstractMatrix{<:Real})
103103
@check_argdims size(A, 1) == size(X, 1)
104104
Z = chol_upper(A) * X
105-
return transpose(Z) * Z
105+
return Symmetric(transpose(Z) * Z)
106106
end
107107

108-
function X_invA_Xt(A::Cholesky, X::AbstractMatrix)
108+
function X_invA_Xt(A::Cholesky, X::AbstractMatrix{<:Real})
109109
@check_argdims size(A, 1) == size(X, 2)
110110
Z = X / chol_upper(A)
111-
return Z * transpose(Z)
111+
return Symmetric(Z * transpose(Z))
112112
end
113113

114-
function Xt_invA_X(A::Cholesky, X::AbstractMatrix)
114+
function Xt_invA_X(A::Cholesky, X::AbstractMatrix{<:Real})
115115
@check_argdims size(A, 1) == size(X, 1)
116116
Z = chol_lower(A) \ X
117-
return transpose(Z) * Z
117+
return Symmetric(transpose(Z) * Z)
118118
end

src/pdiagmat.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,28 +166,28 @@ end
166166

167167
### tri products
168168

169-
function X_A_Xt(a::PDiagMat, x::AbstractMatrix)
169+
function X_A_Xt(a::PDiagMat, x::AbstractMatrix{<:Real})
170170
@check_argdims a.dim == size(x, 2)
171171
z = a.diag .* transpose(x)
172-
return x * z
172+
return Symmetric(x * z)
173173
end
174174

175-
function Xt_A_X(a::PDiagMat, x::AbstractMatrix)
175+
function Xt_A_X(a::PDiagMat, x::AbstractMatrix{<:Real})
176176
@check_argdims a.dim == size(x, 1)
177177
z = a.diag .* x
178-
return transpose(x) * z
178+
return Symmetric(transpose(x) * z)
179179
end
180180

181-
function X_invA_Xt(a::PDiagMat, x::AbstractMatrix)
181+
function X_invA_Xt(a::PDiagMat, x::AbstractMatrix{<:Real})
182182
@check_argdims a.dim == size(x, 2)
183183
z = transpose(x) ./ a.diag
184-
return x * z
184+
return Symmetric(x * z)
185185
end
186186

187-
function Xt_invA_X(a::PDiagMat, x::AbstractMatrix)
187+
function Xt_invA_X(a::PDiagMat, x::AbstractMatrix{<:Real})
188188
@check_argdims a.dim == size(x, 1)
189189
z = x ./ a.diag
190-
return transpose(x) * z
190+
return Symmetric(transpose(x) * z)
191191
end
192192

193193
### Specializations for `Array` arguments with reduced allocations

src/pdmat.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,28 +160,28 @@ end
160160

161161
### tri products
162162

163-
function X_A_Xt(a::PDMat, x::AbstractMatrix)
163+
function X_A_Xt(a::PDMat, x::AbstractMatrix{<:Real})
164164
@check_argdims a.dim == size(x, 2)
165165
z = x * chol_lower(a.chol)
166-
return z * transpose(z)
166+
return Symmetric(z * transpose(z))
167167
end
168168

169-
function Xt_A_X(a::PDMat, x::AbstractMatrix)
169+
function Xt_A_X(a::PDMat, x::AbstractMatrix{<:Real})
170170
@check_argdims a.dim == size(x, 1)
171171
z = chol_upper(a.chol) * x
172-
return transpose(z) * z
172+
return Symmetric(transpose(z) * z)
173173
end
174174

175-
function X_invA_Xt(a::PDMat, x::AbstractMatrix)
175+
function X_invA_Xt(a::PDMat, x::AbstractMatrix{<:Real})
176176
@check_argdims a.dim == size(x, 2)
177177
z = x / chol_upper(a.chol)
178-
return z * transpose(z)
178+
return Symmetric(z * transpose(z))
179179
end
180180

181-
function Xt_invA_X(a::PDMat, x::AbstractMatrix)
181+
function Xt_invA_X(a::PDMat, x::AbstractMatrix{<:Real})
182182
@check_argdims a.dim == size(x, 1)
183183
z = chol_lower(a.chol) \ x
184-
return transpose(z) * z
184+
return Symmetric(transpose(z) * z)
185185
end
186186

187187
### Specializations for `Array` arguments with reduced allocations

src/pdsparsemat.jl

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -156,33 +156,28 @@ end
156156

157157
### tri products
158158

159-
function X_A_Xt(a::PDSparseMat, x::AbstractMatrix)
160-
# `*` is not defined for `PtL` factor components,
161-
# so we can't use `x * chol_lower(a.chol)`
162-
C = a.chol
163-
PtL = sparse(C.L)[C.p, :]
164-
z = x * PtL
165-
z * transpose(z)
159+
function X_A_Xt(a::PDSparseMat, x::AbstractMatrix{<:Real})
160+
@check_argdims a.dim == size(x, 2)
161+
z = a.mat * transpose(x)
162+
return Symmetric(x * z)
166163
end
167164

168165

169-
function Xt_A_X(a::PDSparseMat, x::AbstractMatrix)
170-
# `*` is not defined for `UP` factor components,
171-
# so we can't use `chol_upper(a.chol) * x`
172-
# Moreover, `sparse` is only defined for `L` factor components
173-
C = a.chol
174-
UP = transpose(sparse(C.L))[:, C.p]
175-
z = UP * x
176-
transpose(z) * z
166+
function Xt_A_X(a::PDSparseMat, x::AbstractMatrix{<:Real})
167+
@check_argdims a.dim == size(x, 1)
168+
z = a.mat * x
169+
return Symmetric(transpose(x) * z)
177170
end
178171

179172

180-
function X_invA_Xt(a::PDSparseMat, x::AbstractMatrix)
173+
function X_invA_Xt(a::PDSparseMat, x::AbstractMatrix{<:Real})
174+
@check_argdims a.dim == size(x, 2)
181175
z = a.chol \ collect(transpose(x))
182-
x * z
176+
return Symmetric(x * z)
183177
end
184178

185-
function Xt_invA_X(a::PDSparseMat, x::AbstractMatrix)
179+
function Xt_invA_X(a::PDSparseMat, x::AbstractMatrix{<:Real})
180+
@check_argdims a.dim == size(x, 1)
186181
z = a.chol \ x
187-
transpose(x) * z
182+
return Symmetric(transpose(x) * z)
188183
end

src/scalmat.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,43 +147,43 @@ end
147147

148148
### tri products
149149

150-
function X_A_Xt(a::ScalMat, x::AbstractMatrix)
150+
function X_A_Xt(a::ScalMat, x::AbstractMatrix{<:Real})
151151
@check_argdims LinearAlgebra.checksquare(a) == size(x, 2)
152-
a.value * (x * transpose(x))
152+
return Symmetric(a.value * (x * transpose(x)))
153153
end
154154

155-
function Xt_A_X(a::ScalMat, x::AbstractMatrix)
155+
function Xt_A_X(a::ScalMat, x::AbstractMatrix{<:Real})
156156
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
157-
a.value * (transpose(x) * x)
157+
return Symmetric(a.value * (transpose(x) * x))
158158
end
159159

160-
function X_invA_Xt(a::ScalMat, x::AbstractMatrix)
160+
function X_invA_Xt(a::ScalMat, x::AbstractMatrix{<:Real})
161161
@check_argdims LinearAlgebra.checksquare(a) == size(x, 2)
162-
(x * transpose(x)) / a.value
162+
return Symmetric((x * transpose(x)) / a.value)
163163
end
164164

165-
function Xt_invA_X(a::ScalMat, x::AbstractMatrix)
165+
function Xt_invA_X(a::ScalMat, x::AbstractMatrix{<:Real})
166166
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
167-
(transpose(x) * x) / a.value
167+
return Symmetric((transpose(x) * x) / a.value)
168168
end
169169

170170
# Specializations for `x::Matrix` with reduced allocations
171-
function X_A_Xt(a::ScalMat, x::Matrix)
171+
function X_A_Xt(a::ScalMat, x::Matrix{<:Real})
172172
@check_argdims LinearAlgebra.checksquare(a) == size(x, 2)
173-
lmul!(a.value, x * transpose(x))
173+
return Symmetric(lmul!(a.value, x * transpose(x)))
174174
end
175175

176-
function Xt_A_X(a::ScalMat, x::Matrix)
176+
function Xt_A_X(a::ScalMat, x::Matrix{<:Real})
177177
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
178-
lmul!(a.value, transpose(x) * x)
178+
return Symmetric(lmul!(a.value, transpose(x) * x))
179179
end
180180

181-
function X_invA_Xt(a::ScalMat, x::Matrix)
181+
function X_invA_Xt(a::ScalMat, x::Matrix{<:Real})
182182
@check_argdims LinearAlgebra.checksquare(a) == size(x, 2)
183-
_rdiv!(x * transpose(x), a.value)
183+
return Symmetric(_rdiv!(x * transpose(x), a.value))
184184
end
185185

186-
function Xt_invA_X(a::ScalMat, x::Matrix)
186+
function Xt_invA_X(a::ScalMat, x::Matrix{<:Real})
187187
@check_argdims LinearAlgebra.checksquare(a) == size(x, 1)
188-
_rdiv!(transpose(x) * x, a.value)
188+
return Symmetric(_rdiv!(transpose(x) * x, a.value))
189189
end

test/specialarrays.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,16 @@ using StaticArrays
7070
@test invquad(A, Y) isa SVector{10, Float64}
7171
@test invquad(A, Y) diag(Matrix(Y)' * (Matrix(A) \ Matrix(Y)))
7272

73-
@test X_A_Xt(A, X) isa SMatrix{10, 10, Float64}
73+
@test X_A_Xt(A, X) isa Symmetric{Float64,<:SMatrix{10, 10, Float64}}
7474
@test X_A_Xt(A, X) Matrix(X) * Matrix(A) * Matrix(X)'
7575

76-
@test X_invA_Xt(A, X) isa SMatrix{10, 10, Float64}
76+
@test X_invA_Xt(A, X) isa Symmetric{Float64,<:SMatrix{10, 10, Float64}}
7777
@test X_invA_Xt(A, X) Matrix(X) * (Matrix(A) \ Matrix(X)')
7878

79-
@test Xt_A_X(A, Y) isa SMatrix{10, 10, Float64}
79+
@test Xt_A_X(A, Y) isa Symmetric{Float64,<:SMatrix{10, 10, Float64}}
8080
@test Xt_A_X(A, Y) Matrix(Y)' * Matrix(A) * Matrix(Y)
8181

82-
@test Xt_invA_X(A, Y) isa SMatrix{10, 10, Float64}
82+
@test Xt_invA_X(A, Y) isa Symmetric{Float64,<:SMatrix{10, 10, Float64}}
8383
@test Xt_invA_X(A, Y) Matrix(Y)' * (Matrix(A) \ Matrix(Y))
8484
end
8585
end

test/testutils.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,23 +297,27 @@ function pdtest_triprod(C, Cmat::Matrix, Imat::Matrix, X::Matrix, verbose::Int)
297297
Xt = copy(transpose(X))
298298

299299
_pdt(verbose, "X_A_Xt")
300-
# default tolerance in isapprox is different on 0.4. rtol argument can be deleted
301-
# ≈ form used when 0.4 is no longer supported
302-
lhs, rhs = X_A_Xt(C, Xt), Xt * Cmat * X
303-
@test isapprox(lhs, rhs, rtol=sqrt(max(eps(real(float(eltype(lhs)))), eps(real(float(eltype(rhs)))))))
300+
M = X_A_Xt(C, Xt)
301+
@test M Xt * Cmat * X
302+
@test issymmetric(M)
304303
@test_throws DimensionMismatch X_A_Xt(C, rand(n, d + 1))
305304

306305
_pdt(verbose, "Xt_A_X")
307-
lhs, rhs = Xt_A_X(C, X), Xt * Cmat * X
308-
@test isapprox(lhs, rhs, rtol=sqrt(max(eps(real(float(eltype(lhs)))), eps(real(float(eltype(rhs)))))))
306+
M = Xt_A_X(C, X)
307+
@test M Xt * Cmat * X
308+
@test issymmetric(M)
309309
@test_throws DimensionMismatch Xt_A_X(C, rand(d + 1, n))
310310

311311
_pdt(verbose, "X_invA_Xt")
312-
@test X_invA_Xt(C, Xt) Xt * Imat * X
312+
M = X_invA_Xt(C, Xt)
313+
@test M Xt * Imat * X
314+
@test issymmetric(M)
313315
@test_throws DimensionMismatch X_invA_Xt(C, rand(n, d + 1))
314316

315317
_pdt(verbose, "Xt_invA_X")
316-
@test Xt_invA_X(C, X) Xt * Imat * X
318+
M = Xt_invA_X(C, X)
319+
@test M Xt * Imat * X
320+
@test issymmetric(M)
317321
@test_throws DimensionMismatch Xt_invA_X(C, rand(d + 1, n))
318322
end
319323

0 commit comments

Comments
 (0)