Skip to content

Commit 9435a0a

Browse files
author
jax authors
committed
Merge pull request #18138 from mattjj:shmap-axis-env-fix
PiperOrigin-RevId: 574540561
2 parents d55085f + 3bfe1d2 commit 9435a0a

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

jax/experimental/shard_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1429,7 +1429,7 @@ def _partial_eval_jaxpr_custom_rule(
14291429
with core.extend_axis_env_nd(mesh.shape.items()):
14301430
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
14311431
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
1432-
jaxpr_known, jaxpr_staged = _add_reshapes(num_res, jaxpr_known, jaxpr_staged)
1432+
jaxpr_known, jaxpr_staged = _add_reshapes(num_res, jaxpr_known, jaxpr_staged)
14331433
ins_known, _ = partition_list(unks_in, eqn.invars)
14341434
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
14351435
_, ins_staged = partition_list(inst_in, eqn.invars)

tests/shard_map_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,18 @@ def body(q, k, v):
11351135

11361136
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)
11371137

1138+
def test_axis_env_extension_regression(self):
1139+
def foo(x):
1140+
i = jax.lax.axis_index('x')
1141+
return jnp.exp(x) + i.astype('float')
1142+
1143+
@partial(jax.remat, policy=lambda *args, **kwargs: True)
1144+
def bar(x):
1145+
return shard_map(foo, mesh=Mesh(jax.devices(), ['x']), in_specs=(P('x'),),
1146+
out_specs=P('x'), check_rep=False)(x)
1147+
1148+
jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash
1149+
11381150

11391151
class FunSpec(NamedTuple):
11401152
name: str

0 commit comments

Comments
 (0)