Skip to content

Commit 80d7c67

Browse files
committed
Add some matrix functions for symm/herm
1 parent 2e983fe commit 80d7c67

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

lib/cusolver/linalg.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,30 @@ function LinearAlgebra.eigen(A::CuMatrix{T}) where {T<:BlasComplex}
140140
ishermitian(A) ? Eigen(heevd!('V', 'U', A2)...) : error("GPU eigensolver supports only Hermitian or Symmetric matrices.")
141141
end
142142

143+
# matrix functions
144+
for func in (:(Base.exp), :(Base.cos), :(Base.sin), :(Base.tan), :(Base.cosh), :(Base.sinh), :(Base.tanh), :(Base.atan), :(Base.asinh), :(Base.atanh), :(Base.cbrt))
145+
@eval begin
146+
function ($func)(A::Symmetric{T, <:StridedCuMatrix}) where {T<:BlasReal}
147+
F = eigen(A)
148+
return Symmetric((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
149+
end
150+
function ($func)(A::Hermitian{T, <:StridedCuMatrix}) where {T<:BlasReal}
151+
F = eigen(A)
152+
return Hermitian((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
153+
end
154+
function ($func)(A::Hermitian{<:Complex, <:StridedCuMatrix})
155+
F = eigen(A)
156+
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
157+
@static if VERSION >= v"1.11"
158+
d_ixs = diagind(retmat, IndexStyle(retmat))
159+
else
160+
d_ixs = diagind(retmat)
161+
end
162+
@. retmat[d_ixs] = real(retmat[d_ixs])
163+
return Hermitian(retmat)
164+
end
165+
end
166+
end
143167

144168
# factorizations
145169

test/libraries/cusolver/dense.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,3 +945,19 @@ end
945945
@inferred d_A \ d_b
946946
end
947947
end
948+
949+
@testset "Hermitian/Symmetric matrix functions, elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
950+
A = rand(elty, n, n)
951+
Ah = A * A' # make posdef for atan, asinh, atanh
952+
d_Ah = CuArray(Ah)
953+
@testset for func in (exp, cos, sin, tan, cosh, sinh, tanh, atan, asinh)
954+
@test Array(func(d_Ah)) func(Ah)
955+
end
956+
@static if VERSION >= v"1.11.0" # not supported on 1.10 or for Complex
957+
if elty <: Real
958+
@testset for func in (cbrt,) # have to dispatch explicitly
959+
@test Array(parent(func(Hermitian(d_Ah)))) func(Ah)
960+
end
961+
end
962+
end
963+
end

0 commit comments

Comments
 (0)