diff --git a/Project.toml b/Project.toml index 0e9a26cc..37e062dc 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a" GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" @@ -18,6 +19,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore" MatrixAlgebraKitAMDGPUExt = "AMDGPU" MatrixAlgebraKitCUDAExt = "CUDA" +MatrixAlgebraKitEnzymeExt = "Enzyme" MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra" MatrixAlgebraKitGenericSchurExt = "GenericSchur" MatrixAlgebraKitMooncakeExt = "Mooncake" @@ -30,6 +32,8 @@ ChainRulesTestUtils = "1" CUDA = "5" GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" +Enzyme = "0.13.109" +EnzymeTestUtils = "0.2.5" JET = "0.9, 0.10" LinearAlgebra = "1" Mooncake = "0.4.183" @@ -46,6 +50,8 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -56,4 +62,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake", "Enzyme", "EnzymeTestUtils"] diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl new file mode 100644 index 00000000..755a80bd --- /dev/null +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -0,0 +1,371 @@ +module MatrixAlgebraKitEnzymeExt + +using MatrixAlgebraKit +using MatrixAlgebraKit: copy_input +using MatrixAlgebraKit: diagview, inv_safe, eig_trunc!, eigh_trunc!, truncate +using MatrixAlgebraKit: qr_pullback!, lq_pullback! +using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! +using MatrixAlgebraKit: eig_pullback!, eigh_pullback! +using MatrixAlgebraKit: svd_pullback! +using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! +using Enzyme +using Enzyme.EnzymeCore +using Enzyme.EnzymeCore: EnzymeRules +using LinearAlgebra + +@inline EnzymeRules.inactive_type(v::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = true + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(copy_input)}, + ::Type{RT}, + f::Annotation, + A::Annotation + ) where {RT} + func.val(f.val, A.val) + primal = EnzymeRules.needs_primal(config) ? copy(A.val) : nothing + shadow = EnzymeRules.needs_shadow(config) ? zero(A.dval) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, shadow) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(copy_input)}, + dret::Type{RT}, + cache, + f::Annotation, + A::Annotation + ) where {RT} + copy_shadow = cache + if !isa(A, Const) && !isnothing(copy_shadow) + A.dval .+= copy_shadow + end + return (nothing, nothing) +end + +# two-argument factorizations like LQ, QR, EIG +for (f, pb) in ( + (qr_full!, qr_pullback!), + (lq_full!, lq_pullback!), + (qr_compact!, qr_pullback!), + (lq_compact!, lq_pullback!), + (eig_full!, eig_pullback!), + (eigh_full!, eigh_pullback!), + (left_polar!, left_polar_pullback!), + (right_polar!, right_polar_pullback!), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation, + arg::Annotation{Tuple{TA, TB}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA, TB} + cache_arg = nothing + # form cache if needed + cache_A = !(typeof(arg) <: Const) ? copy(A.val) : nothing + func.val(A.val, arg.val, alg.val) + primal = EnzymeRules.needs_primal(config) ? arg.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg)) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation, + arg::Annotation{Tuple{TA, TB}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA, TB} + cache_A, cache_arg = cache + argval = arg.val + Aval = !isnothing(cache_A) ? cache_A : A.val + ∂arg = isa(arg, Const) ? nothing : arg.dval + if !isa(A, Const) && !isa(arg, Const) + $pb(A.dval, Aval, argval, ∂arg) + end + !isa(arg, Const) && make_zero!(arg.dval) + return (nothing, nothing, nothing) + end + end +end + +for (f, pb) in ( + (qr_null!, qr_null_pullback!), + (lq_null!, lq_null_pullback!), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation, + arg::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A = copy(A.val) + func.val(A.val, arg.val, alg.val) + primal = EnzymeRules.needs_primal(config) ? arg.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, cache_A) + end + + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation, + arg::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A = cache + Aval = isnothing(cache_A) ? A.val : cache_A + if !isa(A, Const) && !isa(arg, Const) + $pb(A.dval, Aval, arg.val, arg.dval) + end + !isa(arg, Const) && make_zero!(arg.dval) + return (nothing, nothing, nothing) + end + end +end + +for f in (:svd_compact!, :svd_full!) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation, + USVᴴ::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + # form cache if needed + cache_USVᴴ = (EnzymeRules.overwritten(config)[3] && !(typeof(USVᴴ) <: Const)) ? copy(USVᴴ.val) : nothing + cache_A = !(typeof(A) <: Const) ? copy(A.val) : nothing + func.val(A.val, USVᴴ.val, alg.val) + primal = EnzymeRules.needs_primal(config) ? USVᴴ.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? USVᴴ.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ)) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + dret::Type{RT}, + cache, + A::Annotation, + USVᴴ::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A, cache_USVᴴ = cache + Aval = isnothing(cache_A) ? A.val : cache_A + USVᴴval = !isnothing(cache_USVᴴ) ? cache_USVᴴ : USVᴴ.val + U, S, Vᴴ = USVᴴval + ∂USVᴴ = isa(USVᴴ, Const) ? nothing : USVᴴ.dval + if !isa(A, Const) && !isa(USVᴴ, Const) + minmn = min(size(A.val)...) + if $(f == svd_compact!) # compact + svd_pullback!(A.dval, Aval, USVᴴval, ∂USVᴴ) + else # full + vU = view(U, :, 1:minmn) + vS = Diagonal(diagview(S)[1:minmn]) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(∂USVᴴ[1], :, 1:minmn) + vdS = Diagonal(diagview(∂USVᴴ[2])[1:minmn]) + vdVᴴ = view(∂USVᴴ[3], 1:minmn, :) + svd_pullback!(A.dval, Aval, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) + end + end + !isa(USVᴴ, Const) && make_zero!(USVᴴ.dval) + return (nothing, nothing, nothing) + end + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc_no_error!)}, + ::Type{RT}, + A::Annotation, + USVᴴ::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + # form cache if needed + cache_A = copy(A.val) + svd_compact!(A.val, USVᴴ.val, alg.val.alg) + cache_USVᴴ = copy.(USVᴴ.val) + USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc) + primal = EnzymeRules.needs_primal(config) ? USVᴴ′ : nothing + shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴ, Const) + dU, dS, dVᴴ = USVᴴ.dval + # This creates new output shadow matrices, we do this slicing + # to ensure they have the correct eltype and dimensions. + # These new shadow matrices are "filled in" with the accumulated + # results from earlier in reverse-mode AD after this function exits + # and before `reverse` is called. + dStrunc = Diagonal(diagview(dS)[ind]) + dUtrunc = dU[:, ind] + dVᴴtrunc = dVᴴ[ind, :] + (dUtrunc, dStrunc, dVᴴtrunc) + else + (nothing, nothing, nothing) + end + shadow = EnzymeRules.needs_shadow(config) ? shadow_USVᴴ : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, shadow_USVᴴ, ind)) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_trunc_no_error!)}, + dret::Type{RT}, + cache, + A::Annotation, + USVᴴ::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache + U, S, Vᴴ = cache_USVᴴ + dU, dS, dVᴴ = shadow_USVᴴ + Aval = isnothing(cache_A) ? A.val : cache_A + if !isa(A, Const) && !isa(USVᴴ, Const) + svd_pullback!(A.dval, Aval, (U, S, Vᴴ), shadow_USVᴴ, ind) + end + !isa(USVᴴ, Const) && make_zero!(USVᴴ.dval) + return (nothing, nothing, nothing) +end + +for (f, trunc_f, full_f, pb) in ( + (:eigh_trunc_no_error!, :eigh_trunc!, :eigh_full!, :eigh_pullback!), + (:eig_trunc_no_error!, :eig_trunc!, :eig_full!, :eig_pullback!), + ) + @eval function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + A::Annotation, + DV::Annotation{Tuple{TD, TV}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TD, TV} + # form cache if needed + cache_A = copy(A.val) + $full_f(A.val, DV.val, alg.val.alg) + cache_DV = copy.(DV.val) + DV′, ind = truncate($trunc_f, DV.val, alg.val.trunc) + primal = EnzymeRules.needs_primal(config) ? DV′ : nothing + shadow_DV = if !isa(A, Const) && !isa(DV, Const) + dD, dV = DV.dval + dDtrunc = Diagonal(diagview(dD)[ind]) + dVtrunc = dV[:, ind] + (dDtrunc, dVtrunc) + else + (nothing, nothing) + end + shadow = EnzymeRules.needs_shadow(config) ? shadow_DV : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind)) + end + @eval function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f)}, + ::Type{RT}, + cache, + A::Annotation, + DV::Annotation{Tuple{TD, TV}}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TD, TV} + cache_A, cache_DV, cache_dDVtrunc, ind = cache + Aval = cache_A + D, V = cache_DV + dD, dV = cache_dDVtrunc + if !isa(A, Const) && !isa(DV, Const) + $pb(A.dval, Aval, (D, V), (dD, dV), ind) + end + !isa(DV, Const) && make_zero!(DV.dval) + return (nothing, nothing, nothing) + end +end + +for (f!, f_full!, pb!) in ( + (eig_vals!, eig_full!, eig_pullback!), + (eigh_vals!, eigh_full!, eigh_pullback!), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + A::Annotation, + D::Annotation, + alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_A = nothing + cache_D = nothing + nD, V = MatrixAlgebraKit.initialize_output($f_full!, A.val, alg.val) + nD, V = $f_full!(A.val, (nD, V), alg.val) + copy!(D.val, diagview(nD)) + primal = EnzymeRules.needs_primal(config) ? D.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D, V)) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + cache, + A::Annotation, + D::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + + cache_A, cache_D, V = cache + Dval = !isnothing(cache_D) ? cache_D : D.val + Aval = !isnothing(cache_A) ? cache_A : A.val + ∂D = isa(D, Const) ? nothing : D.dval + if !isa(A, Const) && !isa(D, Const) + $pb!(A.dval, Aval, (Diagonal(Dval), V), (Diagonal(∂D), nothing)) + end + !isa(D, Const) && make_zero!(D.dval) + return (nothing, nothing, nothing) + end + end +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_vals!)}, + ::Type{RT}, + A::Annotation, + S::Annotation, + alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_S = nothing + cache_A = copy(A.val) + U, nS, Vᴴ = svd_compact!(A.val, alg.val) + copy!(S.val, diagview(nS)) + primal = EnzymeRules.needs_primal(config) ? S.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? S.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_S, U, Vᴴ)) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(svd_vals!)}, + ::Type{RT}, + cache, + A::Annotation, + S::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + + cache_A, cache_S, U, Vᴴ = cache + Sval = !isnothing(cache_S) ? cache_S : S.val + Aval = !isnothing(cache_A) ? cache_A : A.val + ∂S = isa(S, Const) ? nothing : S.dval + if !isa(A, Const) && !isa(S, Const) + svd_pullback!(A.dval, Aval, (U, Diagonal(Sval), Vᴴ), (nothing, Diagonal(∂S), nothing)) + end + !isa(S, Const) && make_zero!(S.dval) + return (nothing, nothing, nothing) +end + +end diff --git a/src/common/safemethods.jl b/src/common/safemethods.jl index 62f23a4e..43b06513 100644 --- a/src/common/safemethods.jl +++ b/src/common/safemethods.jl @@ -13,8 +13,9 @@ sign_safe(s::Complex) = ifelse(iszero(s), one(s), s / abs(s)) # Inverse """ - function inv_safe(a::Number, tol = defaulttol(a)) + inv_safe(a::Number, tol = defaulttol(a)) Compute the inverse of a number `a`, but return zero if `a` is smaller than `tol`. """ inv_safe(a::Number, tol = defaulttol(a)) = abs(a) < tol ? zero(a) : inv(a) +@noinline inv_safe(a::ComplexF32, tol = defaulttol(a)) = abs(a) < tol ? zero(a) : inv(a) diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 00000000..9f9072dd --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,496 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using ChainRulesCore +using Enzyme, EnzymeTestUtils +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +ETs = (Float32, ComplexF64) +include("ad_utils.jl") +function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated) + ΔA = randn(rng, eltype(A), size(A)...) + A_ΔA() = Duplicated(copy(A), copy(ΔA)) + args_Δargs() = Duplicated(copy.(args), copy.(Δargs)) + copy_activities = isnothing(alg) ? (Const(f), A_ΔA()) : (Const(f), A_ΔA(), Const(alg)) + inplace_activities = isnothing(alg) ? (Const(f!), A_ΔA(), args_Δargs()) : (Const(f!), A_ΔA(), args_Δargs(), Const(alg)) + + mode = EnzymeTestUtils.set_runtime_activity(ReverseSplitWithPrimal, false) + c_act = Const(EnzymeTestUtils.call_with_kwargs) + forward_copy, reverse_copy = autodiff_thunk( + mode, typeof(c_act), return_act, typeof(Const(())), map(typeof, copy_activities)... + ) + forward_inplace, reverse_inplace = autodiff_thunk( + mode, typeof(c_act), return_act, typeof(Const(())), map(typeof, inplace_activities)... + ) + copy_tape, copy_y_ad, copy_shadow_result = forward_copy(c_act, Const(()), copy_activities...) + inplace_tape, inplace_y_ad, inplace_shadow_result = forward_inplace(c_act, Const(()), inplace_activities...) + if !(copy_shadow_result === nothing) + EnzymeTestUtils.map_fields_recursive(copyto!, copy_shadow_result, copy.(ȳ)) + end + if !(inplace_shadow_result === nothing) + EnzymeTestUtils.map_fields_recursive(copyto!, inplace_shadow_result, copy.(ȳ)) + end + dx_copy_ad = only(reverse_copy(c_act, Const(()), copy_activities..., copy_tape)) + dx_inplace_ad = only(reverse_inplace(c_act, Const(()), inplace_activities..., inplace_tape)) + # check all returned derivatives between copy & inplace + for (i, (copy_act_i, inplace_act_i)) in enumerate(zip(copy_activities[2:end], inplace_activities[2:end])) + if copy_act_i isa Duplicated && inplace_act_i isa Duplicated + msg_deriv = "shadow derivative for argument $(i - 1) should match between copy and inplace" + EnzymeTestUtils.test_approx(copy_act_i.dval, inplace_act_i.dval, msg_deriv) + end + end + return +end + +@timedtestset "QR AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + A = randn(rng, T, m, n) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + @testset for alg in ( + LAPACK_HouseholderQR(), + LAPACK_HouseholderQR(; positive = true), + ) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "qr_compact" begin + ΔQ = randn(rng, T, m, minmn) + ΔR = randn(rng, T, minmn, n) + Q, R = qr_compact(A, alg) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(qr_compact, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = (ΔQ, ΔR), fdm = fdm) + test_pullbacks_match(rng, qr_compact!, qr_compact, A, (Q, R), (ΔQ, ΔR), alg) + end + @testset "qr_null" begin + Q, R = qr_compact(A, alg) + N = zeros(T, m, max(0, m - minmn)) + ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) + test_reverse(qr_null, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = ΔN) + test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) + end + @testset "qr_full" begin + Q, R = qr_full(A, alg) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn(rng, T, m, m) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn(rng, T, m, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(qr_full, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = (ΔQ, ΔR), fdm = fdm) + test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) + end + @testset "qr_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Q, R = qr_compact(Ard, alg) + ΔQ = randn(rng, T, m, minmn) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn(rng, T, minmn, n) + view(ΔR, (r + 1):minmn, :) .= 0 + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(qr_compact, RT, (Ard, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = (ΔQ, ΔR), fdm = fdm) + test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) + end + end + end + end +end + +@timedtestset "LQ AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + minmn = min(m, n) + A = randn(rng, T, m, n) + @testset for alg in ( + LAPACK_HouseholderLQ(), + LAPACK_HouseholderLQ(; positive = true), + ) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "lq_compact" begin + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + L, Q = lq_compact(A, alg) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(lq_compact, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = (ΔL, ΔQ), fdm = fdm) + test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (ΔL, ΔQ), alg) + end + @testset "lq_null" begin + L, Q = lq_compact(A, alg) + ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q + Nᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q + test_reverse(lq_null, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = ΔNᴴ) + test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) + end + @testset "lq_full" begin + L, Q = lq_full(A, alg) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn(rng, T, n, n) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔL = randn(rng, T, m, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(lq_full, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = (ΔL, ΔQ), fdm = fdm) + test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) + end + @testset "lq_compact -- rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + L, Q = lq_compact(Ard, alg) + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(lq_compact, RT, (Ard, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = (ΔL, ΔQ), fdm = fdm) + test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) + end + end + end + end +end + +@timedtestset "EIG AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = make_eig_matrix(rng, T, m) + D, V = eig_full(A) + Ddiag = diagview(D) + ΔV = randn(rng, complex(T), m, m) + ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) + ΔD = randn(rng, complex(T), m, m) + ΔD2 = Diagonal(randn(rng, complex(T), m)) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + @testset for alg in ( + LAPACK_Simple(), + #LAPACK_Expert(), # expensive on CI + ) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm) + test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) + test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag), fdm = fdm) + test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) + end + @testset "eig_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(diagview(D), truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + test_reverse(eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(Ddiag[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + test_reverse(eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg; ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end +end + +function copy_eigh_full(A; kwargs...) + A = (A + A') / 2 + return eigh_full(A; kwargs...) +end + +function copy_eigh_full(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_full(A, alg; kwargs...) +end + +function copy_eigh_full!(A, DV; kwargs...) + A = (A + A') / 2 + return eigh_full!(A, DV; kwargs...) +end + +function copy_eigh_full!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_full!(A, DV, alg; kwargs...) +end + +function copy_eigh_vals(A; kwargs...) + A = (A + A') / 2 + return eigh_vals(A; kwargs...) +end + +function copy_eigh_vals!(A, D; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D; kwargs...) +end + +function copy_eigh_vals(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_vals(A, alg; kwargs...) +end + +function copy_eigh_vals!(A, D, alg; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D, alg; kwargs...) +end + +function copy_eigh_trunc_no_error(A; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error(A; kwargs...) +end + +function copy_eigh_trunc_no_error!(A, DV; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error!(A, DV; kwargs...) +end + +function copy_eigh_trunc_no_error(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error(A, alg; kwargs...) +end + +function copy_eigh_trunc_no_error!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error!(A, DV, alg; kwargs...) +end + +@timedtestset "EIGH AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = make_eigh_matrix(rng, T, m) + #A = (A + A') / 2 + D, V = eigh_full(A) + D2 = Diagonal(D) + ΔV = randn(rng, T, m, m) + ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) + ΔD = randn(rng, real(T), m, m) + ΔD2 = Diagonal(randn(rng, real(T), m)) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + @testset for alg in ( + LAPACK_QRIteration(), + #LAPACK_DivideAndConquer(), + #LAPACK_Bisection(), + #LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI + ) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + test_reverse(copy_eigh_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm) + test_reverse(copy_eigh_full!, RT, (copy(A), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm) + test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) + test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag), fdm = fdm) + test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) + end + @testset "eigh_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + for r in 1:4:m + Ddiag = diagview(D) + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm) + test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) + end + Ddiag = diagview(D) + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm) + test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) + end + end +end + +@timedtestset "SVD AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + @testset for alg in ( + LAPACK_QRIteration(), + #LAPACK_DivideAndConquer(), # expensive on CI + ) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "svd_compact" begin + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + test_reverse(svd_compact, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = (ΔU, ΔS, ΔVᴴ), fdm = fdm) + test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), alg) + end + @testset "svd_full" begin + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + ΔUfull = zeros(T, m, m) + ΔSfull = zeros(real(T), m, n) + ΔVᴴfull = zeros(T, n, n) + U, S, Vᴴ = svd_full(A) + view(ΔUfull, :, 1:minmn) .= ΔU + view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ + diagview(ΔSfull)[1:minmn] .= diagview(ΔS) + test_reverse(svd_full, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = (ΔUfull, ΔSfull, ΔVᴴfull), fdm = fdm) + test_pullbacks_match(rng, svd_full!, svd_full, A, (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull), alg) + end + @testset "svd_vals" begin + S = svd_vals(A) + ΔS = randn(rng, real(T), minmn) + test_reverse(svd_vals, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = ΔS, fdm = fdm) + test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS, alg) + end + end + @testset "svd_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + for r in 1:4:minmn + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + test_reverse(svd_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act = RT) + end + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + test_reverse(svd_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act = RT) + end + end + end +end + +@timedtestset "Polar AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + @testset for alg in PolarViaSVD.( + ( + LAPACK_QRIteration(), + #LAPACK_DivideAndConquer(), # expensive on CI + ) + ) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + if m >= n + WP = left_polar(A; alg = alg) + W, P = WP + ΔW = randn(rng, T, size(W)...) + ΔP = randn(rng, T, size(P)...) + test_reverse(left_polar, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,)) + test_pullbacks_match(rng, left_polar!, left_polar, A, (W, P), (ΔW, ΔP), alg) + elseif m <= n + PWᴴ = right_polar(A; alg = alg) + P, Wᴴ = PWᴴ + ΔWᴴ = randn(rng, T, size(Wᴴ)...) + ΔP = randn(rng, T, size(P)...) + test_reverse(right_polar, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,)) + test_pullbacks_match(rng, right_polar!, right_polar, A, (P, Wᴴ), (ΔP, ΔWᴴ), alg) + end + end + end + end +end + +@timedtestset "Orth and null with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset "left_orth" begin + @testset for alg in (:polar, :qr) + n > m && alg == :polar && continue + VC = left_orth(A; alg = alg) + V, C = VC + ΔV = randn(rng, T, size(V)...) + ΔC = randn(rng, T, size(C)...) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(left_orth, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), fdm = fdm) + left_orth_alg!(A, VC) = left_orth!(A, VC; alg = alg) + left_orth_alg(A) = left_orth(A; alg = alg) + test_pullbacks_match(rng, left_orth_alg!, left_orth_alg, A, (V, C), (ΔV, ΔC)) + end + end + @testset "right_orth" begin + @testset for alg in (:polar, :lq) + n < m && alg == :polar && continue + CVᴴ = right_orth(A; alg = alg) + C, Vᴴ = CVᴴ + ΔC = randn(rng, T, size(C)...) + ΔVᴴ = randn(rng, T, size(Vᴴ)...) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + test_reverse(right_orth, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), fdm = fdm) + right_orth_alg!(A, CVᴴ) = right_orth!(A, CVᴴ; alg = alg) + right_orth_alg(A) = right_orth(A; alg = alg) + test_pullbacks_match(rng, right_orth_alg!, right_orth_alg, A, (C, Vᴴ), (ΔC, ΔVᴴ)) + end + end + @testset "left_null" begin + ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + N = similar(ΔN) + test_reverse(left_null, RT, (A, TA); fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol) + left_null_qr!(A, N) = left_null!(A, N; alg = :qr) + left_null_qr(A) = left_null(A; alg = :qr) + test_pullbacks_match(rng, left_null_qr!, left_null_qr, A, N, ΔN) + end + @testset "right_null" begin + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] + Nᴴ = similar(ΔNᴴ) + test_reverse(right_null, RT, (A, TA); fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, atol = atol, rtol = rtol) + right_null_lq!(A, Nᴴ) = right_null!(A, Nᴴ; alg = :lq) + right_null_lq(A) = right_null(A; alg = :lq) + test_pullbacks_match(rng, right_null_lq!, right_null_lq, A, Nᴴ, ΔNᴴ) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 37796e2a..66fe5889 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,11 @@ if !is_buildkite @safetestset "Image and Null Space" begin include("orthnull.jl") end + if VERSION < v"1.12.0" # reconsider when Enzyme works on 1.12 + @safetestset "Enzyme" begin + include("enzyme.jl") + end + end @safetestset "Mooncake" begin include("mooncake.jl") end