Engineering

Train LLMs on AMD GPU: ROCm, MI300X/MI355X, ZAYA1 (2026)

train LLM on AMD GPUAMD Instinct LLM trainingROCm pretrainingMI300X model training cloudROCm PyTorch trainingMegatron FSDP ROCmMI300X vs MI355X trainingAMD GPU Cloud
Train LLMs on AMD GPU: ROCm, MI300X/MI355X, ZAYA1 (2026)

Zyphra trained ZAYA1 entirely on AMD Instinct hardware, released it under Apache 2.0, and published enough about the stack to make the "can AMD do pretraining" conversation shift from "theoretically possible" to "someone already did it." CUDA has been the default for foundation-model training not because ROCm was categorically incapable, but because of ecosystem depth: FlashAttention 3, NCCL maturity, Megatron's kernel library, and years of battle-tested tooling. ZAYA1 proved you don't need all of that to train a production model.

This guide covers what the AMD pretraining stack actually looks like in 2026: hardware specs for MI300X and MI355X, the ROCm software stack with CUDA equivalents, how to port a training script, where the gotchas still live, cost math against NVIDIA at the same token budget, and how to run multi-node AMD training on Spheron.

One clarification up front: ZAYA1 refers specifically to ZAYA1-8B, a mixture-of-experts model with 8.4B total parameters and 760M active. Zyphra trained it on 1,024 AMD Instinct MI300X GPUs in a cluster co-built with IBM, using AMD Pensando Pollara networking. Per-GPU throughput figures are not publicly reported, so the throughput estimates in this guide are based on community benchmarks.

Why CUDA Lock-In Was Real (and What ZAYA1 Changed)

The CUDA moat in pretraining came from a stack of compounding dependencies. FlashAttention 3 is Hopper-specific and delivers 1.5-2x attention throughput over its predecessor. NCCL has years of tuning for distributed collectives. Megatron-Core's fused attention kernels and optimized data loaders are CUDA-native. The NVIDIA container registry ships ready-to-run images with everything pre-compiled.

On ROCm in 2023-2024, the equivalent pieces existed but with friction: Flash Attention 2 (ROCm/CK build) worked but needed manual configuration, RCCL lagged NCCL in throughput on some topologies, and torch.compile() on ROCm had more INDUCTOR backend misses. Teams with tight throughput requirements or custom CUDA kernels had real reasons to stay on NVIDIA.

What ZAYA1 changed is the risk perception. When a company trains a foundation model on AMD hardware, ships it under a permissive license, and it works, it removes the "unproven" label from the stack. Teams evaluating AMD for cost reasons now have a concrete proof point. For a broader comparison of what the ROCm stack looks like for inference workloads, see ROCm vs CUDA: AMD vs NVIDIA AI Stack Compared.

The 2026 AMD Pretraining Stack

The table below maps every major training stack component to its CUDA equivalent and its ROCm equivalent:

ComponentCUDA EquivalentROCm EquivalentNotes
ContainerNGC CUDA imagesrocm/pytorch:rocm6.2_*AMD's Docker Hub has versioned images per ROCm+PyTorch combo
PyTorch wheelpip install torch --index-url pytorch.org/whl/cu121pip install torch --index-url pytorch.org/whl/rocm6.2Different --index-url only
Collective commsNCCLRCCLAPI-compatible drop-in; env vars are identical
AttentionFlashAttention 3 (Hopper-only)Flash Attention 2 (ROCm/CK build)FA3 not available on ROCm; FA2 closes most of the gap
Model parallelismMegatron-Core (full CUDA)Megatron-Core (ROCm partial)DP and PP fully supported; some fused attention kernels fall back
ZeRO optimizerDeepSpeed ZeRO-1/2/3DeepSpeed ZeRO-1/2/3 (DS v0.6+)Full ROCm support
Distributed launchtorchrun, SLURM + sruntorchrun, SLURM + srunIdentical launch syntax
GPU monitoringnvidia-smirocm-smirocm-smi --showmeminfo vram for memory
ProfilingNsight Systems / Computerocprof, OmniperfROCm profiler ecosystem is less mature but functional
Device stringcuda / CUDA_VISIBLE_DEVICEShip or cuda (via HIP) / AMD_VISIBLE_DEVICESHIP intercepts cuda calls automatically

The PyTorch device string interception is the key insight: tensor.cuda() and .to("cuda") work without changes on ROCm because HIP intercepts those calls. You get a ROCm GPU without touching your model code. The friction is at the edges: custom CUDA kernels, FlashAttention 3 dependencies, and any code that calls into CUDA-specific C extensions.

For the parallelism strategy decisions (FSDP2 vs ZeRO-3 vs Megatron 3D), see the distributed LLM training guide, which covers all three in depth with multi-node torchrun configs.

MI300X vs MI355X: Which Hardware for Pretraining

SpecMI300XMI355X
HBM Memory192 GB HBM3288 GB HBM3e
Memory Bandwidth5.3 TB/s8 TB/s
FP16 TFLOPS~1,307 dense (2,615 with sparsity)~2,500 dense (5,000 with sparsity)
FP8 TrainingAvailableAvailable
UALink SupportNoNo (Helios rack only)

The 192 GB memory on MI300X is the most practically significant number. A 70B model in BF16 weighs 140 GB. On an NVIDIA H100 with 80 GB, you cannot fit a 70B model on a single GPU without quantization or model parallelism. On MI300X, it fits on one GPU with 52 GB to spare. That eliminates tensor parallelism all-reduces for many 70B runs, which is a meaningful throughput advantage.

When to use MI300X:

  • 7B to 70B pretraining where the model fits in 192 GB
  • Cost-optimized runs: MI300X market rates run roughly 55-60% lower than H100 SXM5 on-demand
  • Teams migrating from multi-GPU NVIDIA setups and looking to reduce per-node count

When to use MI355X:

  • 70B+ models in FP16 that you want to run without FSDP sharding across nodes (140 GB weights + optimizer states exceed MI300X capacity)
  • Long-sequence pretraining (128K+ tokens) where the 8 TB/s bandwidth reduces attention memory pressure
  • Pre-production scale runs where you want to minimize inter-node communication overhead

For a more detailed hardware spec comparison between AMD and NVIDIA silicon, the AMD MI300X vs NVIDIA H200 guide covers the hardware in depth.

Porting a Training Run from CUDA to ROCm

Container and device flags

CUDA:

bash
docker run --gpus all \
  -it nvcr.io/nvidia/pytorch:24.01-py3

ROCm:

bash
docker run \
  --device=/dev/kfd \
  --device=/dev/dri \
  --security-opt seccomp=unconfined \
  -it rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_2.4.0

The --device=/dev/kfd and --device=/dev/dri flags expose AMD GPU kernel driver and DRM devices instead of the NVIDIA device nodes. --security-opt seccomp=unconfined is required for ROCm GPU access.

Environment variables

bash
# CUDA version
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_DISABLE=0

# ROCm equivalent
export AMD_VISIBLE_DEVICES=0,1,2,3,4,5,6,7   # replaces CUDA_VISIBLE_DEVICES
export NCCL_SOCKET_IFNAME=eth0                 # same variable, RCCL reads it
export RCCL_DEBUG=INFO                         # ROCm equivalent of NCCL_DEBUG

RCCL reads the same NCCL_* environment variables as NCCL. Your torchrun launch script needs no changes beyond swapping CUDA_VISIBLE_DEVICES for AMD_VISIBLE_DEVICES.

Custom CUDA kernel porting

bash
# Convert a custom CUDA kernel to HIP
hipify-perl my_cuda_kernel.cu > my_hip_kernel.hip

# Rebuild with hipcc instead of nvcc
hipcc -O3 -o my_hip_kernel.so my_hip_kernel.hip

hipify-perl handles most CUDA API calls automatically. What it cannot convert: PTX assembly (requires full rewrite as GCN/CDNA assembly), CUDA intrinsics with no HIP equivalent, and any code that calls into cuBLAS/cuDNN directly (use rocBLAS/MIOpen instead).

DeepSpeed ROCm install

bash
# Standard install (no ROCm ops compilation)
pip install deepspeed

# With ROCm-compiled ops (faster, requires ROCm dev headers)
DS_BUILD_OPS=1 pip install deepspeed

Gotchas that will burn time

1. FlashAttention 3 is Hopper-specific. If your training script sets attn_implementation="flash_attention_3", it will fail on ROCm. Use attn_implementation="flash_attention_2" instead. FA2 (ROCm/CK build) is available and delivers most of the throughput benefit of FA3 for sequence lengths under 8K.

2. bitsandbytes ROCm support is partial. BnB's 4-bit quantization ops rely on CUDA kernels. On ROCm, use GPTQ or AWQ quantization instead if you need quantized pretraining or QLoRA. Full-precision and BF16 pretraining on ROCm does not need BnB.

3. torch.compile() on ROCm has more INDUCTOR backend misses. The compiler works but compilation time is longer and more ops fall back to eager mode. This is a ROCm-specific limitation as of 2026. Benchmark your throughput with and without torch.compile() before committing to a long pretraining run.

4. Custom PTX assembly requires manual rewrite. hipify-perl will not help with PTX. If your pretraining stack uses PTX-level attention kernels or memory access patterns, those need full rewrites as CDNA assembly.

5. FP8 pretraining numerics may differ. MI300X and MI355X support FP8 training via ROCm, but the FP8 number representation and rounding behavior may differ from H100 FP8. If you are porting an FP8 pretraining recipe from H100, validate loss curves carefully for the first 500-1000 steps before assuming it is equivalent.

For context on custom kernel development in CUDA that would need porting, see the CUDA 13 tile programming guide.

Cost Comparison: AMD vs NVIDIA at the Same Token Budget

AMD MI300X is not currently listed in Spheron's GPU inventory, so the AMD pricing below reflects market-rate estimates from providers offering MI-series capacity. The NVIDIA prices come from the Spheron API (fetched 25 Jun 2026):

GPUOn-Demand $/hrSpot $/hrMemoryEst. BF16 Tokens/sec/GPU
H100 SXM5$5.01$2.9180 GB HBM32,500-3,000
H200 SXM5$5.55$3.31141 GB HBM3e2,800-3,400
A100 80G SXM4$1.69$0.8280 GB HBM2e1,400-1,800
MI300X (market est.)~$1.80-2.30varies192 GB HBM32,200-2,800
MI355X (market est.)~$2.50-3.20varies288 GB HBM3e2,600-3,200

Pricing fluctuates based on GPU availability. The prices above are based on 25 Jun 2026 and may have changed. Check current GPU pricing → for live rates.

Cost-per-1T-tokens math:

The formula: (1T / (throughput * 3600)) * $/hr gives cost per GPU for 1T tokens.

For H100 SXM5 at 2,750 tokens/sec and $5.01/hr:

  • (1,000,000,000,000 / (2750 * 3600)) * 5.01 = ~$506,000 per GPU. For an 8-GPU node training the same 1T token budget, total node cost is also ~$506,000: 8 GPUs finish ~8× faster at 8× the hourly rate, so GPU-hours stay constant.

For MI300X at 2,500 tokens/sec and $2.05/hr (midpoint market estimate):

  • (1,000,000,000,000 / (2500 * 3600)) * 2.05 = ~$228,000 per GPU. An 8-GPU node training the same 1T token budget also costs ~$228,000 by the same logic.

The AMD advantage compounds from two directions: lower $/hr at comparable throughput, and fewer GPUs needed for large models due to 192+ GB memory. A 70B model that needs 2 H100 nodes for FSDP sharding may fit on 1 MI300X node, cutting the cluster cost in half for that specific workload.

The caveats: these MI300X throughput estimates are based on community benchmarks, not controlled Spheron measurements. Real throughput depends on framework version, attention implementation, and your specific model architecture. The token budget context for why 1T tokens matters for pretraining is covered in detail in the AI pretraining data curation guide.

Launching Multi-Node AMD Training on Spheron

Step 1: Provision AMD GPU instances

Go to app.spheron.ai and check current AMD GPU availability under the compute marketplace. Use spot instances for experimental short runs; use on-demand for multi-day pretraining where preemption would lose significant checkpoint progress.

Step 2: Pull the ROCm container

bash
docker pull rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_2.4.0

AMD publishes versioned images on Docker Hub under rocm/pytorch. Always pin to a specific version for reproducible training environments.

Step 3: Install training dependencies

bash
# Inside the ROCm container
pip install transformers datasets accelerate

# DeepSpeed with ROCm ops
DS_BUILD_OPS=1 pip install deepspeed

# Megatron-Core (from source for ROCm)
git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM
pip install -e .

Step 4: Multi-node torchrun launch

bash
# On each node (replace NODE_RANK and MASTER_ADDR)
export MASTER_ADDR=<node0_private_ip>
export MASTER_PORT=29500
export AMD_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export RCCL_DEBUG=INFO
export NCCL_SOCKET_IFNAME=eth0

torchrun \
  --nnodes=4 \
  --node_rank=$NODE_RANK \
  --nproc_per_node=8 \
  --master_addr=$MASTER_ADDR \
  --master_port=$MASTER_PORT \
  pretrain.py \
  --model-name my-7b-config \
  --bf16

RCCL handles the collective communication automatically. The --nnodes and --node_rank flags are identical to CUDA-based setups.

Step 5: Spot preemption handler

python
import signal
import sys

checkpoint_dir = "/checkpoints"

def handle_sigterm(signum, frame):
    print("SIGTERM received, saving emergency checkpoint...")
    save_checkpoint(
        step=current_step,
        model=model,
        optimizer=optimizer,
        path=f"{checkpoint_dir}/emergency_step_{current_step}.pt"
    )
    sys.exit(1)

signal.signal(signal.SIGTERM, handle_sigterm)

On Spheron spot instances, the preemption signal fires 30-120 seconds before reclamation. This handler writes an emergency checkpoint before the instance exits. On restart, the training loop polls checkpoint_dir for the latest checkpoint and resumes from there.

Step 6: Checkpoint configuration

Save incremental checkpoints (optimizer state + step counter) every 100-200 steps. Save full checkpoints every 500-1000 steps. Store checkpoints to a persistent NFS mount or S3-compatible store accessible from all nodes.

python
if step % 100 == 0:
    # Incremental: fast, saves optimizer state
    save_incremental_checkpoint(step, optimizer_state, lr_scheduler_state)

if step % 500 == 0:
    # Full: slower, saves everything needed for cold restart
    save_full_checkpoint(step, model, optimizer, lr_scheduler)

For complete checkpointing patterns including distributed checkpoint formats and NFS/S3 configuration for spot training, see the spot GPU training resilience guide.


Spheron gives you access to AMD Instinct and NVIDIA GPU capacity in one marketplace, so you can run the same token budget on both and pick based on real cost, not vendor claims.

H100 SXM5 on Spheron → | Check AMD GPU availability →

Get started on Spheron →

STEPS / 05

Quick Setup Guide

  1. Pull the AMD ROCm PyTorch Docker image

    Start from AMD's official image: docker run --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined -it rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_2.4.0. Verify GPU visibility with rocm-smi. This replaces the NVIDIA docker run --gpus all flag.

  2. Install RCCL and ROCm-compatible training libraries

    RCCL (ROCm Communication Collectives Library) is the AMD equivalent of NCCL and ships with the ROCm toolkit. Install DeepSpeed with ROCm support: DS_BUILD_OPS=1 pip install deepspeed. Install Megatron-Core from source on the ROCm PyTorch image. For vLLM inference after training, use pip install vllm --extra-index-url https://download.pytorch.org/whl/rocm6.2.

  3. Port CUDA training script to ROCm

    Run hipify-perl on any custom CUDA kernels: hipify-perl my_cuda_kernel.cu > my_hip_kernel.hip. Most PyTorch code needs no changes - HIP intercepts 'cuda' device calls. Replace torch.cuda.* calls with their ROCm equivalents only if you use CUDA-specific intrinsics. Remove FlashAttention 3 dependencies (use Flash Attention 2 (ROCm/CK build) instead).

  4. Configure multi-node RCCL for AMD training

    Set NCCL_SOCKET_IFNAME to your network interface, RCCL_DEBUG=INFO for diagnostics. RCCL is API-compatible with NCCL, so torchrun launch scripts work unchanged. Set AMD_VISIBLE_DEVICES instead of CUDA_VISIBLE_DEVICES to control GPU assignment. For InfiniBand-equipped nodes, RCCL activates RDMA automatically.

  5. Launch a pretraining job on Spheron with spot resilience

    Provision MI300X or MI355X instances on Spheron at app.spheron.ai. Use spot pricing for experimental runs and on-demand for production pretraining where preemption tolerance is low. Configure checkpointing to a mounted NFS volume or S3-compatible store. Wrap your torchrun launch with a restart loop that polls the checkpoint directory on restart.

FAQ / 06

Frequently Asked Questions

Yes. Zyphra's ZAYA1 (Apache 2.0) was trained entirely on AMD Instinct hardware using ROCm, proving CUDA is no longer a hard requirement for foundation-model pretraining. The key stack is ROCm 6.x, PyTorch (AMD wheel), RCCL for inter-GPU communication, and Megatron-Core or FSDP2 for model parallelism. Most of the CUDA-to-ROCm porting is handled automatically by the HIP compatibility layer.

MI300X ships with 192 GB HBM3 at 5.3 TB/s. MI355X carries 288 GB HBM3e at 8 TB/s - more memory and more bandwidth, which helps on long-sequence pretraining and very large model sizes that would require multi-GPU setups on MI300X. For most 7B-70B pretraining runs, MI300X is the more cost-effective choice. For 70B+ in FP16 without model parallelism, MI355X's 288 GB removes the need for FSDP sharding across nodes.

Start from AMD's official ROCm PyTorch Docker image (rocm/pytorch:latest). Most PyTorch code runs unchanged because ROCm's HIP layer intercepts 'cuda' device calls. Custom CUDA kernels need hipify-perl conversion. Replace NCCL with RCCL (which is API-compatible). Replace 'nvidia-smi' with 'rocm-smi'. Install vLLM, DeepSpeed, and other libraries from their ROCm wheels. The trickiest porting targets are fused attention kernels and custom PTX assembly, which require manual rewriting.

Megatron-Core works on ROCm with caveats: data parallelism and pipeline parallelism are fully supported; some fused attention kernels fall back to a slower PyTorch path without Flash Attention. DeepSpeed ZeRO-3 has full ROCm support (v0.6+). RCCL replaces NCCL for collective communication and is API-compatible, so launch scripts need no changes beyond the device flag.

At comparable throughput and current cloud pricing, MI300X typically runs 50-60% cheaper per token than H100 SXM5 on-demand. In conservative real-world scenarios where ROCm throughput is lower (around 2,200 tok/s vs H100's 3,000 tok/s), the saving narrows to roughly 37%. A concrete example at the midpoint: H100 SXM5 on-demand is ~$5.01/hr at current Spheron rates ($0.51/M tokens at 2,750 tok/s), while MI300X at $2.05/hr market rate runs ~$0.23/M tokens at 2,500 tok/s, a ~55% saving. The gap widens on spot pricing. The memory advantage also means fewer GPUs are needed for large models. Check current GPU pricing for live rates.

Save incremental checkpoints (optimizer states + step counter) every 100-200 steps to persistent network storage. Save full checkpoints every 500-1000 steps. Configure a preemption signal handler to flush an incremental checkpoint on SIGTERM before the spot instance is reclaimed. On Spheron, the preemption webhook fires before reclamation, giving you 30-120 seconds to write a checkpoint.

Build what's next.

The most cost-effective platform for building, training, and scaling machine learning models-ready when you are.