Skip to content

Commit 142aae0

Browse files
Merge pull request #1959 from pybamm-team/i1863-casadi-idklu
I1898 make evaluator functions consistent across casadi, python and jax
2 parents 795a0c0 + 278b9f7 commit 142aae0

File tree

14 files changed

+504
-530
lines changed

14 files changed

+504
-530
lines changed

pybamm/expression_tree/operations/evaluate_python.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def __init__(self, symbol):
465465
# add function def to first line
466466
python_str = (
467467
"def evaluate(constants, t=None, y=None, "
468-
"y_dot=None, inputs=None, known_evals=None):\n" + python_str
468+
"inputs=None):\n" + python_str
469469
)
470470

471471
# calculate the final variable that will output the result of calling `evaluate`
@@ -491,21 +491,17 @@ def __init__(self, symbol):
491491
compiled_function = compile(python_str, result_var, "exec")
492492
exec(compiled_function)
493493

494-
def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
494+
def __call__(self, t=None, y=None, inputs=None):
495495
"""
496-
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
496+
evaluate function
497497
"""
498498
# generated code assumes y is a column vector
499499
if y is not None and y.ndim == 1:
500500
y = y.reshape(-1, 1)
501501

502-
result = self._evaluate(self._constants, t, y, y_dot, inputs, known_evals)
502+
result = self._evaluate(self._constants, t, y, inputs)
503503

504-
# don't need known_evals, but need to reproduce Symbol.evaluate signature
505-
if known_evals is not None:
506-
return result, known_evals
507-
else:
508-
return result
504+
return result
509505

510506
def __getstate__(self):
511507
# Control the state of instances of EvaluatorPython
@@ -581,7 +577,7 @@ def __init__(self, symbol):
581577
python_str = python_str.replace("\n", "\n ")
582578

583579
# add function def to first line
584-
args = "t=None, y=None, y_dot=None, inputs=None, known_evals=None"
580+
args = "t=None, y=None, inputs=None"
585581
if self._arg_list:
586582
args = ",".join(self._arg_list) + ", " + args
587583
python_str = "def evaluate_jax({}):\n".format(args) + python_str
@@ -628,23 +624,23 @@ def get_jacobian(self):
628624
def get_sensitivities(self):
629625
n = len(self._arg_list)
630626

631-
# forward mode autodiff wrt inputs, which is argument 3 after arg_list
632-
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=3 + n)
627+
# forward mode autodiff wrt inputs, which is argument 2 after arg_list
628+
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=2 + n)
633629

634630
self._sens_evaluate = jax.jit(
635631
jacobian_evaluate, static_argnums=self._static_argnums
636632
)
637633

638634
return EvaluatorJaxSensitivities(self._sens_evaluate, self._constants)
639635

640-
def debug(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
636+
def debug(self, t=None, y=None, inputs=None):
641637
# generated code assumes y is a column vector
642638
if y is not None and y.ndim == 1:
643639
y = y.reshape(-1, 1)
644640

645641
# execute code
646642
jaxpr = jax.make_jaxpr(self._evaluate_jax)(
647-
*self._constants, t, y, y_dot, inputs, known_evals
643+
*self._constants, t, y, inputs
648644
).jaxpr
649645
print("invars:", jaxpr.invars)
650646
print("outvars:", jaxpr.outvars)
@@ -654,52 +650,45 @@ def debug(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
654650
print()
655651
print("jaxpr:", jaxpr)
656652

657-
def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
653+
def __call__(self, t=None, y=None, inputs=None):
658654
"""
659-
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
655+
evaluate function
660656
"""
661657
# generated code assumes y is a column vector
662658
if y is not None and y.ndim == 1:
663659
y = y.reshape(-1, 1)
664660

665-
result = self._jit_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
661+
result = self._jit_evaluate(*self._constants, t, y, inputs)
666662

667-
# don't need known_evals, but need to reproduce Symbol.evaluate signature
668-
if known_evals is not None:
669-
return result, known_evals
670-
else:
671-
return result
663+
return result
672664

673665

674666
class EvaluatorJaxJacobian:
675667
def __init__(self, jac_evaluate, constants):
676668
self._jac_evaluate = jac_evaluate
677669
self._constants = constants
678670

679-
def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
671+
def __call__(self, t=None, y=None, inputs=None):
680672
"""
681-
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
673+
evaluate function
682674
"""
683675
# generated code assumes y is a column vector
684676
if y is not None and y.ndim == 1:
685677
y = y.reshape(-1, 1)
686678

687679
# execute code
688-
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
680+
result = self._jac_evaluate(*self._constants, t, y, inputs)
689681
result = result.reshape(result.shape[0], -1)
690682

691-
if known_evals is not None:
692-
return result, known_evals
693-
else:
694-
return result
683+
return result
695684

696685

697686
class EvaluatorJaxSensitivities:
698687
def __init__(self, jac_evaluate, constants):
699688
self._jac_evaluate = jac_evaluate
700689
self._constants = constants
701690

702-
def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
691+
def __call__(self, t=None, y=None, inputs=None):
703692
"""
704693
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
705694
"""
@@ -708,9 +697,6 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
708697
y = y.reshape(-1, 1)
709698

710699
# execute code
711-
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
700+
result = self._jac_evaluate(*self._constants, t, y, inputs)
712701

713-
if known_evals is not None:
714-
return result, known_evals
715-
else:
716-
return result
702+
return result

pybamm/solvers/algebraic_solver.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def _integrate(self, model, t_eval, inputs_dict=None):
6767

6868
y0 = model.y0
6969
if isinstance(y0, casadi.DM):
70-
y0 = y0.full().flatten()
70+
y0 = y0.full()
71+
y0 = y0.flatten()
7172

7273
# The casadi algebraic solver can read rhs equations, but leaves them unchanged
7374
# i.e. the part of the solution vector that corresponds to the differential
@@ -80,7 +81,16 @@ def _integrate(self, model, t_eval, inputs_dict=None):
8081
len_rhs = model.rhs_eval(t_eval[0], y0, inputs).shape[0]
8182
y0_diff, y0_alg = np.split(y0, [len_rhs])
8283

83-
algebraic = model.algebraic_eval
84+
test_result = model.algebraic_eval(0, y0, inputs)
85+
86+
if isinstance(test_result, casadi.DM):
87+
def algebraic(t, y):
88+
result = model.algebraic_eval(t, y, inputs)
89+
return result.full().flatten()
90+
else:
91+
def algebraic(t, y):
92+
result = model.algebraic_eval(t, y, inputs)
93+
return result.flatten()
8494

8595
y_alg = np.empty((len(y0_alg), len(t_eval)))
8696

@@ -91,7 +101,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
91101
def root_fun(y_alg):
92102
"Evaluates algebraic using y"
93103
y = np.concatenate([y0_diff, y_alg])
94-
out = algebraic(t, y, inputs)
104+
out = algebraic(t, y)
95105
pybamm.logger.debug(
96106
"Evaluating algebraic equations at t={}, L2-norm is {}".format(
97107
t * model.timescale_eval, np.linalg.norm(out)

0 commit comments

Comments
 (0)