Skip to content

Commit b7e281c

Browse files
Fix infinite recursion in _solve_direct_dual! and update tests
- Use __init instead of solve to avoid creating another DualLinearCache which would recursively call _solve_direct_dual! again - Update tests to use nameof check since extension types aren't directly accessible 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent bdec814 commit b7e281c

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,11 @@ function _solve_direct_dual!(
417417
dual_A = getfield(cache, :dual_A)
418418
dual_b = getfield(cache, :dual_b)
419419

420-
# Solve by creating a LinearProblem with the dual values and using LinearSolve
420+
# Use __init to create a regular LinearCache (bypasses ForwardDiff extension)
421+
# then solve! on that cache directly with the dual values
421422
dual_prob = LinearProblem(dual_A, dual_b)
422-
dual_sol = solve(dual_prob, getfield(cache, :linear_cache).alg, args...; kwargs...)
423+
dual_cache = __init(dual_prob, getfield(cache, :linear_cache).alg, args...; kwargs...)
424+
dual_sol = SciMLBase.solve!(dual_cache)
423425

424426
# Update the cache
425427
if getfield(cache, :dual_u) isa AbstractArray

test/forwarddiff_overloads.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,17 +202,20 @@ A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
202202

203203
prob = LinearProblem(A, b)
204204

205+
# Helper to check if type is DualLinearCache (extension type not directly accessible)
206+
is_dual_cache(x) = nameof(typeof(x)) == :DualLinearCache
207+
205208
# GenericLUFactorization now returns DualLinearCache for type stability
206209
# (the optimization for GenericLU happens at solve-time instead of init-time)
207-
@test init(prob, GenericLUFactorization()) isa LinearSolve.LinearSolveForwardDiffExt.DualLinearCache
210+
@test is_dual_cache(init(prob, GenericLUFactorization()))
208211

209212
# Test inference with explicit algorithm
210-
@test (@inferred init(prob, LUFactorization())) isa LinearSolve.LinearSolveForwardDiffExt.DualLinearCache
211-
@test (@inferred init(prob, GenericLUFactorization())) isa LinearSolve.LinearSolveForwardDiffExt.DualLinearCache
213+
@test is_dual_cache(@inferred init(prob, LUFactorization()))
214+
@test is_dual_cache(@inferred init(prob, GenericLUFactorization()))
212215

213216
# Test inference with default algorithm (nothing) - this was the main bug
214217
# Previously returned Union{LinearCache, DualLinearCache} due to runtime conditional
215-
@test (@inferred init(prob, nothing)) isa LinearSolve.LinearSolveForwardDiffExt.DualLinearCache
218+
@test is_dual_cache(@inferred init(prob, nothing))
216219

217220
# Test that SparspakFactorization still opts out (sparse solvers can't handle Duals the same way)
218221
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

0 commit comments

Comments
 (0)