diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index d278b5c5..371b0481 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -22,4 +22,9 @@ function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration) return GenericSchur.eigvals!(A) end +function MatrixAlgebraKit.default_exponential_algorithm(E::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} + eig_alg = MatrixAlgebraKit.default_eig_algorithm(E; kwargs...) + return MatrixFunctionViaEig(eig_alg) +end + end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index fd97497b..81d9d486 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -30,6 +30,7 @@ export left_polar, right_polar export left_polar!, right_polar! export left_orth, right_orth, left_null, right_null export left_orth!, right_orth!, left_null!, right_null! +export exponential, exponential!, exponentiali, exponentiali! export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, @@ -37,6 +38,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration export LQViaTransposedQR export PolarViaSVD, PolarNewton +export MatrixFunctionViaLA, MatrixFunctionViaEig, MatrixFunctionViaEigh export DiagonalAlgorithm export NativeBlocked export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, @@ -81,9 +83,12 @@ include("common/matrixproperties.jl") include("yalapack.jl") include("algorithms.jl") + include("interface/projections.jl") include("interface/decompositions.jl") include("interface/truncation.jl") +include("interface/matrixfunctions.jl") + include("interface/qr.jl") include("interface/lq.jl") include("interface/svd.jl") @@ -93,6 +98,7 @@ include("interface/gen_eig.jl") include("interface/schur.jl") include("interface/polar.jl") include("interface/orthnull.jl") +include("interface/exponential.jl") include("common/gauge.jl") # needs to be defined after the functions are @@ -107,6 +113,7 @@ include("implementations/gen_eig.jl") include("implementations/schur.jl") include("implementations/polar.jl") include("implementations/orthnull.jl") +include("implementations/exponential.jl") include("pullbacks/qr.jl") include("pullbacks/lq.jl") diff --git a/src/common/view.jl b/src/common/view.jl index e03bfb88..b9b4c8cc 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -20,6 +20,26 @@ See also [`diagview`](@ref). diagonal(v::AbstractVector) = Diagonal(v) +""" + map_diagonal!(f, dst, src...) + +Map the scalar function `f` over all elements of the diagonal of `src...`, returning +a diagonal result. + +See also [`map_diagonal!`](@ref). +""" +map_diagonal(f, src, srcs...) = diagonal(f.(diagview(src), map(diagview, srcs)...)) + +""" + map_diagonal!(f, dst, src...) + +Map the scalar function `f` over all elements of the diagonal of `src...`, +into the diagonal elements of destination `dst`. + +See also [`map_diagonal`](@ref). +""" +map_diagonal!(f, dst, src, srcs...) = (diagview(dst) .= f.(diagview(src), map(diagview, srcs)...); dst) + # triangularind function lowertriangularind(A::AbstractMatrix) Base.require_one_based_indexing(A) diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl new file mode 100644 index 00000000..430ff213 --- /dev/null +++ b/src/implementations/exponential.jl @@ -0,0 +1,136 @@ +# Inputs +# ------ +function copy_input(::typeof(exponential), A::AbstractMatrix) + return copy!(similar(A, float(eltype(A))), A) +end + +copy_input(::typeof(exponential), A::Diagonal) = copy(A) + +function copy_input(::typeof(exponentiali), τ::Number, A::AbstractMatrix) + return τ, copy!(similar(A, complex(eltype(A))), A) +end + +copy_input(::typeof(exponentiali), τ::Number, A::Diagonal) = τ, copy(A) + +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) + @check_size(expA, (m, m)) + return @check_scalar(expA, A) +end + +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + if !ishermitian(A) + throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix)")) + end + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) + @check_size(expA, (m, m)) + return @check_scalar(expA, A) +end + +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert expA isa Diagonal + @check_size(expA, (m, m)) + @check_scalar(expA, A) + return nothing +end + +function check_input(::typeof(exponentiali!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) + return @check_size(expA, (m, m)) +end + +function check_input(::typeof(exponentiali!), A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + if !ishermitian(A) + throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix)")) + end + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) + return @check_size(expA, (m, m)) +end + +function check_input(::typeof(exponentiali!), A::AbstractMatrix, expA::AbstractMatrix, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert expA isa Diagonal + return @check_size(expA, (m, m)) +end + +# Outputs +# ------- +initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm) = A +initialize_output(::typeof(exponentiali!), τ::Number, A::AbstractMatrix, ::AbstractAlgorithm) = + complex(A) + +# Implementation +# -------------- +function exponential!(A, expA, alg::MatrixFunctionViaLA) + check_input(exponential!, A, expA, alg) + return LinearAlgebra.exp!(A) +end + +function exponential!(A, expA, alg::MatrixFunctionViaEigh) + check_input(exponential!, A, expA, alg) + D, V = eigh_full!(A, alg.eigh_alg) + expD = map_diagonal!(x -> exp(x / 2), D, D) + VexpD = rmul!(V, expD) + return mul!(expA, VexpD, V') +end + +function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEig) + check_input(exponential!, A, expA, alg) + D, V = eig_full!(A, alg.eig_alg) + expD = map_diagonal!(exp, D, D) + iV = inv(V) + VexpD = rmul!(V, expD) + if eltype(A) <: Real + expA .= real.(VexpD * iV) + else + mul!(expA, VexpD, iV) + end + return expA +end + +function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaLA) + check_input(exponentiali!, A, expA, alg) + expA .= A .* (im * τ) + return LinearAlgebra.exp!(expA) +end + +function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + check_input(exponentiali!, A, expA, alg) + D, V = eigh_full!(A, alg.eigh_alg) + expD = map_diagonal(x -> exp(x * im * τ), D) + if eltype(A) <: Real + VexpD = V * expD + return expA .= real.(VexpD * V') + else + VexpD = rmul!(V, expD) + return mul!(expA, VexpD, V') + end +end + +function exponentiali!(τ::Number, A, expA, alg::MatrixFunctionViaEig) + check_input(exponentiali!, A, expA, alg) + D, V = eig_full!(A, alg.eig_alg) + expD = map_diagonal!(x -> exp(x * im * τ), D, D) + iV = inv(V) + VexpD = rmul!(V, expD) + return mul!(expA, VexpD, iV) +end + +# Diagonal logic +# -------------- +function exponential!(A, expA, alg::DiagonalAlgorithm) + check_input(exponential!, A, expA, alg) + return map_diagonal!(exp, expA, A) +end + +function exponentiali!(τ::Number, A, expA, alg::DiagonalAlgorithm) + check_input(exponentiali!, A, expA, alg) + return map_diagonal!(x -> exp(x * im * τ), expA, A) +end diff --git a/src/interface/exponential.jl b/src/interface/exponential.jl new file mode 100644 index 00000000..ffd725d0 --- /dev/null +++ b/src/interface/exponential.jl @@ -0,0 +1,60 @@ +# Exponential functions +# -------------- + +""" + exponential(A; kwargs...) -> expA + exponential(A, alg::AbstractAlgorithm) -> expA + exponential!(A, [expA]; kwargs...) -> expA + exponential!(A, [expA], alg::AbstractAlgorithm) -> expA + +Compute the exponential of the square matrix `A`, + +!!! note + The bang method `exponential!` optionally accepts the output structure and + possibly destroys the input matrix `A`. Always use the return value of the function + as it may not always be possible to use the provided `expA` as output. + +See also [`exponentiali(!)`](@ref exponentiali). +""" +@functiondef exponential + +""" + exponentiali(τ, A; kwargs...) -> expiτA + exponentiali(τ, A, alg::AbstractAlgorithm) -> expiτA + exponentiali!(τ, A, [expiτA]; kwargs...) -> expiτA + exponentiali!(τ, A, [expiτA], alg::AbstractAlgorithm) -> expiτA + +Compute the exponential of `i*τ*A`, where `i` is the imaginary unit, `τ` is a scalar, and `A` is a square matrix. +This allows the user to use the hermitian eigendecomposition when `A` is hermitian, even when `i*τ*A` is not. + +!!! note + The bang method `exponentiali!` optionally accepts the output structure and + possibly destroys the input matrix `A`. + Always use the return value of the function as it may not always be + possible to use the provided `expiτA` as output. + +See also [`exponential(!)`](@ref exponential). +""" +@functiondef n_args = 2 exponentiali + +# Algorithm selection +# ------------------- +default_exponential_algorithm(A; kwargs...) = default_exponential_algorithm(typeof(A); kwargs...) +function default_exponential_algorithm(T::Type; kwargs...) + return MatrixFunctionViaLA(; kwargs...) +end +function default_exponential_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} + return DiagonalAlgorithm(; kwargs...) +end + +for f in (:exponential!,) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_exponential_algorithm(A; kwargs...) + end +end + +for f in (:exponentiali!,) + @eval function default_algorithm(::typeof($f), ::Tuple{A, B}; kwargs...) where {A, B} + return default_exponential_algorithm(B; kwargs...) + end +end diff --git a/src/interface/matrixfunctions.jl b/src/interface/matrixfunctions.jl new file mode 100644 index 00000000..ce24652d --- /dev/null +++ b/src/interface/matrixfunctions.jl @@ -0,0 +1,39 @@ +# ================================ +# EXPONENTIAL ALGORITHMS +# ================================ +""" + MatrixFunctionViaLA() + +Algorithm type to denote finding the exponential of `A` via the implementation of `LinearAlgebra`. +""" +@algdef MatrixFunctionViaLA + +""" + MatrixFunctionViaEigh() + +Algorithm type to denote finding the exponential `A` by computing the hermitian eigendecomposition of `A`. +The `eigh_alg` specifies which hermitian eigendecomposition implementation to use. +""" +struct MatrixFunctionViaEigh{A <: AbstractAlgorithm} <: AbstractAlgorithm + eigh_alg::A +end +function Base.show(io::IO, alg::MatrixFunctionViaEigh) + print(io, "MatrixFunctionViaEigh(") + _show_alg(io, alg.eigh_alg) + return print(io, ")") +end + +""" + MatrixFunctionViaEig() + +Algorithm type to denote finding the exponential `A` by computing the eigendecomposition of `A`. +The `eig_alg` specifies which eigendecomposition implementation to use. +""" +struct MatrixFunctionViaEig{A <: AbstractAlgorithm} <: AbstractAlgorithm + eig_alg::A +end +function Base.show(io::IO, alg::MatrixFunctionViaEig) + print(io, "MatrixFunctionViaEig(") + _show_alg(io, alg.eig_alg) + return print(io, ")") +end diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl deleted file mode 100644 index 8b137891..00000000 --- a/src/matrixfunctions.jl +++ /dev/null @@ -1 +0,0 @@ - diff --git a/test/exponential.jl b/test/exponential.jl new file mode 100644 index 00000000..d817510d --- /dev/null +++ b/test/exponential.jl @@ -0,0 +1,88 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra +using LinearAlgebra: exp + +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat}) + +@testset "exponential! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + + A = LinearAlgebra.normalize!(randn(rng, T, m, m)) + Ac = copy(A) + expA = LinearAlgebra.exp(A) + + expA2 = @constinferred exponential(A) + @test expA ≈ expA2 + @test A == Ac + + algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) + @testset "algorithm $alg" for alg in algs + expA2 = @constinferred exponential(A, alg) + @test expA ≈ expA2 + @test A == Ac + end + + @test_throws DomainError exponential(A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) +end + +@testset "exponentiali! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T, m, m) + τ = randn(rng, T) + Ac = copy(A) + + Aimτ = A * im * τ + expAimτ = LinearAlgebra.exp(Aimτ) + + expAimτ2 = @constinferred exponentiali(τ, A) + @test expAimτ ≈ expAimτ2 + @test A == Ac + + algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) + @testset "algorithm $alg" for alg in algs + expAimτ2 = @constinferred exponentiali(τ, A, alg) + @test expAimτ ≈ expAimτ2 + @test A == Ac + end + + @test_throws DomainError exponentiali(τ, A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) +end + +@testset "exponential! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) + rng = StableRNG(123) + m = 54 + + A = Diagonal(randn(rng, T, m)) + τ = randn(rng, T) + Ac = copy(A) + + expA = LinearAlgebra.exp(A) + + expA2 = @constinferred exponential(A) + @test expA ≈ expA2 + @test A == Ac +end + +@testset "exponentiali! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) + rng = StableRNG(123) + m = 54 + + A = Diagonal(randn(rng, T, m)) + τ = randn(rng, T) + Ac = copy(A) + + Aimτ = A * im * τ + expAimτ = LinearAlgebra.exp(Aimτ) + + expAimτ2 = @constinferred exponentiali(τ, A) + @test expAimτ ≈ expAimτ2 + @test A == Ac +end diff --git a/test/genericlinearalgebra/exponential.jl b/test/genericlinearalgebra/exponential.jl new file mode 100644 index 00000000..c8a09508 --- /dev/null +++ b/test/genericlinearalgebra/exponential.jl @@ -0,0 +1,47 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra + +GenericFloats = (BigFloat, Complex{BigFloat}) + +@testset "exponential! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T, m, m) + A = (A + A') / 2 + D, V = @constinferred eigh_full(A) + algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expA = @constinferred exponential!(copy(A); alg) + expA2 = @constinferred exponential(A; alg) + @test expA2 ≈ expA + + Dexp, Vexp = @constinferred eigh_full(expA) + @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) + end +end + +using GenericSchur +@testset "exponentiali! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 2 + + A = randn(rng, T, m, m) + A = (A + A') / 2 + τ = randn(rng, T) + + D, V = @constinferred eigh_full(A) + algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expiτA = @constinferred exponentiali!(τ, copy(A); alg) + expiτA2 = @constinferred exponentiali(τ, A; alg) + @test expiτA2 ≈ expiτA + + Dexp, Vexp = @constinferred eig_full(expiτA) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im * τ)); by = imag) + end +end diff --git a/test/genericschur/exponential.jl b/test/genericschur/exponential.jl new file mode 100644 index 00000000..9bbfa5d9 --- /dev/null +++ b/test/genericschur/exponential.jl @@ -0,0 +1,46 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra + +GenericFloats = (BigFloat, Complex{BigFloat}) + +@testset "exponential! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T, m, m) + D, V = @constinferred eig_full(A) + algs = (MatrixFunctionViaEig(GS_QRIteration()),) + expA_LA = @constinferred exponential(A) + @testset "algorithm $alg" for alg in algs + expA = @constinferred exponential!(copy(A)) + expA2 = @constinferred exponential(A; alg = alg) + @test expA ≈ expA_LA + @test expA2 ≈ expA + + Dexp, Vexp = @constinferred eig_full(expA) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D)); by = imag) + end +end + +@testset "exponentiali! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T, m, m) + τ = randn(rng, T) + + D, V = @constinferred eig_full(A) + algs = (MatrixFunctionViaEig(GS_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expiτA = @constinferred exponentiali!(τ, copy(A)) + expiτA2 = @constinferred exponentiali(τ, A; alg) + @test expiτA2 ≈ expiτA + + Dexp, Vexp = @constinferred eig_full(expiτA) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im * τ)); by = imag) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 4b69a3dc..e17700d9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,9 @@ if !is_buildkite @safetestset "Image and Null Space" begin include("orthnull.jl") end + @safetestset "Exponential" begin + include("exponential.jl") + end @safetestset "Mooncake" begin include("mooncake.jl") end @@ -112,18 +115,28 @@ if AMDGPU.functional() end using GenericLinearAlgebra -@safetestset "QR / LQ Decomposition" begin - include("genericlinearalgebra/qr.jl") - include("genericlinearalgebra/lq.jl") -end -@safetestset "Singular Value Decomposition" begin - include("genericlinearalgebra/svd.jl") -end -@safetestset "Hermitian Eigenvalue Decomposition" begin - include("genericlinearalgebra/eigh.jl") +if !is_buildkite + @safetestset "QR / LQ Decomposition" begin + include("genericlinearalgebra/qr.jl") + include("genericlinearalgebra/lq.jl") + end + @safetestset "Singular Value Decomposition" begin + include("genericlinearalgebra/svd.jl") + end + @safetestset "Hermitian Eigenvalue Decomposition" begin + include("genericlinearalgebra/eigh.jl") + end + @safetestset "Exponential" begin + include("genericlinearalgebra/exponential.jl") + end end using GenericSchur -@safetestset "General Eigenvalue Decomposition" begin - include("genericschur/eig.jl") +if !is_buildkite + @safetestset "General Eigenvalue Decomposition" begin + include("genericschur/eig.jl") + end + @safetestset "Exponential" begin + include("genericschur/exponential.jl") + end end