Release date:
April 16, 2026
Magnitude:
62,264
Diff Delta
Contributors:
59 total committers
587 Commits
in this Release
Ordered by the degree to which they evolved the repo in this version.
Authored December 2, 2025
Browse Other Releases
Latest Pending
Unreleased π
jaxlib-v0.1.37
Released December 6, 2019
0 Ξ
jaxlib-v0.1.36
Released November 22, 2019
0 Ξ
jaxlib-v0.1.33
Released May 15, 2026
429,407 Ξ
jax-v0.10.1
Released May 18, 2026
31,116 Ξ
jax-v0.10.0
Released April 16, 2026
62,264 Ξ
jax-v0.9.2
Released March 18, 2026
14,078 Ξ
jax-v0.9.1
Released February 28, 2026
17,693 Ξ
jax-v0.9.0.1
Released February 3, 2026
0 Ξ
jax-v0.9.0
Released January 20, 2026
0 Ξ
jax-v0.8.3
Released January 28, 2026
23 Ξ
Top Contributors in jax-v0.10.0
Release Notes Published
New features:
- Added
ResizeMethod.CUBIC_PYTORCH to jax.image.resize to match
PyTorch's bicubic resize (#15768).
- We now support differentiation of jax.lax.linalg.qr for wide
matrices and when
full_matrices is True.
- LAPACK operations are now parallelized along the batch dimension on CPU.
- Added
perturb_singular argument to
jax.lax.linalg.tridiagonal_solve to handle singular matrices by
perturbing near-zero pivots in the LU decomposition. This is useful for
solving numerically singular systems when computing eigenvectors by inverse
iteration.
- jax.scipy.linalg.eigh_tridiagonal now supports computing
eigenvectors on CPU and GPU.
- Added the jax.numpy.ndarray.byteswap method.
Breaking changes:
PartitionSpec objects no longer report themselves to be equal to tuples.
Convert tuples to PartitionSpec objects before testing equality.
- The
.vma property has been removed from jax.core.ShapedArray. Use
.manual_axis_type.varying instead.
- JAX CPU devices now report their names as
cpu:0, cpu:1, etc. instead of
TFRT_CPU_0, TFRT_CPU_1.
- The config state
jax_pmap_shmap_merge has been removed. jax.pmap
will now always use the new implementation that wraps
jax.jit(jax.shard_map). Please see
https://docs.jax.dev/en/latest/migrate_pmap.html for more information.
jax.device_put_sharded and jax.device_put_replicated have been removed
from the public API and now raise an AttributeError when accessed.
Please see
https://docs.jax.dev/en/latest/migrate_pmap.html#drop-in-replacements for
drop-in replacements.
- The C++ pmap infrastructure has been removed. The following public APIs
are no longer available:
jax.sharding.PmapSharding
- From
jaxlib.xla_extension: PmapFunction, pmap,
NoSharding, Chunked, Unstacked, ShardedAxis, Replicated,
ShardingSpec.
- From
jax.interpreters.pxla: MapTracer, PmapExecutable,
parallel_callable, shard_args, xla_pmap_p, Chunked,
NoSharding, Replicated, ShardedAxis, ShardingSpec,
Unstacked, spec_to_indices.
- The deprecated keyword arguments
a, a_min, and a_max to
jax.numpy.clip have been removed.
- Functions
jax.numpy.hstack, jax.numpy.vstack, jax.numpy.dstack,
jax.numpy.column_stack, jax.numpy.atleast_1d, jax.numpy.atleast_2d,
and jax.numpy.atleast_3d no longer accept non-ArrayLike inputs.
Doing so previously issued a DeprecationWarning.
- jax.scipy.stats.rankdata now returns floating point values in
all cases, following a similar change in the SciPy 1.18 release.
Deprecations:
- A number of internal APIs in
jax.core have been newly deprecated and
some have been moved to jax.extend.core. These include CallPrimitive,
DebugInfo, DropVar, Effect, Effects, InconclusiveDimensionOperation,
JaxprTypeError, check_jaxpr, concrete_or_error, find_top_trace,
gensym, get_opaque_trace_state, jaxprs_in_params, new_jaxpr_eqn,
no_effects, nonempty_axis_env_DO_NOT_USE, primal_dtype_to_tangent_dtype,
unsafe_am_i_under_a_jit_DO_NOT_USE, unsafe_am_i_under_a_vmap_DO_NOT_USE,
unsafe_get_axis_names_DO_NOT_USE, valid_jaxtype, JaxprPpContext,
JaxprPpSettings, OutputType, abstract_token, aval_mapping_handlers,
call, concretization_function_error, custom_typechecks, is_concrete,
is_constant_dim, is_constant_shape, literalable_types, no_axis_name,
pytype_aval_mappings, and trace_ctx.
Changes:
- The minimum supported SciPy version is now 1.14.
vma parameter of jax.ShapeDtypeStruct has been replaced with
manual_axis_type: jax.sharding.ManualAxisType. The .vma property has
been replaced with .manual_axis_type.varying.
- Removed experimental jax.experimental.custom_dce.custom_dce
jax.scipy.linalg.cho_solve, jax.scipy.linalg.lu_solve, and
jax.scipy.linalg.solve_triangular now show a deprecation warning for
batched 1D solves with b.ndim > 1. In the future these will be treated as
batched 2D solves.
- Added a new version 10 for the jax.export serialization format. This is
an optimization for when there are multiple occurrences of the same
abstract value, abstract mesh, or sharding.
Bug fixes:
- Fixed a bug that led to differing output between CPU and GPU for
non-symmetric multidimensional IRFFTs (#29325).
- Fixed an error when tiny matrices were passed to
jax.lax.linalg.tridiagonal_solve on GPU (#32487).
- Fixed a bug in
jax.scipy.fft.dctn and idctn where axes=None
incorrectly defaulted to all axes when s was specified, instead of the
last len(s) axes to match SciPy behavior (#29426).
- Fixed a bug where calling
jax.distributed.initialize() on a GCE TPU
Managed Instance Group raised an IndexError (#36593). When
jax.distributed.initialize() is called on a GCE VM, it uses the GCE
metadata server to learn the addresses of all participating tasks. The format of this metadata
on Managed Instance Groups was not a format JAX expected, leading to the
exception. We now parse this format correctly.