Jax: jax-v0.7.2 Release

Release date:
September 15, 2025
Previous version:
jax-v0.7.1 (released August 19, 2025)
Magnitude:
14,732 Diff Delta
Contributors:
33 total committers
Data confidence:
Commits:

163 Commits in this Release

Ordered by the degree to which they evolved the repo in this version.

Authored September 1, 2025
Authored September 4, 2025
Authored September 10, 2025
Authored August 27, 2025

Top Contributors in jax-v0.7.2

allanrenucci
justinjfu
bchetioui
jakevdp
dimitar-asenov
ZacCranko
cperivol
IvyZX
WindQAQ
18praveenb

Directory Browser for jax-v0.7.2

We haven't yet finished calculating and confirming the files and directories changed in this release. Please check back soon.

Release Notes Published

  • Breaking changes:

    • jax.dlpack.from_dlpack no longer accepts a DLPack capsule. This behavior was deprecated and is now removed. The function must be called with an array implementing __dlpack__ and __dlpack_device__.
  • Changes

    • The minimum supported NumPy version is now 2.0. Since SciPy 1.13 is required for NumPy 2.0 support, the minimum supported SciPy version is now 1.13.
    • JAX now represents constants in its internal jaxpr representation as a LiteralArray, which is a private JAX type that duck types as a numpy.ndarray. This type may be exposed to users via custom_jvp rules, for example, and may break code that uses isinstance(x, np.ndarray). If this breaks your code, you may convert these arrays to classic NumPy arrays using np.asarray(x).
  • Bug fixes

    • arr.view(dtype=None) now returns the array unchanged, matching NumPy's semantics. Previously it returned the array with a float dtype.
    • jax.random.randint now produces a less-biased distribution for 8-bit and 16-bit integer types ({jax-issue}#27742). To restore the previous biased behavior, you may temporarily set the jax_safer_randint configuration to False, but note this is a temporary config that will be removed in a future release.
  • Deprecations:

    • The parameters enable_xla and native_serialization for jax2tf.convert are deprecated and will be removed in a future version of JAX. These were used for jax2tf with non-native serialization, which has been now removed.
    • Setting the config state jax_pmap_no_rank_reduction to False is deprecated. By default, jax_pmap_no_rank_reduction will be set to True and jax.pmap shards will not have their rank reduced, keeping the same rank as their enclosing array.