Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 0 additions & 108 deletions test/cuda/eig.jl

This file was deleted.

156 changes: 37 additions & 119 deletions test/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,127 +4,45 @@ using TestExtras
using StableRNGs
using LinearAlgebra: Diagonal
using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm
using CUDA, AMDGPU

BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
GenericFloats = (Float16, BigFloat, Complex{BigFloat})

@testset "eig_full! for T = $T" for T in BLASFloats
rng = StableRNG(123)
m = 54
for alg in (LAPACK_Simple(), LAPACK_Expert(), :LAPACK_Simple, LAPACK_Simple)
A = randn(rng, T, m, m)
Tc = complex(T)

D, V = @constinferred eig_full(A; alg = ($alg))
@test eltype(D) == eltype(V) == Tc
@test A * V ≈ V * D

alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg)

Ac = similar(A)
D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′)
@test D2 === D
@test V2 === V
@test A * V ≈ V * D

Dc = @constinferred eig_vals(A, alg′)
@test eltype(Dc) == Tc
@test D ≈ Diagonal(Dc)
GenericFloats = (BigFloat, Complex{BigFloat})

@isdefined(TestSuite) || include("testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 54
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if T ∈ BLASFloats
if CUDA.functional()
TestSuite.test_eig(CuMatrix{T}, (m, m); test_trunc = false)
TestSuite.test_eig_algs(CuMatrix{T}, (m, m), (CUSOLVER_Simple(),); test_trunc = false)
TestSuite.test_eig(Diagonal{T, CuVector{T}}, m; test_trunc = false)
TestSuite.test_eig_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false)
end
#= not yet supported
if AMDGPU.functional()
TestSuite.test_eig(ROCMatrix{T}, (m, m); test_blocksize = false)
TestSuite.test_eig_algs(ROCMatrix{T}, (m, m), (ROCSOLVER_Simple(),))
TestSuite.test_eig(Diagonal{T, ROCVector{T}}, m; test_blocksize = false)
TestSuite.test_eig_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),))
end=#
end
end

@testset "eig_trunc! for T = $T" for T in BLASFloats
rng = StableRNG(123)
m = 54
for alg in (LAPACK_Simple(), LAPACK_Expert())
A = randn(rng, T, m, m)
A *= A' # TODO: deal with eigenvalue ordering etc
# eigenvalues are sorted by ascending real component...
D₀ = sort!(eig_vals(A); by = abs, rev = true)
rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
r = length(D₀) - rmin
atol = sqrt(eps(real(T)))

D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r))
@test length(diagview(D1)) == r
@test A * V1 ≈ V1 * D1
@test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 + sqrt(eps(real(T)))
trunc = trunctol(; atol = s * abs(D₀[r + 1]))
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc)
@test length(diagview(D2)) == r
@test A * V2 ≈ V2 * D2
@test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 - sqrt(eps(real(T)))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc)
@test length(diagview(D3)) == r
@test A * V3 ≈ V3 * D3
@test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 - sqrt(eps(real(T)))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D4, V4 = @constinferred eig_trunc_no_error(A; alg, trunc)
@test length(diagview(D4)) == r
@test A * V4 ≈ V4 * D4
# trunctol keeps order, truncrank might not
# test for same subspace
@test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2
@test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1
@test V1 * ((V1' * V1) \ (V1' * V3)) ≈ V3
@test V3 * ((V3' * V3) \ (V3' * V1)) ≈ V1
if !is_buildkite
TestSuite.test_eig(T, (m, m))
if T ∈ BLASFloats
LAPACK_EIG_ALGS = (LAPACK_Simple(), LAPACK_Expert())
TestSuite.test_eig_algs(T, (m, m), LAPACK_EIG_ALGS)
elseif T ∈ GenericFloats
GS_EIG_ALGS = (GS_QRIteration(),)
TestSuite.test_eig_algs(T, (m, m), GS_EIG_ALGS)
end
AT = Diagonal{T, Vector{T}}
TestSuite.test_eig(AT, m)
TestSuite.test_eig_algs(AT, m, (DiagonalAlgorithm(),))
end
end

@testset "eig_trunc! specify truncation algorithm T = $T" for T in BLASFloats
rng = StableRNG(123)
m = 4
atol = sqrt(eps(real(T)))
V = randn(rng, T, m, m)
D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01])
A = V * D * inv(V)
alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2))
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg)
@test diagview(D2) ≈ diagview(D)[1:2]
@test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol
@test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2))

alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1))
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg)
@test diagview(D3) ≈ diagview(D)[1:2]
@test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol

alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1))
D4, V4 = @constinferred eig_trunc_no_error(A; alg)
@test diagview(D4) ≈ diagview(D)[1:2]
end

@testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
rng = StableRNG(123)
m = 54
Ad = randn(rng, T, m)
A = Diagonal(Ad)
atol = sqrt(eps(real(T)))

D, V = @constinferred eig_full(A)
@test D isa Diagonal{T} && size(D) == size(A)
@test V isa Diagonal{T} && size(V) == size(A)
@test A * V ≈ V * D

D2 = @constinferred eig_vals(A)
@test D2 isa AbstractVector{T} && length(D2) == m
@test diagview(D) ≈ D2

A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01])
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg)
@test diagview(D2) ≈ diagview(A2)[1:2]
@test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol

A3 = Diagonal(T[0.9, 0.3, 0.1, 0.01])
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
D3, V3 = @constinferred eig_trunc_no_error(A3; alg)
@test diagview(D3) ≈ diagview(A3)[1:2]
end
10 changes: 3 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ if !is_buildkite
@safetestset "Hermitian Eigenvalue Decomposition" begin
include("eigh.jl")
end
@safetestset "General Eigenvalue Decomposition" begin
include("eig.jl")
end
@safetestset "Generalized Eigenvalue Decomposition" begin
include("gen_eig.jl")
end
Expand Down Expand Up @@ -51,7 +48,6 @@ if !is_buildkite
@safetestset "Hermitian Eigenvalue Decomposition" begin
include("genericlinearalgebra/eigh.jl")
end

end

@safetestset "QR / LQ Decomposition" begin
Expand All @@ -67,15 +63,15 @@ end
@safetestset "Schur Decomposition" begin
include("schur.jl")
end
@safetestset "General Eigenvalue Decomposition" begin
include("eig.jl")
end

using CUDA
if CUDA.functional()
@safetestset "CUDA SVD" begin
include("cuda/svd.jl")
end
@safetestset "CUDA General Eigenvalue Decomposition" begin
include("cuda/eig.jl")
end
@safetestset "CUDA Hermitian Eigenvalue Decomposition" begin
include("cuda/eigh.jl")
end
Expand Down
1 change: 1 addition & 0 deletions test/testsuite/TestSuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,6 @@ include("lq.jl")
include("polar.jl")
include("projections.jl")
include("schur.jl")
include("eig.jl")

end
Loading