@@ -265,22 +265,19 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAl
265265 return __dual_init (prob, alg, args... ; kwargs... )
266266end
267267
268- # Opt out for GenericLUFactorization
269- function SciMLBase. init (prob:: DualAbstractLinearProblem , alg:: GenericLUFactorization , args... ; kwargs... )
270- return __init (prob, alg, args... ; kwargs... )
271- end
268+ # NOTE: Removed GenericLUFactorization opt-out from init to fix type inference.
269+ # The special handling for GenericLUFactorization is now done in solve! instead.
270+ # This ensures init always returns DualLinearCache for type stability.
272271
273- # Opt out for SparspakFactorization
272+ # Opt out for SparspakFactorization (sparse solvers can't handle Duals in the same way)
274273function SciMLBase. init (prob:: DualAbstractLinearProblem , alg:: SparspakFactorization , args... ; kwargs... )
275274 return __init (prob, alg, args... ; kwargs... )
276275end
277276
277+ # NOTE: Removed the runtime conditional for DefaultLinearSolver that checked for
278+ # GenericLUFactorization. Now always use __dual_init for type stability.
278279function SciMLBase. init (prob:: DualAbstractLinearProblem , alg:: DefaultLinearSolver , args... ; kwargs... )
279- if alg. alg === DefaultAlgorithmChoice. GenericLUFactorization
280- return __init (prob, alg, args... ; kwargs... )
281- else
282- return __dual_init (prob, alg, args... ; kwargs... )
283- end
280+ return __dual_init (prob, alg, args... ; kwargs... )
284281end
285282
286283function SciMLBase. init (prob:: DualAbstractLinearProblem , alg:: Nothing ,
@@ -376,9 +373,25 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
376373 solve! (cache, getfield (cache, :linear_cache ). alg, args... ; kwargs... )
377374end
378375
376+ # Check if the algorithm should use the direct dual solve path
377+ # (algorithms that can work directly with Dual numbers without the primal/partials separation)
378+ function _use_direct_dual_solve (alg)
379+ return alg isa GenericLUFactorization
380+ end
381+
382+ function _use_direct_dual_solve (alg:: DefaultLinearSolver )
383+ return alg. alg === DefaultAlgorithmChoice. GenericLUFactorization
384+ end
385+
379386function SciMLBase. solve! (
380387 cache:: DualLinearCache{DT} , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... ) where {DT < :
381388 ForwardDiff. Dual}
389+ # Check if this algorithm can work directly with Duals (e.g., GenericLUFactorization)
390+ # In that case, we solve the dual problem directly without separating primal/partials
391+ if _use_direct_dual_solve (getfield (cache, :linear_cache ). alg)
392+ return _solve_direct_dual! (cache, alg, args... ; kwargs... )
393+ end
394+
382395 primal_sol = linearsolve_forwarddiff_solve! (
383396 cache:: DualLinearCache , getfield (cache, :linear_cache ). alg, args... ; kwargs... )
384397
@@ -396,6 +409,39 @@ function SciMLBase.solve!(
396409 )
397410end
398411
412+ # Direct solve path for algorithms that can work with Dual numbers directly
413+ # This avoids the primal/partials separation overhead
414+ function _solve_direct_dual! (
415+ cache:: DualLinearCache{DT} , alg, args... ; kwargs... ) where {DT <: ForwardDiff.Dual }
416+ # Get the dual A and b
417+ dual_A = getfield (cache, :dual_A )
418+ dual_b = getfield (cache, :dual_b )
419+
420+ # Use __init to create a regular LinearCache (bypasses ForwardDiff extension)
421+ # then solve! on that cache directly with the dual values
422+ dual_prob = LinearProblem (dual_A, dual_b)
423+ dual_cache = __init (dual_prob, getfield (cache, :linear_cache ). alg, args... ; kwargs... )
424+ dual_sol = SciMLBase. solve! (dual_cache)
425+
426+ # Update the cache
427+ if getfield (cache, :dual_u ) isa AbstractArray
428+ getfield (cache, :dual_u ) .= dual_sol. u
429+ else
430+ setfield! (cache, :dual_u , dual_sol. u)
431+ end
432+
433+ # Also update the primal cache for consistency
434+ primal_u = nodual_value .(dual_sol. u)
435+ if getfield (cache, :linear_cache ). u isa AbstractArray
436+ getfield (cache, :linear_cache ). u .= primal_u
437+ end
438+
439+ return SciMLBase. build_linear_solution (
440+ getfield (cache, :linear_cache ). alg, getfield (cache, :dual_u ), dual_sol. resid, cache;
441+ dual_sol. retcode, dual_sol. iters, dual_sol. stats
442+ )
443+ end
444+
399445function setA! (dc:: DualLinearCache , A)
400446 # Put the Dual-stripped versions in the LinearCache
401447 prop = nodual_value! (getproperty (dc. linear_cache, :A ), A) # Update in-place
0 commit comments