JAX v0.6.1
-
New features:
- Added
jax.lax.axis_sizewhich returns the size of the mapped axis
given its name.
- Added
-
Changes
- Additional checking for the versions of CUDA package dependencies was
reenabled, having been accidentally disabled in a previous release. - JAX nightly packages are now published to artifact registry. To install
these packages, see the JAX installation guide. jax.sharding.PartitionSpecno longer inherits from a tuple.jax.ShapeDtypeStructis immutable now. Please use.updatemethod to
update yourShapeDtypeStructinstead of doing in-place updates.
- Additional checking for the versions of CUDA package dependencies was
-
Deprecations
jax.custom_derivatives.custom_jvp_call_jaxpr_pis deprecated, and will be
removed in JAX v0.7.0.