Skip to content

Commit 25f159c

Browse files
Fix type inference for ForwardDiff Dual problems
The `init` function for `DualAbstractLinearProblem` with `DefaultLinearSolver` previously had a runtime conditional that checked whether the auto-selected algorithm was `GenericLUFactorization`. This caused the return type to be `Union{LinearCache, DualLinearCache}`, breaking type inference. This fix: 1. Removes the runtime conditional from `init`, ensuring it always returns `DualLinearCache` for type stability 2. Moves the `GenericLUFactorization` optimization to `solve!` instead, where it checks at solve-time and uses direct dual solving if applicable 3. Adds `_use_direct_dual_solve` helper functions to detect when direct dual solving should be used 4. Adds `_solve_direct_dual!` function that solves the dual system directly without separating primal/partials This preserves the performance optimization for `GenericLUFactorization` (which can work directly with any number type including Duals) while ensuring `init` has a concrete return type. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent d40b7bf commit 25f159c

File tree

1 file changed

+54
-10
lines changed

1 file changed

+54
-10
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,22 +265,19 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAl
265265
return __dual_init(prob, alg, args...; kwargs...)
266266
end
267267

268-
# Opt out for GenericLUFactorization
269-
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::GenericLUFactorization, args...; kwargs...)
270-
return __init(prob, alg, args...; kwargs...)
271-
end
268+
# NOTE: Removed GenericLUFactorization opt-out from init to fix type inference.
269+
# The special handling for GenericLUFactorization is now done in solve! instead.
270+
# This ensures init always returns DualLinearCache for type stability.
272271

273-
# Opt out for SparspakFactorization
272+
# Opt out for SparspakFactorization (sparse solvers can't handle Duals in the same way)
274273
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SparspakFactorization, args...; kwargs...)
275274
return __init(prob, alg, args...; kwargs...)
276275
end
277276

277+
# NOTE: Removed the runtime conditional for DefaultLinearSolver that checked for
278+
# GenericLUFactorization. Now always use __dual_init for type stability.
278279
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::DefaultLinearSolver, args...; kwargs...)
279-
if alg.alg === DefaultAlgorithmChoice.GenericLUFactorization
280-
return __init(prob, alg, args...; kwargs...)
281-
else
282-
return __dual_init(prob, alg, args...; kwargs...)
283-
end
280+
return __dual_init(prob, alg, args...; kwargs...)
284281
end
285282

286283
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::Nothing,
@@ -376,9 +373,25 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
376373
solve!(cache, getfield(cache, :linear_cache).alg, args...; kwargs...)
377374
end
378375

376+
# Check if the algorithm should use the direct dual solve path
377+
# (algorithms that can work directly with Dual numbers without the primal/partials separation)
378+
function _use_direct_dual_solve(alg)
379+
return alg isa GenericLUFactorization
380+
end
381+
382+
function _use_direct_dual_solve(alg::DefaultLinearSolver)
383+
return alg.alg === DefaultAlgorithmChoice.GenericLUFactorization
384+
end
385+
379386
function SciMLBase.solve!(
380387
cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <:
381388
ForwardDiff.Dual}
389+
# Check if this algorithm can work directly with Duals (e.g., GenericLUFactorization)
390+
# In that case, we solve the dual problem directly without separating primal/partials
391+
if _use_direct_dual_solve(getfield(cache, :linear_cache).alg)
392+
return _solve_direct_dual!(cache, alg, args...; kwargs...)
393+
end
394+
382395
primal_sol = linearsolve_forwarddiff_solve!(
383396
cache::DualLinearCache, getfield(cache, :linear_cache).alg, args...; kwargs...)
384397

@@ -396,6 +409,37 @@ function SciMLBase.solve!(
396409
)
397410
end
398411

412+
# Direct solve path for algorithms that can work with Dual numbers directly
413+
# This avoids the primal/partials separation overhead
414+
function _solve_direct_dual!(
415+
cache::DualLinearCache{DT}, alg, args...; kwargs...) where {DT <: ForwardDiff.Dual}
416+
# Reconstruct the dual A and b
417+
dual_A = getfield(cache, :dual_A)
418+
dual_b = getfield(cache, :dual_b)
419+
420+
# Solve directly with Duals using the generic LU path
421+
# This works because GenericLUFactorization doesn't use BLAS and can handle any number type
422+
dual_u = dual_A \ dual_b
423+
424+
# Update the cache
425+
if getfield(cache, :dual_u) isa AbstractArray
426+
getfield(cache, :dual_u) .= dual_u
427+
else
428+
setfield!(cache, :dual_u, dual_u)
429+
end
430+
431+
# Also update the primal cache for consistency
432+
primal_u = nodual_value.(dual_u)
433+
if getfield(cache, :linear_cache).u isa AbstractArray
434+
getfield(cache, :linear_cache).u .= primal_u
435+
end
436+
437+
return SciMLBase.build_linear_solution(
438+
getfield(cache, :linear_cache).alg, getfield(cache, :dual_u), nothing, cache;
439+
retcode = ReturnCode.Success, iters = 1, stats = nothing
440+
)
441+
end
442+
399443
function setA!(dc::DualLinearCache, A)
400444
# Put the Dual-stripped versions in the LinearCache
401445
prop = nodual_value!(getproperty(dc.linear_cache, :A), A) # Update in-place

0 commit comments

Comments
 (0)