@@ -465,7 +465,7 @@ def __init__(self, symbol):
465465 # add function def to first line
466466 python_str = (
467467 "def evaluate(constants, t=None, y=None, "
468- "y_dot=None, inputs=None, known_evals =None):\n " + python_str
468+ "inputs=None):\n " + python_str
469469 )
470470
471471 # calculate the final variable that will output the result of calling `evaluate`
@@ -491,21 +491,17 @@ def __init__(self, symbol):
491491 compiled_function = compile (python_str , result_var , "exec" )
492492 exec (compiled_function )
493493
494- def evaluate (self , t = None , y = None , y_dot = None , inputs = None , known_evals = None ):
494+ def __call__ (self , t = None , y = None , inputs = None ):
495495 """
496- Acts as a drop-in replacement for :func:`pybamm.Symbol. evaluate`
496+ evaluate function
497497 """
498498 # generated code assumes y is a column vector
499499 if y is not None and y .ndim == 1 :
500500 y = y .reshape (- 1 , 1 )
501501
502- result = self ._evaluate (self ._constants , t , y , y_dot , inputs , known_evals )
502+ result = self ._evaluate (self ._constants , t , y , inputs )
503503
504- # don't need known_evals, but need to reproduce Symbol.evaluate signature
505- if known_evals is not None :
506- return result , known_evals
507- else :
508- return result
504+ return result
509505
510506 def __getstate__ (self ):
511507 # Control the state of instances of EvaluatorPython
@@ -581,7 +577,7 @@ def __init__(self, symbol):
581577 python_str = python_str .replace ("\n " , "\n " )
582578
583579 # add function def to first line
584- args = "t=None, y=None, y_dot=None, inputs=None, known_evals =None"
580+ args = "t=None, y=None, inputs=None"
585581 if self ._arg_list :
586582 args = "," .join (self ._arg_list ) + ", " + args
587583 python_str = "def evaluate_jax({}):\n " .format (args ) + python_str
@@ -628,23 +624,23 @@ def get_jacobian(self):
628624 def get_sensitivities (self ):
629625 n = len (self ._arg_list )
630626
631- # forward mode autodiff wrt inputs, which is argument 3 after arg_list
632- jacobian_evaluate = jax .jacfwd (self ._evaluate_jax , argnums = 3 + n )
627+ # forward mode autodiff wrt inputs, which is argument 2 after arg_list
628+ jacobian_evaluate = jax .jacfwd (self ._evaluate_jax , argnums = 2 + n )
633629
634630 self ._sens_evaluate = jax .jit (
635631 jacobian_evaluate , static_argnums = self ._static_argnums
636632 )
637633
638634 return EvaluatorJaxSensitivities (self ._sens_evaluate , self ._constants )
639635
640- def debug (self , t = None , y = None , y_dot = None , inputs = None , known_evals = None ):
636+ def debug (self , t = None , y = None , inputs = None ):
641637 # generated code assumes y is a column vector
642638 if y is not None and y .ndim == 1 :
643639 y = y .reshape (- 1 , 1 )
644640
645641 # execute code
646642 jaxpr = jax .make_jaxpr (self ._evaluate_jax )(
647- * self ._constants , t , y , y_dot , inputs , known_evals
643+ * self ._constants , t , y , inputs
648644 ).jaxpr
649645 print ("invars:" , jaxpr .invars )
650646 print ("outvars:" , jaxpr .outvars )
@@ -654,52 +650,45 @@ def debug(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
654650 print ()
655651 print ("jaxpr:" , jaxpr )
656652
657- def evaluate (self , t = None , y = None , y_dot = None , inputs = None , known_evals = None ):
653+ def __call__ (self , t = None , y = None , inputs = None ):
658654 """
659- Acts as a drop-in replacement for :func:`pybamm.Symbol. evaluate`
655+ evaluate function
660656 """
661657 # generated code assumes y is a column vector
662658 if y is not None and y .ndim == 1 :
663659 y = y .reshape (- 1 , 1 )
664660
665- result = self ._jit_evaluate (* self ._constants , t , y , y_dot , inputs , known_evals )
661+ result = self ._jit_evaluate (* self ._constants , t , y , inputs )
666662
667- # don't need known_evals, but need to reproduce Symbol.evaluate signature
668- if known_evals is not None :
669- return result , known_evals
670- else :
671- return result
663+ return result
672664
673665
674666class EvaluatorJaxJacobian :
675667 def __init__ (self , jac_evaluate , constants ):
676668 self ._jac_evaluate = jac_evaluate
677669 self ._constants = constants
678670
679- def evaluate (self , t = None , y = None , y_dot = None , inputs = None , known_evals = None ):
671+ def __call__ (self , t = None , y = None , inputs = None ):
680672 """
681- Acts as a drop-in replacement for :func:`pybamm.Symbol. evaluate`
673+ evaluate function
682674 """
683675 # generated code assumes y is a column vector
684676 if y is not None and y .ndim == 1 :
685677 y = y .reshape (- 1 , 1 )
686678
687679 # execute code
688- result = self ._jac_evaluate (* self ._constants , t , y , y_dot , inputs , known_evals )
680+ result = self ._jac_evaluate (* self ._constants , t , y , inputs )
689681 result = result .reshape (result .shape [0 ], - 1 )
690682
691- if known_evals is not None :
692- return result , known_evals
693- else :
694- return result
683+ return result
695684
696685
697686class EvaluatorJaxSensitivities :
698687 def __init__ (self , jac_evaluate , constants ):
699688 self ._jac_evaluate = jac_evaluate
700689 self ._constants = constants
701690
702- def evaluate (self , t = None , y = None , y_dot = None , inputs = None , known_evals = None ):
691+ def __call__ (self , t = None , y = None , inputs = None ):
703692 """
704693 Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
705694 """
@@ -708,9 +697,6 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
708697 y = y .reshape (- 1 , 1 )
709698
710699 # execute code
711- result = self ._jac_evaluate (* self ._constants , t , y , y_dot , inputs , known_evals )
700+ result = self ._jac_evaluate (* self ._constants , t , y , inputs )
712701
713- if known_evals is not None :
714- return result , known_evals
715- else :
716- return result
702+ return result
0 commit comments