Skip to content

Commit 3dccd52

Browse files
feat: auto-add state variables to output variables (#4700)
* feat: auto-add state variables to output variables * add changelog * fix lead acid model tests --------- Co-authored-by: Eric G. Kratz <[email protected]>
1 parent 8aaaab1 commit 3dccd52

File tree

5 files changed

+103
-76
lines changed

5 files changed

+103
-76
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
- Made composite electrode model compatible with particle size distribution ([#4687](https://github.com/pybamm-team/PyBaMM/pull/4687))
1010
- Added `Symbol.post_order()` method to return an iterable that steps through the tree in post-order fashion. ([#4684](https://github.com/pybamm-team/PyBaMM/pull/4684))
1111
- Added two more submodels (options) for the SEI: Lars von Kolzenberg (2020) model and Tunneling Limit model ([#4394](https://github.com/pybamm-team/PyBaMM/pull/4394))
12-
12+
- Automatically add state variables of the model to the output variables if they are not already present ([#4700](https://github.com/pybamm-team/PyBaMM/pull/4700))
1313

1414
## Breaking changes
1515

src/pybamm/discretisations/discretisation.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ def process_model(self, model, inplace=True):
193193

194194
model_disc.bcs = self.bcs
195195

196+
# pre-process variables so that all state variables are included
197+
pre_processed_variables = self._pre_process_variables(
198+
model.variables, model.initial_conditions
199+
)
200+
196201
pybamm.logger.verbose(f"Discretise initial conditions for {model.name}")
197202
ics, concat_ics = self.process_initial_conditions(model)
198203
model_disc.initial_conditions = ics
@@ -202,7 +207,8 @@ def process_model(self, model, inplace=True):
202207
# Note that we **do not** discretise the keys of model.rhs,
203208
# model.initial_conditions and model.boundary_conditions
204209
pybamm.logger.verbose(f"Discretise variables for {model.name}")
205-
model_disc.variables = self.process_dict(model.variables)
210+
211+
model_disc.variables = self.process_dict(pre_processed_variables)
206212

207213
# Process parabolic and elliptic equations
208214
pybamm.logger.verbose(f"Discretise model equations for {model.name}")
@@ -657,6 +663,46 @@ def create_mass_matrix(self, model):
657663

658664
return mass_matrix, mass_matrix_inv
659665

666+
def _pre_process_variables(
667+
self,
668+
variables: dict[str, pybamm.Symbol],
669+
initial_conditions: dict[pybamm.Variable, pybamm.Symbol],
670+
):
671+
"""
672+
Pre-process variables before discretisation. This involves:
673+
- ensuring that all the state variables are included in the variables,
674+
any missing are added
675+
676+
Parameters
677+
----------
678+
variables : dict
679+
Dictionary of variables to pre-process
680+
initial_conditions : dict
681+
Dictionary of initial conditions
682+
683+
Returns
684+
-------
685+
dict
686+
Pre-processed variables (copy of input variables with any missing state)
687+
688+
Raises
689+
------
690+
:class:`pybamm.ModelError`
691+
If any state variable names are already included but with
692+
incorrect expressions
693+
"""
694+
new_variables = {k: v for k, v in variables.items()}
695+
for var in initial_conditions.keys():
696+
if var.name not in new_variables:
697+
new_variables[var.name] = var
698+
else:
699+
if new_variables[var.name] != var:
700+
raise pybamm.ModelError(
701+
f"Variable '{var.name}' should have expression "
702+
f"'{var}', but has expression '{new_variables[var.name]}'"
703+
)
704+
return new_variables
705+
660706
def process_dict(self, var_eqn_dict, ics=False):
661707
"""Discretise a dictionary of {variable: equation}, broadcasting if necessary
662708
(can be model.rhs, model.algebraic, model.initial_conditions or
@@ -1008,7 +1054,6 @@ def _concatenate_in_order(self, var_eqn_dict, check_complete=False, sparse=False
10081054
def check_model(self, model):
10091055
"""Perform some basic checks to make sure the discretised model makes sense."""
10101056
self.check_initial_conditions(model)
1011-
self.check_variables(model)
10121057

10131058
def check_initial_conditions(self, model):
10141059
# Check initial conditions are a numpy array
@@ -1049,40 +1094,6 @@ def check_initial_conditions(self, model):
10491094
f"{model.algebraic[var].shape} and initial_conditions.shape = {model.initial_conditions[var].shape} for variable '{var}'."
10501095
)
10511096

1052-
def check_variables(self, model):
1053-
"""
1054-
Check variables in variable list against rhs.
1055-
Be lenient with size check if the variable in model.variables is broadcasted, or
1056-
a concatenation
1057-
(if broadcasted, variable is a multiplication with a vector of ones)
1058-
"""
1059-
for rhs_var in model.rhs.keys():
1060-
if rhs_var.name in model.variables.keys():
1061-
var = model.variables[rhs_var.name]
1062-
1063-
different_shapes = not np.array_equal(
1064-
model.rhs[rhs_var].shape, var.shape
1065-
)
1066-
1067-
not_concatenation = not isinstance(var, pybamm.Concatenation)
1068-
1069-
not_mult_by_one_vec = not (
1070-
isinstance(
1071-
var, (pybamm.Multiplication, pybamm.MatrixMultiplication)
1072-
)
1073-
and (
1074-
pybamm.is_matrix_one(var.left)
1075-
or pybamm.is_matrix_one(var.right)
1076-
)
1077-
)
1078-
1079-
if different_shapes and not_concatenation and not_mult_by_one_vec:
1080-
raise pybamm.ModelError(
1081-
"variable and its eqn must have the same shape after "
1082-
"discretisation but variable.shape = "
1083-
f"{var.shape} and rhs.shape = {model.rhs[rhs_var].shape} for variable '{var}'. "
1084-
)
1085-
10861097
def is_variable_independent(self, var, all_vars_in_eqns):
10871098
pybamm.logger.verbose("Removing independent blocks.")
10881099
if not isinstance(var, pybamm.Variable):

src/pybamm/expression_tree/concatenations.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55
import copy
66
from collections import defaultdict
7+
from typing import Optional
78

89
import numpy as np
910
import sympy
@@ -146,9 +147,9 @@ def _concatenation_new_copy(self, children, perform_simplifications: bool = True
146147
children before creating the new copy.
147148
"""
148149
if perform_simplifications:
149-
return concatenation(*children)
150+
return concatenation(*children, name=self.name)
150151
else:
151-
return self.__class__(*children)
152+
return self.__class__(*children, name=self.name)
152153

153154
def _concatenation_jac(self, children_jacs):
154155
"""Calculate the Jacobian of a concatenation."""
@@ -468,17 +469,18 @@ def _concatenation_new_copy(self, children, perform_simplifications=True):
468469
class ConcatenationVariable(Concatenation):
469470
"""A Variable representing a concatenation of variables."""
470471

471-
def __init__(self, *children):
472-
# Name is the intersection of the children names (should usually make sense
473-
# if the children have been named consistently)
474-
name = intersect(children[0].name, children[1].name)
475-
for child in children[2:]:
476-
name = intersect(name, child.name)
477-
if len(name) == 0:
478-
name = None
479-
# name is unchanged if its length is 1
480-
elif len(name) > 1:
481-
name = name[0].capitalize() + name[1:]
472+
def __init__(self, *children, name: Optional[str] = None):
473+
if name is None:
474+
# Name is the intersection of the children names (should usually make sense
475+
# if the children have been named consistently)
476+
name = intersect(children[0].name, children[1].name)
477+
for child in children[2:]:
478+
name = intersect(name, child.name)
479+
if len(name) == 0:
480+
name = None
481+
# name is unchanged if its length is 1
482+
elif len(name) > 1:
483+
name = name[0].capitalize() + name[1:]
482484

483485
if len(children) > 0:
484486
if all(child.scale == children[0].scale for child in children):
@@ -523,7 +525,7 @@ def intersect(s1: str, s2: str):
523525
return intersect.lstrip().rstrip()
524526

525527

526-
def simplified_concatenation(*children):
528+
def simplified_concatenation(*children, name: Optional[str] = None):
527529
"""Perform simplifications on a concatenation."""
528530
# remove children that are None
529531
children = list(filter(lambda x: x is not None, children))
@@ -534,29 +536,29 @@ def simplified_concatenation(*children):
534536
elif len(children) == 1:
535537
return children[0]
536538
elif all(isinstance(child, pybamm.Variable) for child in children):
537-
return pybamm.ConcatenationVariable(*children)
539+
return pybamm.ConcatenationVariable(*children, name=name)
538540
else:
539541
# Create Concatenation to easily read domains
540-
concat = Concatenation(*children)
542+
concat = Concatenation(*children, name=name)
541543
if all(
542544
isinstance(child, pybamm.Broadcast) and child.child == children[0].child
543545
for child in children
544546
):
545547
unique_child = children[0].orphans[0]
546548
if isinstance(children[0], pybamm.PrimaryBroadcast):
547-
return pybamm.PrimaryBroadcast(unique_child, concat.domain)
549+
return pybamm.PrimaryBroadcast(unique_child, concat.domain, name=name)
548550
else:
549551
return pybamm.FullBroadcast(
550-
unique_child, broadcast_domains=concat.domains
552+
unique_child, broadcast_domains=concat.domains, name=name
551553
)
552554
else:
553555
return concat
554556

555557

556-
def concatenation(*children):
558+
def concatenation(*children, name: Optional[str] = None):
557559
"""Helper function to create concatenations."""
558560
# TODO: add option to turn off simplifications
559-
return simplified_concatenation(*children)
561+
return simplified_concatenation(*children, name=name)
560562

561563

562564
def simplified_numpy_concatenation(*children):

src/pybamm/models/submodels/oxygen_diffusion/full_oxygen_diffusion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def get_fundamental_variables(self):
3636
domain="positive electrode",
3737
auxiliary_domains={"secondary": "current collector"},
3838
)
39-
c_ox_s_p = pybamm.concatenation(c_ox_s, c_ox_p)
39+
c_ox_s_p = pybamm.concatenation(
40+
c_ox_s,
41+
c_ox_p,
42+
name="Separator and positive electrode oxygen concentration [mol.m-3]",
43+
)
4044
variables = {
4145
"Separator and positive electrode oxygen concentration [mol.m-3]": c_ox_s_p
4246
}

tests/unit/test_discretisations/test_discretisation.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,9 +1036,9 @@ def test_concatenation_2D(self):
10361036
assert expr.children[2].evaluate(0, y).shape == (105, 1)
10371037

10381038
def test_exceptions(self):
1039-
c_n = pybamm.Variable("c", domain=["negative electrode"])
1039+
c_n = pybamm.Variable("c_n", domain=["negative electrode"])
10401040
N_n = pybamm.grad(c_n)
1041-
c_s = pybamm.Variable("c", domain=["separator"])
1041+
c_s = pybamm.Variable("c_s", domain=["separator"])
10421042
N_s = pybamm.grad(c_s)
10431043
model = pybamm.BaseModel()
10441044
model.rhs = {c_n: pybamm.div(N_n), c_s: pybamm.div(N_s)}
@@ -1049,22 +1049,6 @@ def test_exceptions(self):
10491049
}
10501050

10511051
disc = get_discretisation_for_testing()
1052-
1053-
# check raises error if different sized key and output var
1054-
model.variables = {c_n.name: c_s}
1055-
with pytest.raises(pybamm.ModelError, match="variable and its eqn"):
1056-
disc.process_model(model)
1057-
1058-
# check doesn't raise if concatenation
1059-
model.variables = {c_n.name: pybamm.concatenation(2 * c_n, 3 * c_s)}
1060-
disc.process_model(model, inplace=False)
1061-
1062-
# check doesn't raise if broadcast
1063-
model.variables = {
1064-
c_n.name: pybamm.PrimaryBroadcast(
1065-
pybamm.InputParameter("a"), ["negative electrode"]
1066-
)
1067-
}
10681052
disc.process_model(model)
10691053

10701054
# Check setting up a 0D spatial method with 1D mesh raises error
@@ -1277,3 +1261,29 @@ def test_independent_rhs_with_event(self):
12771261
disc = pybamm.Discretisation(remove_independent_variables_from_rhs=True)
12781262
disc.process_model(model)
12791263
assert len(model.rhs) == 3
1264+
1265+
def test_pre_process_variables(self):
1266+
a = pybamm.Variable("a")
1267+
b = pybamm.Variable("b")
1268+
model = pybamm.BaseModel()
1269+
model.rhs = {a: b, b: a}
1270+
model.initial_conditions = {
1271+
a: pybamm.Scalar(0),
1272+
b: pybamm.Scalar(1),
1273+
}
1274+
model.variables = {
1275+
"a": a, # correct
1276+
# b missing
1277+
}
1278+
disc = pybamm.Discretisation()
1279+
disc_model = disc.process_model(model, inplace=False)
1280+
assert list(disc_model.variables.keys()) == ["a", "b"]
1281+
1282+
model.variables = {
1283+
"a": a,
1284+
"b": 2 * a,
1285+
}
1286+
with pytest.raises(
1287+
pybamm.ModelError, match="Variable 'b' should have expression"
1288+
):
1289+
disc.process_model(model, inplace=False)

0 commit comments

Comments
 (0)