Skip to content

Commit 44ef21e

Browse files
pfackeldeymatthewfeickert
authored andcommitted
fix jax backend tolist for tracers in logging
1 parent 10488f0 commit 44ef21e

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/pyhf/tensor/jax_backend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
config.update('jax_enable_x64', True)
44

5+
import jax
56
from jax import Array
67
import jax.numpy as jnp
78
from jax.scipy.special import gammaln, xlogy
@@ -14,6 +15,10 @@
1415
log = logging.getLogger(__name__)
1516

1617

18+
def currently_jitting():
19+
return isinstance(jnp.array(1) + 1, jax.core.Tracer)
20+
21+
1722
class _BasicPoisson:
1823
def __init__(self, rate):
1924
self.rate = rate
@@ -184,6 +189,9 @@ def conditional(self, predicate, true_callable, false_callable):
184189
return true_callable() if predicate else false_callable()
185190

186191
def tolist(self, tensor_in):
192+
if currently_jitting():
193+
# .aval is the abstract value and has a little nicer representation
194+
return tensor_in.aval
187195
try:
188196
return jnp.asarray(tensor_in).tolist()
189197
except (TypeError, ValueError):

0 commit comments

Comments
 (0)