Skip to content

Commit 4882e97

Browse files
Use LinearSolve's solve instead of manual backslash
Instead of manually calling `dual_A \ dual_b`, create a LinearProblem with the dual values and use solve() to go through the proper LinearSolve infrastructure. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 25f159c commit 4882e97

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -413,30 +413,30 @@ end
413413
# This avoids the primal/partials separation overhead
414414
function _solve_direct_dual!(
415415
cache::DualLinearCache{DT}, alg, args...; kwargs...) where {DT <: ForwardDiff.Dual}
416-
# Reconstruct the dual A and b
416+
# Get the dual A and b
417417
dual_A = getfield(cache, :dual_A)
418418
dual_b = getfield(cache, :dual_b)
419419

420-
# Solve directly with Duals using the generic LU path
421-
# This works because GenericLUFactorization doesn't use BLAS and can handle any number type
422-
dual_u = dual_A \ dual_b
420+
# Solve by creating a LinearProblem with the dual values and using LinearSolve
421+
dual_prob = LinearProblem(dual_A, dual_b)
422+
dual_sol = solve(dual_prob, getfield(cache, :linear_cache).alg, args...; kwargs...)
423423

424424
# Update the cache
425425
if getfield(cache, :dual_u) isa AbstractArray
426-
getfield(cache, :dual_u) .= dual_u
426+
getfield(cache, :dual_u) .= dual_sol.u
427427
else
428-
setfield!(cache, :dual_u, dual_u)
428+
setfield!(cache, :dual_u, dual_sol.u)
429429
end
430430

431431
# Also update the primal cache for consistency
432-
primal_u = nodual_value.(dual_u)
432+
primal_u = nodual_value.(dual_sol.u)
433433
if getfield(cache, :linear_cache).u isa AbstractArray
434434
getfield(cache, :linear_cache).u .= primal_u
435435
end
436436

437437
return SciMLBase.build_linear_solution(
438-
getfield(cache, :linear_cache).alg, getfield(cache, :dual_u), nothing, cache;
439-
retcode = ReturnCode.Success, iters = 1, stats = nothing
438+
getfield(cache, :linear_cache).alg, getfield(cache, :dual_u), dual_sol.resid, cache;
439+
dual_sol.retcode, dual_sol.iters, dual_sol.stats
440440
)
441441
end
442442

0 commit comments

Comments
 (0)