JAX v0.7.1
-
New features
- JAX now ships Python 3.14 and 3.14t wheels.
- JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only
offered free-threading builds on Linux.
-
Changes
- Exposed
jax.set_meshwhich acts as a global setter and a context manager.
Removedjax.sharding.use_meshin favor ofjax.set_mesh. - JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain
supported. jax.lax.dotnow implements the general dot product via the optional
dimension_numbersargument.
- Exposed
-
Deprecations:
jax.lax.zeros_like_arrayis deprecated. Please use
jax.numpy.zeros_likeinstead.- Attempting to import
jax.experimental.host_callbacknow results in
aDeprecationWarning, and will result in anImportErrorstarting in JAX
v0.8.0. Its APIs have raisedNotImplementedErrorsince JAX version 0.4.35. - In
jax.lax.dot, passing theprecisionandpreferred_element_type
arguments by position is deprecated. Pass them by explicit keyword instead. - Several dozen internal APIs have been deprecated from
jax.interpreters.ad,
jax.interpreters.batching, andjax.interpreters.partial_eval; they
are used rarely if ever outside JAX itself, and most are deprecated without any
public replacement.