JAX compatibility#

2025-06-12

7 min read time

Applies to Linux

JAX provides a NumPy-like API, which combines automatic differentiation and the Accelerated Linear Algebra (XLA) compiler to achieve high-performance machine learning at scale.

JAX uses composable transformations of Python and NumPy through just-in-time (JIT) compilation, automatic vectorization, and parallelization. To learn about JAX, including profiling and optimizations, see the official JAX documentation.

ROCm support for JAX is upstreamed, and users can build the official source code with ROCm support:

Note

AMD releases official ROCm JAX Docker images quarterly alongside new ROCm releases. These images undergo full AMD testing. Community ROCm JAX Docker images follow upstream JAX releases and use the latest available ROCm version.

Use cases and recommendations#

  • The nanoGPT in JAX blog explores the implementation and training of a Generative Pre-trained Transformer (GPT) model in JAX, inspired by Andrej Karpathy’s JAX-based nanoGPT. Comparing how essential GPT components—such as self-attention mechanisms and optimizers—are realized in JAX and JAX, also highlights JAX’s unique features.

  • The Optimize GPT Training: Enabling Mixed Precision Training in JAX using ROCm on AMD GPUs blog post provides a comprehensive guide on enhancing the training efficiency of GPT models by implementing mixed precision techniques in JAX, specifically tailored for AMD GPUs utilizing the ROCm platform.

  • The Supercharging JAX with Triton Kernels on AMD GPUs blog demonstrates how to develop a custom fused dropout-activation kernel for matrices using Triton, integrate it with JAX, and benchmark its performance using ROCm.

  • The Distributed fine-tuning with JAX on AMD GPUs outlines the process of fine-tuning a Bidirectional Encoder Representations from Transformers (BERT)-based large language model (LLM) using JAX for a text classification task. The blog post discuss techniques for parallelizing the fine-tuning across multiple AMD GPUs and assess the model’s performance on a holdout dataset. During the fine-tuning, a BERT-base-cased transformer model and the General Language Understanding Evaluation (GLUE) benchmark dataset was used on a multi-GPU setup.

  • The MI300X workload optimization guide provides detailed guidance on optimizing workloads for the AMD Instinct MI300X accelerator using ROCm. The page is aimed at helping users achieve optimal performance for deep learning and other high-performance computing tasks on the MI300X GPU.

For more use cases and recommendations, see ROCm JAX blog posts.

Docker image compatibility#

AMD validates and publishes ready-made ROCm JAX Docker images with ROCm backends on Docker Hub. The following Docker image tags and associated inventories represent the latest JAX version from the official Docker Hub and are validated for ROCm 6.4.1. Click the icon to view the image on Docker Hub.

JAX Docker image components#

Docker image

JAX

Linux

Python

rocm/jax

0.4.35

Ubuntu 24.04

3.12.10

rocm/jax

0.4.35

Ubuntu 22.04

3.10.17

AMD publishes Community ROCm JAX Docker images with ROCm backends on Docker Hub. The following Docker image tags and associated inventories are tested for ROCm 6.3.2.

JAX community Docker image components#

Docker image

JAX

Linux

Python

rocm/jax-community

0.5.0

Ubuntu 22.04

3.12.8

rocm/jax-community

0.5.0

Ubuntu 22.04

3.11.11

rocm/jax-community

0.5.0

Ubuntu 22.04

3.10.16

Key ROCm libraries for JAX#

The following ROCm libraries represent potential targets that could be utilized by JAX on ROCm for various computational tasks. The actual libraries used will depend on the specific implementation and operations performed.

ROCm library

Version

Purpose

hipBLAS

2.4.0

Provides GPU-accelerated Basic Linear Algebra Subprograms (BLAS) for matrix and vector operations.

hipBLASLt

0.12.1

hipBLASLt is an extension of hipBLAS, providing additional features like epilogues fused into the matrix multiplication kernel or use of integer tensor cores.

hipCUB

3.4.0

Provides a C++ template library for parallel algorithms for reduction, scan, sort and select.

hipFFT

1.0.18

Provides GPU-accelerated Fast Fourier Transform (FFT) operations.

hipRAND

2.12.0

Provides fast random number generation for GPUs.

hipSOLVER

2.4.0

Provides GPU-accelerated solvers for linear systems, eigenvalues, and singular value decompositions (SVD).

hipSPARSE

3.2.0

Accelerates operations on sparse matrices, such as sparse matrix-vector or matrix-matrix products.

hipSPARSELt

0.2.3

Accelerates operations on sparse matrices, such as sparse matrix-vector or matrix-matrix products.

MIOpen

3.4.0

Optimized for deep learning primitives such as convolutions, pooling, normalization, and activation functions.

RCCL

2.22.3

Optimized for multi-GPU communication for operations like all-reduce, broadcast, and scatter.

rocThrust

3.3.0

Provides a C++ template library for parallel algorithms like sorting, reduction, and scanning.

Note

This table shows ROCm libraries that could potentially be utilized by JAX. Not all libraries may be used in every configuration, and the actual library usage will depend on the specific operations and implementation details.

Supported data types and modules#

The following tables lists the supported public JAX API data types and modules.

Supported data types#

ROCm supports all the JAX data types of jax.dtypes module, jax.numpy.dtype and default_dtype . The ROCm supported data types in JAX are collected in the following table.

Data type

Description

bfloat16

16-bit bfloat (brain floating point).

bool

Boolean.

complex128

128-bit complex.

complex64

64-bit complex.

float16

16-bit (half precision) floating-point.

float32

32-bit (single precision) floating-point.

float64

64-bit (double precision) floating-point.

half

16-bit (half precision) floating-point.

int16

Signed 16-bit integer.

int32

Signed 32-bit integer.

int64

Signed 64-bit integer.

int8

Signed 8-bit integer.

uint16

Unsigned 16-bit (word) integer.

uint32

Unsigned 32-bit (dword) integer.

uint64

Unsigned 64-bit (qword) integer.

uint8

Unsigned 8-bit (byte) integer.

Note

JAX data type support is effected by the Key ROCm libraries for JAX and it’s collected on ROCm data types and precision support page.

Supported modules#

For a complete and up-to-date list of JAX public modules (for example, jax.numpy, jax.scipy, jax.lax), their descriptions, and usage, please refer directly to the official JAX API documentation.

Note

Since version 0.1.56, JAX has full support for ROCm, and the Known issues and important notes section contains details about limitations specific to the ROCm backend. The list of JAX API modules is maintained by the JAX project and is subject to change. Refer to the official Jax documentation for the most up-to-date information.