PyTorch compatibility#

2025-06-18

13 min read time

Applies to Linux

PyTorch is an open-source tensor library designed for deep learning. PyTorch on ROCm provides mixed-precision and large-scale training using MIOpen and RCCL libraries.

ROCm support for PyTorch is upstreamed into the official PyTorch repository. Due to independent compatibility considerations, this results in two distinct release cycles for PyTorch on ROCm:

PyTorch includes tooling that generates HIP source code from the CUDA backend. This approach allows PyTorch to support ROCm without requiring manual code modifications. For more information, see HIPIFY.

ROCm development is aligned with the stable release of PyTorch, while upstream PyTorch testing uses the stable release of ROCm to maintain consistency.

Use cases and recommendations#

  • Using ROCm for AI: training a model guides how to leverage the ROCm platform for training AI models. It covers the steps, tools, and best practices for optimizing training workflows on AMD GPUs using PyTorch features.

  • Single-GPU fine-tuning and inference describes and demonstrates how to use the ROCm platform for the fine-tuning and inference of machine learning models, particularly large language models (LLMs), on systems with a single GPU. This topic provides a detailed guide for setting up, optimizing, and executing fine-tuning and inference workflows in such environments.

  • Multi-GPU fine-tuning and inference optimization describes and demonstrates the fine-tuning and inference of machine learning models on systems with multiple GPUs.

  • The Instinct MI300X workload optimization guide provides detailed guidance on optimizing workloads for the AMD Instinct MI300X accelerator using ROCm. This guide helps users achieve optimal performance for deep learning and other high-performance computing tasks on the MI300X accelerator.

  • The Inception with PyTorch documentation describes how PyTorch integrates with ROCm for AI workloads It outlines the use of PyTorch on the ROCm platform and focuses on efficiently leveraging AMD GPU hardware for training and inference tasks in AI applications.

For more use cases and recommendations, see ROCm PyTorch blog posts.

Docker image compatibility#

AMD validates and publishes PyTorch images with ROCm backends on Docker Hub. The following Docker image tags and associated inventories were tested on ROCm 6.4.1. Click to view the image on Docker Hub.

Key ROCm libraries for PyTorch#

PyTorch functionality on ROCm is determined by its underlying library dependencies. These ROCm components affect the capabilities, performance, and feature set available to developers.

ROCm library

Version

Purpose

Used in

Composable Kernel

1.1.0

Enables faster execution of core operations like matrix multiplication (GEMM), convolutions and transformations.

Speeds up torch.permute, torch.view, torch.matmul, torch.mm, torch.bmm, torch.nn.Conv2d, torch.nn.Conv3d and torch.nn.MultiheadAttention.

hipBLAS

2.4.0

Provides GPU-accelerated Basic Linear Algebra Subprograms (BLAS) for matrix and vector operations.

Supports operations such as matrix multiplication, matrix-vector products, and tensor contractions. Utilized in both dense and batched linear algebra operations.

hipBLASLt

0.12.1

hipBLASLt is an extension of the hipBLAS library, providing additional features like epilogues fused into the matrix multiplication kernel or use of integer tensor cores.

Accelerates operations such as torch.matmul, torch.mm, and the matrix multiplications used in convolutional and linear layers.

hipCUB

3.4.0

Provides a C++ template library for parallel algorithms for reduction, scan, sort and select.

Supports operations such as torch.sum, torch.cumsum, torch.sort irregular shapes often involve scanning, sorting, and filtering, which hipCUB handles efficiently.

hipFFT

1.0.18

Provides GPU-accelerated Fast Fourier Transform (FFT) operations.

Used in functions like the torch.fft module.

hipRAND

2.12.0

Provides fast random number generation for GPUs.

The torch.rand, torch.randn, and stochastic layers like torch.nn.Dropout rely on hipRAND.

hipSOLVER

2.4.0

Provides GPU-accelerated solvers for linear systems, eigenvalues, and singular value decompositions (SVD).

Supports functions like torch.linalg.solve, torch.linalg.eig, and torch.linalg.svd.

hipSPARSE

3.2.0

Accelerates operations on sparse matrices, such as sparse matrix-vector or matrix-matrix products.

Sparse tensor operations torch.sparse.

hipSPARSELt

0.2.3

Accelerates operations on sparse matrices, such as sparse matrix-vector or matrix-matrix products.

Sparse tensor operations torch.sparse.

hipTensor

1.5.0

Optimizes for high-performance tensor operations, such as contractions.

Accelerates tensor algebra, especially in deep learning and scientific computing.

MIOpen

3.4.0

Optimizes deep learning primitives such as convolutions, pooling, normalization, and activation functions.

Speeds up convolutional neural networks (CNNs), recurrent neural networks (RNNs), and other layers. Used in operations like torch.nn.Conv2d, torch.nn.ReLU, and torch.nn.LSTM.

MIGraphX

2.12.0

Adds graph-level optimizations, ONNX models and mixed precision support and enable Ahead-of-Time (AOT) Compilation.

Speeds up inference models and executes ONNX models for compatibility with other frameworks. torch.nn.Conv2d, torch.nn.ReLU, and torch.nn.LSTM.

MIVisionX

3.2.0

Optimizes acceleration for computer vision and AI workloads like preprocessing, augmentation, and inferencing.

Faster data preprocessing and augmentation pipelines for datasets like ImageNet or COCO and easy to integrate into PyTorch’s torch.utils.data and torchvision workflows.

rocAL

2.2.0

Accelerates the data pipeline by offloading intensive preprocessing and augmentation tasks. rocAL is part of MIVisionX.

Easy to integrate into PyTorch’s torch.utils.data and torchvision data load workloads.

RCCL

2.22.3

Optimizes for multi-GPU communication for operations like AllReduce and Broadcast.

Distributed data parallel training (torch.nn.parallel.DistributedDataParallel). Handles communication in multi-GPU setups.

rocDecode

0.10.0

Provides hardware-accelerated data decoding capabilities, particularly for image, video, and other dataset formats.

Can be integrated in torch.utils.data, torchvision.transforms and torch.distributed.

rocJPEG

0.8.0

Provides hardware-accelerated JPEG image decoding and encoding.

GPU accelerated torchvision.io.decode_jpeg and torchvision.io.encode_jpeg and can be integrated in torch.utils.data and torchvision.

RPP

1.9.10

Speeds up data augmentation, transformation, and other preprocessing steps.

Easy to integrate into PyTorch’s torch.utils.data and torchvision data load workloads to speed up data processing.

rocThrust

3.3.0

Provides a C++ template library for parallel algorithms like sorting, reduction, and scanning.

Utilized in backend operations for tensor computations requiring parallel processing.

rocWMMA

1.7.0

Accelerates warp-level matrix-multiply and matrix-accumulate to speed up matrix multiplication (GEMM) and accumulation operations with mixed precision support.

Linear layers (torch.nn.Linear), convolutional layers (torch.nn.Conv2d), attention layers, general tensor operations that involve matrix products, such as torch.matmul, torch.bmm, and more.

Supported modules and data types#

The following section outlines the supported data types, modules, and ___domain libraries available in PyTorch on ROCm.

Supported data types#

The tensor data type is specified using the dtype attribute or argument. PyTorch supports many data types for different use cases.

The following table lists torch.Tensor single data types:

Data type

Description

torch.float8_e4m3fn

8-bit floating point, e4m3

torch.float8_e5m2

8-bit floating point, e5m2

torch.float16 or torch.half

16-bit floating point

torch.bfloat16

16-bit floating point

torch.float32 or torch.float

32-bit floating point

torch.float64 or torch.double

64-bit floating point

torch.complex32 or torch.chalf

32-bit complex numbers

torch.complex64 or torch.cfloat

64-bit complex numbers

torch.complex128 or torch.cdouble

128-bit complex numbers

torch.uint8

8-bit integer (unsigned)

torch.uint16

16-bit integer (unsigned); Not natively supported in ROCm

torch.uint32

32-bit integer (unsigned); Not natively supported in ROCm

torch.uint64

64-bit integer (unsigned); Not natively supported in ROCm

torch.int8

8-bit integer (signed)

torch.int16 or torch.short

16-bit integer (signed)

torch.int32 or torch.int

32-bit integer (signed)

torch.int64 or torch.long

64-bit integer (signed)

torch.bool

Boolean

torch.quint8

Quantized 8-bit integer (unsigned)

torch.qint8

Quantized 8-bit integer (signed)

torch.qint32

Quantized 32-bit integer (signed)

torch.quint4x2

Quantized 4-bit integer (unsigned)

Note

Unsigned types, except uint8, have limited support in eager mode. They primarily exist to assist usage with torch.compile.

See ROCm precision support for the native hardware support of data types.

Supported modules#

For a complete and up-to-date list of PyTorch core modules (for example., torch, torch.nn, torch.cuda, torch.backends.cuda and torch.backends.cudnn), their descriptions, and usage, please refer directly to the official PyTorch documentation.

Core PyTorch functionality on ROCm includes tensor operations, neural network layers, automatic differentiation, distributed training, mixed-precision training, compilation features, and ___domain-specific libraries for audio, vision, text processing, and more.

Supported ___domain libraries#

PyTorch offers specialized ___domain libraries with GPU acceleration that build on its core features to support specific application areas. The table below lists the PyTorch ___domain libraries that are compatible with ROCm.

Library

Description

torchaudio

Audio and signal processing library for PyTorch. Provides utilities for audio I/O, signal and data processing functions, datasets, model implementations, and application components for audio and speech processing tasks.

Note: To ensure GPU-acceleration with torchaudio.transforms, you need to explicitly move audio data (waveform tensor) to GPU using .to('cuda').

torchtune

PyTorch-native library designed for fine-tuning large language models (LLMs). Provides supports the full fine-tuning workflow and offers compatibility with popular production inference systems.

Note: Only official release exists.

torchvision

Computer vision library that is part of the PyTorch project. Provides popular datasets, model architectures, and common image transformations for computer vision applications.

torchtext

Text processing library for PyTorch. Provides data processing utilities and popular datasets for natural language processing, including tokenization, vocabulary management, and text embeddings.

Note: torchtext does not implement ROCm-specific kernels. ROCm acceleration is provided through the underlying PyTorch framework and ROCm library integration. Only official release exists.

torchdata

Beta library of common modular data loading primitives for easily constructing flexible and performant data pipelines, with features still in prototype stage.

torchrec

PyTorch ___domain library for common sparsity and parallelism primitives needed for large-scale recommender systems, enabling authors to train models with large embedding tables shared across many GPUs.

Note: torchrec does not implement ROCm-specific kernels. ROCm acceleration is provided through the underlying PyTorch framework and ROCm library integration.

torchserve

Performant, flexible and easy-to-use tool for serving PyTorch models in production, providing features for model management, batch processing, and scalable deployment.

Note: torchserve is no longer actively maintained. Last official release is sent out with PyTorch 2.4.

torchrl

Open-source, Python-first Reinforcement Learning library for PyTorch with a focus on high modularity and good runtime performance, providing low and high-level RL abstractions and reusable functionals for cost functions, returns, and data processing.

Note: Only official release exists.

tensordict

Dictionary-like class that simplifies operations on batches of tensors, enhancing code readability, compactness, and modularity by abstracting tailored operations and reducing errors through automatic operation dispatching.

Note: Only official release exists.