Skip to content

Commit b8e097e

Browse files
committed
#1887 fix sens subset test
1 parent e39ec59 commit b8e097e

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

pybamm/solvers/base_solver.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,32 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
248248
y_and_S = y_casadi
249249

250250
# if we will change the equations to include the explicit sensitivity
251-
# equations, then we also need to update the mass matrix and bounds
251+
# equations, then we also need to update the mass matrix and bounds.
252+
# First, we reset the mass matrix and bounds back to their original form
253+
# if they have been extended
254+
if model.bounds[0].shape[0] > model.len_rhs_and_alg:
255+
model.bounds = (
256+
model.bounds[0][: model.len_rhs_and_alg],
257+
model.bounds[1][: model.len_rhs_and_alg],
258+
)
259+
if (
260+
model.mass_matrix is not None
261+
and model.mass_matrix.shape[0] > model.len_rhs_and_alg
262+
):
263+
if model.mass_matrix_inv is not None:
264+
model.mass_matrix_inv = pybamm.Matrix(
265+
model.mass_matrix_inv.entries[
266+
: model.len_rhs, : model.len_rhs
267+
]
268+
)
269+
model.mass_matrix = pybamm.Matrix(
270+
model.mass_matrix.entries[
271+
: model.len_rhs_and_alg, : model.len_rhs_and_alg
272+
]
273+
)
274+
275+
# now we can extend them by the number of sensitivity parameters
276+
# if needed
252277
if calculate_sensitivities_explicit:
253278
if model.len_rhs != 0:
254279
n_inputs = model.len_rhs_sens // model.len_rhs
@@ -276,28 +301,6 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
276301
[model.mass_matrix.entries] * (n_inputs + 1), format="csr"
277302
)
278303
)
279-
else:
280-
# take care if calculate_sensitivites used then not used
281-
if model.bounds[0].shape[0] > model.len_rhs_and_alg:
282-
model.bounds = (
283-
model.bounds[0][: model.len_rhs_and_alg],
284-
model.bounds[1][: model.len_rhs_and_alg],
285-
)
286-
if (
287-
model.mass_matrix is not None
288-
and model.mass_matrix.shape[0] > model.len_rhs_and_alg
289-
):
290-
if model.mass_matrix_inv is not None:
291-
model.mass_matrix_inv = pybamm.Matrix(
292-
model.mass_matrix_inv.entries[
293-
: model.len_rhs, : model.len_rhs
294-
]
295-
)
296-
model.mass_matrix = pybamm.Matrix(
297-
model.mass_matrix.entries[
298-
: model.len_rhs_and_alg, : model.len_rhs_and_alg
299-
]
300-
)
301304

302305
def process(symbol, name, use_jacobian=None):
303306
def report(string):

tests/unit/test_solvers/test_casadi_solver.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -911,10 +911,6 @@ def test_solve_sensitivity_subset(self):
911911
calculate_sensitivities=["r"],
912912
)
913913
np.testing.assert_allclose(solution.y[0], -1 + 0.2 * solution.t)
914-
np.testing.assert_allclose(
915-
solution.sensitivities["r"],
916-
(2 * solution.t)[:, np.newaxis],
917-
)
918914
self.assertTrue("p" not in solution.sensitivities)
919915
self.assertTrue("q" not in solution.sensitivities)
920916
np.testing.assert_allclose(solution.sensitivities["r"], 1)

0 commit comments

Comments
 (0)