Skip to content

Commit 815a521

Browse files
Merge pull request #3739 from AayushSabharwal/as/semilinear-odeprob
feat: add `SemilinearODEFunction` and `SemilinearODEProblem`
2 parents ab2beb1 + 9d176a7 commit 815a521

File tree

14 files changed

+1024
-21
lines changed

14 files changed

+1024
-21
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
2626
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
2727
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
2828
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
29+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
2930
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
3031
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3132
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
@@ -44,6 +45,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
4445
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
4546
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
4647
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
48+
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
4749
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
4850
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4951
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -114,6 +116,7 @@ DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
114116
EnumX = "1.0.4"
115117
ExprTools = "0.1.10"
116118
FMI = "0.14"
119+
FillArrays = "1.13.0"
117120
FindFirstFunctions = "1"
118121
ForwardDiff = "0.10.3, 1"
119122
FunctionWrappers = "1.1"
@@ -141,6 +144,7 @@ OrdinaryDiffEq = "6.82.0"
141144
OrdinaryDiffEqCore = "1.34.0"
142145
OrdinaryDiffEqDefault = "1.2"
143146
OrdinaryDiffEqNonlinearSolve = "1.5.0"
147+
PreallocationTools = "0.4.27"
144148
PrecompileTools = "1"
145149
Pyomo = "0.1.0"
146150
REPL = "1"

docs/src/API/codegen.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ ModelingToolkit.build_explicit_observed_function
2323
ModelingToolkit.generate_control_function
2424
ModelingToolkit.generate_update_A
2525
ModelingToolkit.generate_update_b
26+
ModelingToolkit.generate_semiquadratic_functions
27+
ModelingToolkit.generate_semiquadratic_jacobian
28+
ModelingToolkit.get_semiquadratic_W_sparsity
2629
```
2730

2831
For functions such as jacobian calculation which require symbolic computation, there

docs/src/API/problems.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ SciMLBase.ODEFunction
1717
SciMLBase.ODEProblem
1818
SciMLBase.DAEFunction
1919
SciMLBase.DAEProblem
20+
ModelingToolkit.SemilinearODEFunction
21+
ModelingToolkit.SemilinearODEProblem
2022
SciMLBase.SDEFunction
2123
SciMLBase.SDEProblem
2224
SciMLBase.DDEFunction

src/ModelingToolkit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ const DQ = DynamicQuantities
104104
import DifferentiationInterface as DI
105105
using ADTypes: AutoForwardDiff
106106
import SciMLPublic: @public
107+
import PreallocationTools
108+
import PreallocationTools: DiffCache
109+
import FillArrays
107110

108111
export @derivatives
109112

@@ -262,6 +265,7 @@ export IntervalNonlinearProblem
262265
export OptimizationProblem, constraints
263266
export SteadyStateProblem
264267
export JumpProblem
268+
export SemilinearODEFunction, SemilinearODEProblem
265269
export alias_elimination, flatten
266270
export connect, domain_connect, @connector, Connection, AnalysisPoint, Flow, Stream,
267271
instream

src/problems/docs.jl

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
struct SemilinearODEFunction{iip, spec} end
2+
struct SemilinearODEProblem{iip, spec} end
3+
14
const U0_P_DOCS = """
25
The order of unknowns is determined by `unknowns(sys)`. If the system is split
36
[`is_split`](@ref) create an [`MTKParameters`](@ref) object. Otherwise, a parameter vector.
@@ -92,6 +95,15 @@ function problem_ctors(prob, istd)
9295
end
9396
end
9497

98+
function problem_ctors(prob::Type{<:SemilinearODEProblem}, istd)
99+
@assert istd
100+
"""
101+
SciMLBase.$prob(sys::System, op, tspan::NTuple{2}; kwargs...)
102+
SciMLBase.$prob{iip}(sys::System, op, tspan::NTuple{2}; kwargs...)
103+
SciMLBase.$prob{iip, specialize}(sys::System, op, tspan::NTuple{2}; stiff_linear = true, stiff_quadratic = false, stiff_nonlinear = false, kwargs...)
104+
"""
105+
end
106+
95107
function prob_fun_common_kwargs(T, istd)
96108
return """
97109
- `check_compatibility`: Whether to check if the given system `sys` contains all the
@@ -103,7 +115,8 @@ function prob_fun_common_kwargs(T, istd)
103115
"""
104116
end
105117

106-
function problem_docstring(prob, func, istd; init = true, extra_body = "")
118+
function problem_docstring(prob, func, istd; init = true, extra_body = "",
119+
extra_kwargs = "", extra_kwargs_desc = "")
107120
if func isa DataType
108121
func = "`$func`"
109122
end
@@ -127,8 +140,9 @@ function problem_docstring(prob, func, istd; init = true, extra_body = "")
127140
$PROBLEM_KWARGS
128141
$(istd ? TIME_DEPENDENT_PROBLEM_KWARGS : "")
129142
$(prob_fun_common_kwargs(prob, istd))
130-
143+
$(extra_kwargs)
131144
All other keyword arguments are forwarded to the $func constructor.
145+
$(extra_kwargs_desc)
132146
133147
$PROBLEM_INTERNALS_HEADER
134148
@@ -186,6 +200,32 @@ If the `System` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
186200
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
187201
"""
188202

203+
const SEMILINEAR_EXTRA_BODY = """
204+
This is a special form of an ODE which uses a `SplitFunction` internally. The equations are
205+
separated into linear, quadratic and general terms and phrased as matrix operations. See
206+
[`calculate_semiquadratic_form`](@ref) for information on how the equations are split. This
207+
formulation allows leveraging split ODE solvers such as `KenCarp4` and is useful for systems
208+
where the stiff and non-stiff terms can be separated out in such a manner. Typically the linear
209+
part of the equations is the stiff part, but the keywords `stiff_linear`, `stiff_quadratic` and `stiff_nonlinear` can
210+
be used to control which parts are considered as stiff.
211+
"""
212+
213+
const SEMILINEAR_A_B_C_KWARGS = """
214+
- `stiff_linear`: Whether the linear part of the equations should be part of the stiff function
215+
in the split form. Has no effect if the equations have no linear part.
216+
- `stiff_quadratic`: Whether the quadratic part of the equations should be part of the stiff
217+
function in the split form. Has no effect if the equations have no quadratic part.
218+
- `stiff_nonlinear`: Whether the non-linear non-quadratic part of the equations should be part of
219+
the stiff function in the split form. Has no effect if the equations have no such
220+
non-linear non-quadratic part.
221+
"""
222+
223+
const SEMILINEAR_A_B_C_CONSTRAINT = """
224+
Note that all three of `stiff_linear`, `stiff_quadratic`, `stiff_nonlinear` cannot be identical, and at least
225+
two of `A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref) must be
226+
non-`nothing`. In other words, both of the functions in the split form must be non-empty.
227+
"""
228+
189229
for (mod, prob, func, istd, kws) in [
190230
(SciMLBase, :ODEProblem, ODEFunction, true, (;)),
191231
(SciMLBase, :SteadyStateProblem, ODEFunction, false, (;)),
@@ -201,12 +241,23 @@ for (mod, prob, func, istd, kws) in [
201241
(SciMLBase, :NonlinearProblem, NonlinearFunction, false, (;)),
202242
(SciMLBase, :NonlinearLeastSquaresProblem, NonlinearFunction, false, (;)),
203243
(SciMLBase, :SCCNonlinearProblem, NonlinearFunction, false, (; init = false)),
204-
(SciMLBase, :OptimizationProblem, OptimizationFunction, false, (; init = false))
244+
(SciMLBase, :OptimizationProblem, OptimizationFunction, false, (; init = false)),
245+
(ModelingToolkit,
246+
:SemilinearODEProblem,
247+
:SemilinearODEFunction,
248+
true,
249+
(; extra_body = SEMILINEAR_EXTRA_BODY, extra_kwargs = SEMILINEAR_A_B_C_KWARGS,
250+
extra_kwargs_desc = SEMILINEAR_A_B_C_CONSTRAINT))
205251
]
206-
@eval @doc problem_docstring($mod.$prob, $func, $istd) $mod.$prob
252+
kwexpr = Expr(:parameters)
253+
for (k, v) in pairs(kws)
254+
push!(kwexpr.args, Expr(:kw, k, v))
255+
end
256+
@eval @doc problem_docstring($kwexpr, $mod.$prob, $func, $istd) $mod.$prob
207257
end
208258

209-
function function_docstring(func, istd, optionals)
259+
function function_docstring(
260+
func, istd, optionals; extra_body = "", extra_kwargs = "", extra_kwargs_desc = "")
210261
return """
211262
$func(sys::System; kwargs...)
212263
$func{iip}(sys::System; kwargs...)
@@ -216,6 +267,8 @@ function function_docstring(func, istd, optionals)
216267
function should be in-place. `specialization` is a `SciMLBase.AbstractSpecalize`
217268
subtype indicating the level of specialization of the $func.
218269
270+
$(extra_body)
271+
219272
Beyond the arguments listed below, this constructor accepts all keyword arguments
220273
supported by the DifferentialEquations.jl `solve` function. For a complete list
221274
and detailed descriptions, see the [DifferentialEquations.jl solve documentation](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/).
@@ -236,9 +289,11 @@ function function_docstring(func, istd, optionals)
236289
sparse matrices. Also controls whether the mass matrix is sparse, wherever applicable.
237290
$(prob_fun_common_kwargs(func, istd))
238291
$(process_optional_function_kwargs(optionals))
292+
$(extra_kwargs)
239293
- `kwargs...`: Additional keyword arguments passed to the solver
240294
241295
All other keyword arguments are forwarded to the `$func` struct constructor.
296+
$(extra_kwargs_desc)
242297
"""
243298
end
244299

@@ -329,20 +384,30 @@ function process_optional_function_kwargs(choices::Vector{Symbol})
329384
join(map(Base.Fix1(getindex, OPTIONAL_FN_KWARGS_DICT), choices), "\n")
330385
end
331386

332-
for (mod, func, istd, optionals) in [
333-
(SciMLBase, :ODEFunction, true, [:jac, :tgrad]),
334-
(SciMLBase, :ODEInputFunction, true, [:inputfn, :jac, :tgrad, :controljac]),
335-
(SciMLBase, :DAEFunction, true, [:jac, :tgrad]),
336-
(SciMLBase, :DDEFunction, true, Symbol[]),
337-
(SciMLBase, :SDEFunction, true, [:jac, :tgrad]),
338-
(SciMLBase, :SDDEFunction, true, Symbol[]),
339-
(SciMLBase, :DiscreteFunction, true, Symbol[]),
340-
(SciMLBase, :ImplicitDiscreteFunction, true, Symbol[]),
341-
(SciMLBase, :NonlinearFunction, false, [:resid_prototype, :jac]),
342-
(SciMLBase, :IntervalNonlinearFunction, false, Symbol[]),
343-
(SciMLBase, :OptimizationFunction, false, [:jac, :grad, :hess, :cons_h, :cons_j])
387+
for (mod, func, istd, optionals, kws) in [
388+
(SciMLBase, :ODEFunction, true, [:jac, :tgrad], (;)),
389+
(SciMLBase, :ODEInputFunction, true, [:inputfn, :jac, :tgrad, :controljac], (;)),
390+
(SciMLBase, :DAEFunction, true, [:jac, :tgrad], (;)),
391+
(SciMLBase, :DDEFunction, true, Symbol[], (;)),
392+
(SciMLBase, :SDEFunction, true, [:jac, :tgrad], (;)),
393+
(SciMLBase, :SDDEFunction, true, Symbol[], (;)),
394+
(SciMLBase, :DiscreteFunction, true, Symbol[], (;)),
395+
(SciMLBase, :ImplicitDiscreteFunction, true, Symbol[], (;)),
396+
(SciMLBase, :NonlinearFunction, false, [:resid_prototype, :jac], (;)),
397+
(SciMLBase, :IntervalNonlinearFunction, false, Symbol[], (;)),
398+
(SciMLBase, :OptimizationFunction, false, [:jac, :grad, :hess, :cons_h, :cons_j], (;)),
399+
(ModelingToolkit,
400+
:SemilinearODEFunction,
401+
true,
402+
[:jac],
403+
(; extra_body = SEMILINEAR_EXTRA_BODY, extra_kwargs = SEMILINEAR_A_B_C_KWARGS,
404+
extra_kwargs_desc = SEMILINEAR_A_B_C_CONSTRAINT))
344405
]
345-
@eval @doc function_docstring($mod.$func, $istd, $optionals) $mod.$func
406+
kwexpr = Expr(:parameters)
407+
for (k, v) in pairs(kws)
408+
push!(kwexpr.args, Expr(:kw, k, v))
409+
end
410+
@eval @doc function_docstring($kwexpr, $mod.$func, $istd, $optionals) $mod.$func
346411
end
347412

348413
@doc """

0 commit comments

Comments
 (0)