Skip to content

Commit f9f2fdd

Browse files
Merge pull request #855 from ChrisRackauckas-Claude/fix-forwarddiff-inference
Fix type inference for ForwardDiff Dual problems
2 parents d40b7bf + b7e281c commit f9f2fdd

File tree

2 files changed

+80
-14
lines changed

2 files changed

+80
-14
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,22 +265,19 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAl
265265
return __dual_init(prob, alg, args...; kwargs...)
266266
end
267267

268-
# Opt out for GenericLUFactorization
269-
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::GenericLUFactorization, args...; kwargs...)
270-
return __init(prob, alg, args...; kwargs...)
271-
end
268+
# NOTE: Removed GenericLUFactorization opt-out from init to fix type inference.
269+
# The special handling for GenericLUFactorization is now done in solve! instead.
270+
# This ensures init always returns DualLinearCache for type stability.
272271

273-
# Opt out for SparspakFactorization
272+
# Opt out for SparspakFactorization (sparse solvers can't handle Duals in the same way)
274273
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SparspakFactorization, args...; kwargs...)
275274
return __init(prob, alg, args...; kwargs...)
276275
end
277276

277+
# NOTE: Removed the runtime conditional for DefaultLinearSolver that checked for
278+
# GenericLUFactorization. Now always use __dual_init for type stability.
278279
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::DefaultLinearSolver, args...; kwargs...)
279-
if alg.alg === DefaultAlgorithmChoice.GenericLUFactorization
280-
return __init(prob, alg, args...; kwargs...)
281-
else
282-
return __dual_init(prob, alg, args...; kwargs...)
283-
end
280+
return __dual_init(prob, alg, args...; kwargs...)
284281
end
285282

286283
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::Nothing,
@@ -376,9 +373,25 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
376373
solve!(cache, getfield(cache, :linear_cache).alg, args...; kwargs...)
377374
end
378375

376+
# Check if the algorithm should use the direct dual solve path
377+
# (algorithms that can work directly with Dual numbers without the primal/partials separation)
378+
function _use_direct_dual_solve(alg)
379+
return alg isa GenericLUFactorization
380+
end
381+
382+
function _use_direct_dual_solve(alg::DefaultLinearSolver)
383+
return alg.alg === DefaultAlgorithmChoice.GenericLUFactorization
384+
end
385+
379386
function SciMLBase.solve!(
380387
cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <:
381388
ForwardDiff.Dual}
389+
# Check if this algorithm can work directly with Duals (e.g., GenericLUFactorization)
390+
# In that case, we solve the dual problem directly without separating primal/partials
391+
if _use_direct_dual_solve(getfield(cache, :linear_cache).alg)
392+
return _solve_direct_dual!(cache, alg, args...; kwargs...)
393+
end
394+
382395
primal_sol = linearsolve_forwarddiff_solve!(
383396
cache::DualLinearCache, getfield(cache, :linear_cache).alg, args...; kwargs...)
384397

@@ -396,6 +409,39 @@ function SciMLBase.solve!(
396409
)
397410
end
398411

412+
# Direct solve path for algorithms that can work with Dual numbers directly
413+
# This avoids the primal/partials separation overhead
414+
function _solve_direct_dual!(
415+
cache::DualLinearCache{DT}, alg, args...; kwargs...) where {DT <: ForwardDiff.Dual}
416+
# Get the dual A and b
417+
dual_A = getfield(cache, :dual_A)
418+
dual_b = getfield(cache, :dual_b)
419+
420+
# Use __init to create a regular LinearCache (bypasses ForwardDiff extension)
421+
# then solve! on that cache directly with the dual values
422+
dual_prob = LinearProblem(dual_A, dual_b)
423+
dual_cache = __init(dual_prob, getfield(cache, :linear_cache).alg, args...; kwargs...)
424+
dual_sol = SciMLBase.solve!(dual_cache)
425+
426+
# Update the cache
427+
if getfield(cache, :dual_u) isa AbstractArray
428+
getfield(cache, :dual_u) .= dual_sol.u
429+
else
430+
setfield!(cache, :dual_u, dual_sol.u)
431+
end
432+
433+
# Also update the primal cache for consistency
434+
primal_u = nodual_value.(dual_sol.u)
435+
if getfield(cache, :linear_cache).u isa AbstractArray
436+
getfield(cache, :linear_cache).u .= primal_u
437+
end
438+
439+
return SciMLBase.build_linear_solution(
440+
getfield(cache, :linear_cache).alg, getfield(cache, :dual_u), dual_sol.resid, cache;
441+
dual_sol.retcode, dual_sol.iters, dual_sol.stats
442+
)
443+
end
444+
399445
function setA!(dc::DualLinearCache, A)
400446
# Put the Dual-stripped versions in the LinearCache
401447
prop = nodual_value!(getproperty(dc.linear_cache, :A), A) # Update in-place

test/forwarddiff_overloads.jl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,20 +196,40 @@ overload_x_p = solve!(cache, UMFPACKFactorization())
196196
backslash_x_p = A \ b
197197
@test (overload_x_p, backslash_x_p, rtol = 1e-9)
198198

199-
# Test that GenericLU doesn't create a DualLinearCache
199+
# Test type inference for init with ForwardDiff Dual numbers
200+
# This ensures init returns a concrete type (not a Union) for type stability
200201
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
201202

202203
prob = LinearProblem(A, b)
203-
@test init(prob, GenericLUFactorization()) isa LinearSolve.LinearCache
204204

205-
@test init(prob) isa LinearSolve.LinearCache
205+
# Helper to check if type is DualLinearCache (extension type not directly accessible)
206+
is_dual_cache(x) = nameof(typeof(x)) == :DualLinearCache
206207

207-
# Test that SparspakFactorization doesn't create a DualLinearCache
208+
# GenericLUFactorization now returns DualLinearCache for type stability
209+
# (the optimization for GenericLU happens at solve-time instead of init-time)
210+
@test is_dual_cache(init(prob, GenericLUFactorization()))
211+
212+
# Test inference with explicit algorithm
213+
@test is_dual_cache(@inferred init(prob, LUFactorization()))
214+
@test is_dual_cache(@inferred init(prob, GenericLUFactorization()))
215+
216+
# Test inference with default algorithm (nothing) - this was the main bug
217+
# Previously returned Union{LinearCache, DualLinearCache} due to runtime conditional
218+
@test is_dual_cache(@inferred init(prob, nothing))
219+
220+
# Test that SparspakFactorization still opts out (sparse solvers can't handle Duals the same way)
208221
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
209222

210223
prob = LinearProblem(sparse(A), b)
211224
@test init(prob, SparspakFactorization()) isa LinearSolve.LinearCache
212225

226+
# Test that solve still works correctly with GenericLUFactorization
227+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
228+
prob = LinearProblem(A, b)
229+
sol_generic = solve(prob, GenericLUFactorization())
230+
backslash_result = A \ b
231+
@test (sol_generic.u, backslash_result, rtol = 1e-9)
232+
213233
# Test ComponentArray with ForwardDiff (Issue SciML/DifferentialEquations.jl#1110)
214234
# This tests that ArrayInterface.restructure preserves ComponentArray structure
215235

0 commit comments

Comments
 (0)