Engineering

Spot GPU Training Resilience: Checkpointing, Preemption Recovery, and Fault-Tolerant LLM Fine-Tuning (2026)

Back to BlogWritten by Mitrasish, Co-founderMay 15, 2026
spot GPU trainingpreemption checkpointing LLMfault-tolerant GPU traininginterruptible GPU fine-tuningspot instance ML trainingFSDP checkpointingDeepSpeed ZeRO-3GPU Cloud
Spot GPU Training Resilience: Checkpointing, Preemption Recovery, and Fault-Tolerant LLM Fine-Tuning (2026)

Spot GPU pricing is 50-70% cheaper than on-demand. Most teams avoid it for long fine-tuning runs because they are afraid of losing progress when an instance gets reclaimed. That fear is reasonable if you have no checkpointing infrastructure. Once you have it, spot interruptions become a minor annoyance (averaging 10-15 minutes of recovery time) rather than a training disaster. The engineering is not complicated. This guide covers exactly what to build: preemption signal handling, checkpoint strategies for FSDP and ZeRO-3, optimizer state preservation, async offload to avoid stalling training, and self-healing job controllers.

For the cost argument, the spot GPU training case study documents how a 12-person AI startup completed a 70B model fine-tune for $11,200 on spot GPUs. That post covers the business logic. This one covers the implementation.

Spot GPU Economics: What 60% Savings Actually Means for a Long Fine-Tune

Here is what the difference looks like on an 8x H100 SXM5 cluster running for a full week:

ConfigurationGPUCountDurationOn-Demand RateSpot RateOn-Demand TotalSpot TotalSavings
8x H100 SXM5, 1-week fine-tuneH100 SXM58168 hrs$4.00/hr/GPU$1.69/hr/GPU$5,376$2,271~58%

Pricing fluctuates based on GPU availability. The prices above are based on 15 May 2026 and may have changed. Check current GPU pricing for live rates.

The on-demand rate above is the current lowest H100 SXM5 on-demand price on Spheron. The spot rate is the lowest available at time of writing. The actual gap you see on a given day depends on availability, but 50-60% cheaper than lowest on-demand is the typical range.

That $3,105 difference for a single week-long run is significant. For teams running 3-4 training iterations before converging on a final model, the spot strategy saves $9,000-12,000 on compute alone. For the broader cost-optimization picture across training and inference, see the GPU cost optimization playbook.

Anatomy of a Preemption Signal

When a spot instance is reclaimed, the sequence looks like this:

  1. The cloud provider flags the instance for reclamation (internally, no notification yet)
  2. A SIGTERM signal is delivered to the process group of the main training process
  3. The metadata endpoint (if the provider has one) changes state to indicate pending shutdown
  4. The instance is shut down, typically 30-120 seconds after SIGTERM

What can go wrong if you do nothing: your training script exits immediately on SIGTERM (Python's default), leaving no checkpoint at all. Or worse, a checkpoint write was in progress when the signal arrived, leaving a corrupted partial file that causes cryptic errors when you try to resume.

On Spheron, a preemption webhook fires before instance reclamation, giving you a notification window to trigger cleanup. Check your provider's documentation for the preemption webhook endpoint and payload format.

The fix is straightforward: register a SIGTERM handler that saves an emergency checkpoint before exiting.

python
import signal
import sys

_trainer = None

def handle_preemption(signum, frame):
    print("Preemption signal received. Saving emergency checkpoint...")
    if _trainer is not None:
        _trainer.save_checkpoint(output_dir="/persistent-storage/checkpoints/emergency")
    sys.exit(0)

signal.signal(signal.SIGTERM, handle_preemption)

# After you create your trainer:
_trainer = trainer

Keep the checkpoint write under 30 seconds. If a full checkpoint takes longer (common for 70B models), use an incremental checkpoint that only saves optimizer state and RNG state. Those are much smaller and write quickly. The full weight checkpoint can happen on a regular schedule; the SIGTERM handler should only flush the incremental state.

Good GPU health monitoring can catch degradation signals before the preemption actually arrives, giving you more time to write a clean checkpoint.

Checkpointing Strategy: Full vs Incremental vs Sharded

Three checkpoint types serve different purposes:

Full Checkpoints

A full checkpoint saves model weights, optimizer states, LR scheduler state, and RNG state. For a 70B BF16 model, this is roughly:

  • Model weights: 140GB (70B × 2 bytes)
  • Optimizer states (AdamW, m + v, fp32): 560GB
  • Master weights (fp32, mixed precision): 280GB
  • Total per rank with FSDP across 8 GPUs: ~122GB

At typical NVMe write speeds (3-5 GB/s), a full checkpoint for a 70B model takes 25-45 seconds per rank. Use these every 500-1000 steps for disaster recovery, not after every interruption.

Incremental (Resume) Checkpoints

An incremental checkpoint skips model weights entirely. It saves optimizer states, LR scheduler state, step counter, and RNG state. For the same 70B model, the fp32 AdamW optimizer state (m and v moments, 70B × 2 × 4 bytes) totals ~560GB, sharding to ~70GB per rank across 8 GPUs. Write time is roughly 15-25 seconds per rank at typical NVMe speeds, roughly half the time of a full checkpoint.

Here is a SpotCheckpointCallback that extends the basic pattern with async write and integrity verification:

python
import os
import threading
import torch
from transformers import TrainerCallback

class SpotCheckpointCallback(TrainerCallback):
    def __init__(self, save_dir, full_interval=500, resume_interval=100):
        self.save_dir = save_dir
        self.full_interval = full_interval
        self.resume_interval = resume_interval
        self._write_thread = None

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % self.resume_interval != 0:
            return

        checkpoint_path = os.path.join(
            self.save_dir,
            f"resume_step_{state.global_step}.pt"
        )

        # Collect state synchronously (must happen on training thread)
        resume_state = {
            "step": state.global_step,
            "optimizer": kwargs["optimizer"].state_dict(),
            "lr_scheduler": kwargs["lr_scheduler"].state_dict(),
            "rng_cpu": torch.random.get_rng_state(),
            "rng_cuda": torch.cuda.get_rng_state_all(),
        }

        # Write in background thread to avoid blocking GPU
        if self._write_thread and self._write_thread.is_alive():
            self._write_thread.join()  # Don't stack multiple writes

        def _write():
            tmp_path = checkpoint_path + ".tmp"
            torch.save(resume_state, tmp_path)
            # Atomic rename avoids partial writes visible on crash
            os.replace(tmp_path, checkpoint_path)

        self._write_thread = threading.Thread(target=_write, daemon=True)
        self._write_thread.start()

    def verify_checkpoint(self, checkpoint_path):
        """Call before resuming to catch corrupted files."""
        try:
            state = torch.load(checkpoint_path, map_location="cpu")
            required_keys = {"step", "optimizer", "lr_scheduler", "rng_cpu", "rng_cuda"}
            return required_keys.issubset(state.keys())
        except Exception:
            return False

The atomic rename (os.replace) is important. A partial write followed by a crash leaves a .tmp file, not a corrupted checkpoint. On resume, you fall back to the previous clean checkpoint.

Sharded Checkpoints (FSDP / ZeRO-3)

When using FSDP2 and DeepSpeed ZeRO-3 setup, each GPU rank writes its own shard. The rank-to-shard mapping is deterministic: rank 0 always writes the first N parameters, rank 1 writes the next N, and so on. This means sharded checkpoints are fast (each rank writes only its fraction of the total state) but not portable across world sizes.

python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig

# Option A: Sharded save (fast, not portable across world sizes)
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    sharded_state = model.state_dict()
    torch.save(sharded_state, f"/checkpoints/rank_{rank}/model.pt")

# Option B: Full state dict (slow, portable across world sizes)
full_state_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_cfg):
    # All ranks must call model.state_dict() - FULL_STATE_DICT uses dist.all_gather
    # internally so every rank must participate; only rank 0 writes to disk.
    full_state = model.state_dict()
    if rank == 0:
        torch.save(full_state, "/checkpoints/model_full.pt")

Use FULL_STATE_DICT for any checkpoint you might resume on different hardware. Use SHARDED_STATE_DICT for frequent incremental saves during training where you know the hardware topology will not change.

Optimizer State and RNG State Preservation

Optimizer state is as critical as model weights for a clean resume. The Adam moment buffers encode the training momentum accumulated over hundreds of steps. Losing them resets the optimizer to a cold start, which causes a visible loss spike and wastes 50-200 steps of convergence.

python
# Save
resume_state = {
    "step": global_step,
    "optimizer": optimizer.state_dict(),
    "lr_scheduler": scheduler.state_dict(),
    # CPU RNG state (controls data shuffling, dropout, etc.)
    "rng_cpu": torch.random.get_rng_state(),
    # Per-GPU CUDA RNG state (each rank is independent in distributed training)
    "rng_cuda": torch.cuda.get_rng_state_all(),
}
torch.save(resume_state, "/persistent-storage/checkpoints/resume_latest.pt")

# Load
resume = torch.load("/persistent-storage/checkpoints/resume_latest.pt", map_location="cpu")
optimizer.load_state_dict(resume["optimizer"])
scheduler.load_state_dict(resume["lr_scheduler"])
torch.random.set_rng_state(resume["rng_cpu"])
torch.cuda.set_rng_state_all(resume["rng_cuda"])
global_step = resume["step"]

In distributed training, each rank has its own independent CUDA RNG state. torch.cuda.get_rng_state_all() returns a list with one entry per GPU on the current node. Save and restore all of them to preserve exact reproducibility.

For DeepSpeed ZeRO-3, optimizer state checkpointing is handled by the engine directly:

python
# Save (ZeRO-3 shards optimizer states, so this saves per-rank shards)
model_engine.save_checkpoint(
    save_dir="/persistent-storage/checkpoints",
    tag=f"step_{global_step}",
    save_zero_checkpoint=True,
    save_latest=True,
)

# If you need to consolidate shards to a single fp32 model file:
# python zero_to_fp32.py /persistent-storage/checkpoints/step_500/ /output/model_fp32.pt

The save_zero_checkpoint=True flag saves the ZeRO-3 optimizer state shards alongside the model shards. Do not skip it; without it, resuming from a ZeRO-3 checkpoint cold-starts the optimizer.

Async Checkpoint Offload: Don't Stall Training

Synchronous checkpoint writes block GPU compute for the duration of the write. For a 70B model:

  • 200GB checkpoint at 4 GB/s NVMe = 50 seconds of GPU idle time
  • At $13.52/hr for an 8x H100 cluster, that is $0.19 per checkpoint write
  • With checkpoints every 500 steps and ~6,000 tokens/sec throughput on 70B, checkpoints happen roughly every 100 minutes
  • Synchronous writes cost ~0.8% of total training time to overhead

That sounds small, but it adds up over a week-long run, and the GPU idle spikes are jarring. More importantly, with async writes you can checkpoint more frequently without paying the compute cost.

PyTorch 2.4+ includes torch.distributed.checkpoint.async_save, which writes checkpoints in background threads while training continues:

python
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict

# PyTorch 2.4+ required for async_save
# For older versions, use threading.Thread directly (see SpotCheckpointCallback above)

def save_checkpoint_async(model, optimizer, step, save_dir):
    model_state, optim_state = get_state_dict(model, optimizer)

    checkpoint_future = dcp.async_save(
        {"model": model_state, "optimizer": optim_state},
        checkpoint_id=f"{save_dir}/step_{step}",
    )
    # Training continues here while the write happens in background
    return checkpoint_future  # call .result() before the next checkpoint

# In training loop:
checkpoint_future = None
for step, batch in enumerate(dataloader):
    loss = model(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if step % 200 == 0:
        if checkpoint_future is not None:
            checkpoint_future.result()  # Ensure previous write finished
        checkpoint_future = save_checkpoint_async(model, optimizer, step, "/checkpoints")

Storage targets by speed and cost:

  • Local NVMe: 4-7 GB/s, does not survive preemption
  • Network-attached NVMe: 1-3 GB/s, survives preemption (required for spot training)
  • Object storage (S3/GCS): 200-500 MB/s, survives preemption, cheapest at scale but adds 30-60s write latency

For spot training, network-attached NVMe is the right default. Fast enough to keep checkpoint overhead low, durable across preemptions.

Checkpoint Frequency Math: Balancing Wasted Work vs Overhead

The optimal checkpoint interval depends on preemption probability and checkpoint write time. Here is the math:

Expected wasted compute (hours) = (checkpoint_interval_hours / 2) × preemption_probability_per_hour
Checkpoint overhead per hour = (checkpoint_write_time_seconds / 3600) × checkpoints_per_hour

For a 70B model fine-tune on an 8x H100 SXM5 cluster, assuming ~6,000 tokens/sec aggregate throughput and global batch size of ~72K tokens per step (8 GPUs × ~2 sequences × 4,096 tokens, typical for SFT), each step takes ~12 seconds:

Checkpoint intervalStepsWasted compute (avg)Checkpoint overheadRecommended for
50 steps (~10 min)505 min8 min/hrHigh preemption frequency (more than 1/hr)
100 steps (~20 min)10010 min4 min/hrModerate spot environments
500 steps (~100 min)50050 minunder 1 min/hrStable spot with low interruption rate

The practical recommendation: run incremental checkpoints every 100 steps (saves only optimizer state, writes in 15-25 seconds per rank), and full checkpoints every 500 steps. With async offload, the incremental checkpoints add near-zero overhead to training throughput.

Multi-Node Spot Resilience: Handling Partial Cluster Preemption

Multi-node spot is harder than single-node because a partial preemption (2 of 4 nodes reclaimed) leaves the remaining nodes hung waiting for distributed collectives that never complete.

NCCL does not detect single-node failures gracefully by default. You will see timeouts like Watchdog caught collective operation timeout or ncclInternalError several minutes after the node goes down, not an immediate clean failure. This is why the wait is painful: you need to set NCCL timeouts aggressively so hung training surfaces quickly.

bash
# Renamed from NCCL_ASYNC_ERROR_HANDLING in PyTorch 2.2+
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1

# NCCL_TIMEOUT is not a real env var - set the watchdog timeout in Python:
# dist.init_process_group(backend="nccl", timeout=timedelta(seconds=120))
# 120s is aggressive but safer than 60s for large all-reduces on 70B FSDP jobs;
# 60s risks false positives during legitimate large collectives.
# Default is 1800s; tune down to 120-300s for faster failure detection on spot.

Two strategies for handling partial preemption:

Strategy A: Treat any node loss as full-cluster failure. The surviving nodes checkpoint their state and exit cleanly. The job controller re-provisions a full replacement cluster and resumes from the latest checkpoint. Simple to implement, wastes the surviving nodes' compute during re-provisioning (typically 5-15 minutes). This is the right default for most teams.

Strategy B: Elastic training with torchelastic. PyTorch Elastic (torchrun --rdzv-backend=c10d) can scale world size down when nodes leave and up when replacements arrive. More complex to set up, but allows training to continue on the surviving nodes while replacements are provisioned.

bash
# torchrun with c10d rendezvous for elastic training (recommended; no external dependency)
torchrun \
  --nnodes=2:4 \           # Min 2 nodes, max 4 nodes (elastic range)
  --nproc_per_node=8 \
  --rdzv-backend=c10d \
  --rdzv-endpoint=<host0-ip>:29500 \
  --rdzv-id=my-training-job \
  train.py

For NCCL tuning and hang avoidance in multi-node environments, you need additional environment variables beyond the timeout settings. The details are in the NCCL tuning guide.

For multi-node training setup without InfiniBand, the collective communication throughput is lower, which makes NCCL timeouts even more important to tune correctly.

Resuming on Heterogeneous GPUs After Preemption

Spot availability does not guarantee you will get the same GPU SKU when you re-provision. You might train on H100 SXM5 8-GPU nodes and get H100 PCIe or H200 nodes on restart.

Model weights are architecture-agnostic. A BF16 checkpoint loads identically on H100, H200, or any other GPU that supports BF16. No conversion needed.

Optimizer states are device-agnostic too. The Adam moment buffers are just tensors. They load cleanly regardless of the target hardware.

The issue is rank topology. FSDP sharded checkpoints are tied to the world size and rank assignment. If you provisioned 8 GPUs originally and the replacement also has 8 GPUs, sharded checkpoints load cleanly even on different GPU hardware. If the world size changes (e.g., you get a 4-GPU replacement), you need to consolidate first:

python
# Consolidate sharded FSDP checkpoints before resuming on different world size
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig

# On the original 8-GPU setup (or after loading sharded checkpoints):
full_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
# FSDP.state_dict_type is a collective - all ranks must enter it together.
# model.state_dict() gathers the full state across ranks; only rank 0 saves.
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_cfg):
    full_state = model.state_dict()
    if dist.get_rank() == 0:
        torch.save(full_state, "/checkpoints/consolidated_model.pt")

# Then reshard on the new 4-GPU setup using the consolidated checkpoint

Mixed-precision compatibility: BF16 checkpoints load on H100 and H200 without conversion. FP16 checkpoints also load without conversion. FP32 checkpoints load but you will want to cast back to BF16 before the forward pass to avoid VRAM bloat.

The practical recommendation: always checkpoint in FULL_STATE_DICT mode for any checkpoint you designate as a recovery point (every 500+ steps). Use sharded checkpoints only for incremental saves where you know the hardware topology will not change.

Health Checks and Self-Healing Job Controllers

Three common orchestration patterns for spot training:

Kubernetes

yaml
apiVersion: batch/v1
kind: Job
metadata:
  name: llm-training-job
spec:
  backoffLimit: 10
  template:
    spec:
      restartPolicy: OnFailure
      terminationGracePeriodSeconds: 120
      containers:
      - name: trainer
        image: your-training-image:latest
        resources:
          limits:
            nvidia.com/gpu: 8
        volumeMounts:
        - name: checkpoint-storage
          mountPath: /persistent-storage/checkpoints
        lifecycle:
          preStop:
            exec:
              command: ["/bin/sh", "-c", "kill -TERM $(cat /tmp/training.pid) && sleep 90"]
        env:
        - name: RESUME_FROM_CHECKPOINT
          value: "latest"
      volumes:
      - name: checkpoint-storage
        persistentVolumeClaim:
          claimName: checkpoint-pvc

The preStop hook triggers before the container is killed on preemption. Combined with terminationGracePeriodSeconds: 120, you get 90 seconds to write an emergency checkpoint. backoffLimit: 10 allows 10 restarts before marking the job as failed.

Slurm

bash
#!/bin/bash
#SBATCH --job-name=llm-training
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-task=1
#SBATCH --time=168:00:00
#SBATCH --requeue
#SBATCH --open-mode=append

# With --requeue, Slurm automatically requeues the job when preempted
# Training script must detect $SLURM_RESTART_COUNT and resume from checkpoint

if [ "${SLURM_RESTART_COUNT:-0}" -gt 0 ]; then
    export RESUME_FROM_CHECKPOINT="/persistent-storage/checkpoints/latest"
fi

torchrun --nnodes=$SLURM_NNODES --nproc_per_node=8 train.py

For Slurm AI training workloads, --requeue is the simplest approach: Slurm handles re-submission automatically, and $SLURM_RESTART_COUNT tells your training script it is a resumed run.

Ray Train

Ray's fault-tolerance model wraps the training function and handles restarts transparently:

python
import ray
from ray import train
from ray.train.torch import TorchTrainer
from ray.train import CheckpointConfig, FailureConfig, RunConfig

trainer = TorchTrainer(
    train_loop_per_worker=training_function,
    scaling_config=train.ScalingConfig(num_workers=8, use_gpu=True),
    run_config=RunConfig(
        checkpoint_config=CheckpointConfig(
            num_to_keep=3,
            checkpoint_score_attribute="eval_loss",
            checkpoint_score_order="min",
        ),
        failure_config=FailureConfig(max_failures=5),
    ),
)
result = trainer.fit()

Ray re-runs training_function with the latest checkpoint when a worker fails. The checkpoint is loaded from the Ray object store automatically. For Kubernetes GPU job orchestration, Ray on Kubernetes gives you elastic scaling plus automatic fault recovery without writing your own restart logic.

Worked Example: Fine-Tuning a 70B Model on Spheron Spot H100s

Here is a complete setup for a 70B Qwen 2.5 fine-tuning run using H100 SXM5 spot instances on Spheron:

Cluster: 8x H100 SXM5 spot nodes

Model: Qwen 2.5 72B

Storage: 2TB persistent NVMe volume (~$0.10/GB/month, about $200/month = ~$47/week)

Checkpoint cadence: Full checkpoint every 500 steps (approximately every 100 minutes), incremental every 100 steps

Cost for a 1-week run:

ComponentRateDurationTotal
8x H100 SXM5 spot$1.69/GPU/hr × 8168 hrs$2,271
Persistent NVMe storage (2TB)~$47/week1 week$47
Total$2,318
vs on-demand (8x H100 lowest)$4.00/GPU/hr × 8168 hrs$5,376
Savings~$3,058 (57%)

PyTorch Lightning Setup with Preemption Hook

python
import os
import signal
import pytorch_lightning as pl
import torch

class FaultTolerantFinetuner(pl.LightningModule):
    def __init__(self, model, tokenizer):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer

    def training_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        return outputs.loss

    def on_train_batch_end(self, outputs, batch, batch_idx):
        # Save incremental checkpoint every 100 steps
        if self.global_step % 100 == 0:
            ckpt_dir = f"/persistent-storage/checkpoints/resume_step_{self.global_step}"
            self.trainer.save_checkpoint(ckpt_dir + ".ckpt")

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1.5e-5)


def setup_preemption_handler(trainer):
    def handle_sigterm(signum, frame):
        print("SIGTERM received. Saving emergency checkpoint...")
        trainer.save_checkpoint("/persistent-storage/checkpoints/emergency.ckpt")
        exit(0)
    signal.signal(signal.SIGTERM, handle_sigterm)


def main():
    checkpoint_path = None
    latest_ckpt = "/persistent-storage/checkpoints/last.ckpt"
    if os.path.exists(latest_ckpt):
        checkpoint_path = latest_ckpt
        print(f"Resuming from checkpoint: {checkpoint_path}")

    model = load_model()   # Your model loading logic
    tokenizer = load_tokenizer()
    finetuner = FaultTolerantFinetuner(model, tokenizer)

    trainer = pl.Trainer(
        max_steps=10000,
        accelerator="gpu",
        devices=8,
        strategy="fsdp",
        enable_checkpointing=True,
        default_root_dir="/persistent-storage/checkpoints",
    )

    setup_preemption_handler(trainer)
    trainer.fit(finetuner, ckpt_path=checkpoint_path)

FSDP2 with Async Distributed Checkpoint

python
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.fsdp import fully_shard
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import signal

def setup_model_fsdp(model):
    mp_policy = MixedPrecisionPolicy(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
    )
    # Shard each transformer block independently
    for layer in model.layers:
        fully_shard(layer, mp_policy=mp_policy)
    fully_shard(model, mp_policy=mp_policy, reshard_after_forward=True)
    return model


def train(model, optimizer, dataloader, checkpoint_dir):
    pending_checkpoint = None
    prev_checkpoint_path = None  # track last completed path for pointer update
    prev_global_step = None  # step number matching prev_checkpoint_path
    global_step = 0

    # Resume from checkpoint if available
    latest_path_file = os.path.join(checkpoint_dir, "latest_path.txt")
    if os.path.exists(latest_path_file):
        with open(latest_path_file) as f:
            resume_path = f.read().strip()
        model_state, optim_state = get_state_dict(model, optimizer)
        dcp.load(
            {"model": model_state, "optimizer": optim_state},
            checkpoint_id=resume_path,
        )
        set_state_dict(model, optimizer, model_state_dict=model_state, optim_state_dict=optim_state)
        with open(os.path.join(resume_path, "step.txt")) as f:
            global_step = int(f.read())
        if dist.get_rank() == 0:
            print(f"Resumed from step {global_step}")

    # SIGTERM handler: set a flag only. Calling distributed collectives directly
    # from a signal handler deadlocks when not all ranks receive the signal at once
    # (e.g. partial-preemption where only one node is reclaimed). The training loop
    # checks this flag at a safe point where all ranks are coordinated.
    _preemption_requested = False

    def handle_preemption(signum, frame):
        nonlocal _preemption_requested
        _preemption_requested = True

    signal.signal(signal.SIGTERM, handle_preemption)

    for batch in dataloader:
        loss = model(batch).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        global_step += 1

        # Check preemption flag at a safe coordination point (all ranks are here)
        if _preemption_requested:
            if pending_checkpoint:
                pending_checkpoint.result()
            model_state, optim_state = get_state_dict(model, optimizer)
            emergency_path = os.path.join(checkpoint_dir, "emergency")
            dcp.save(
                {"model": model_state, "optimizer": optim_state},
                checkpoint_id=emergency_path,
            )
            if dist.get_rank() == 0:
                with open(os.path.join(emergency_path, "step.txt"), "w") as f:
                    f.write(str(global_step))
                with open(os.path.join(checkpoint_dir, "latest_path.txt"), "w") as f:
                    f.write(emergency_path)
            dist.destroy_process_group()
            exit(0)

        # Async incremental checkpoint every 200 steps
        # torch.distributed.checkpoint.async_save requires PyTorch 2.4+
        # For older PyTorch, fall back to synchronous dcp.save() in a threading.Thread
        if global_step % 200 == 0:
            if pending_checkpoint:
                pending_checkpoint.result()
                # Update pointer only after confirming the previous write succeeded -
                # writing it before .result() risks pointing to an incomplete
                # checkpoint if the process is preempted mid-write.
                if dist.get_rank() == 0 and prev_checkpoint_path is not None:
                    with open(os.path.join(prev_checkpoint_path, "step.txt"), "w") as f:
                        f.write(str(prev_global_step))
                    with open(os.path.join(checkpoint_dir, "latest_path.txt"), "w") as f:
                        f.write(prev_checkpoint_path)

            model_state, optim_state = get_state_dict(model, optimizer)
            checkpoint_path = os.path.join(checkpoint_dir, f"step_{global_step}")
            pending_checkpoint = dcp.async_save(
                {"model": model_state, "optimizer": optim_state},
                checkpoint_id=checkpoint_path,
            )
            prev_checkpoint_path = checkpoint_path
            prev_global_step = global_step

        # Full FULL_STATE_DICT checkpoint every 1000 steps for portability
        # FSDP.state_dict_type is a collective - all ranks must enter together.
        # model.state_dict() gathers the full state; only rank 0 writes to disk.
        if global_step % 1000 == 0:
            from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
            from torch.distributed.fsdp import StateDictType, FullStateDictConfig
            full_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_cfg):
                full_state = model.state_dict()
                if dist.get_rank() == 0:
                    torch.save(full_state, os.path.join(checkpoint_dir, f"full_step_{global_step}.pt"))

For the preemption webhook integration, Spheron fires a notification before the instance is reclaimed. Check your provider's documentation for the specific webhook endpoint and payload format.

Spot Bidding Strategies and Capacity-Aware Scheduling

H100 SXM5 spot availability follows predictable patterns. Capacity is tightest during US East business hours (9am-6pm EST). Long runs starting at off-peak times (late evening, early morning, weekends) see fewer preemptions in the first few hours, which is when you want uninterrupted training to reach a stable checkpoint cadence.

Multi-region fallback: configure your provisioning to request spot capacity from a primary data center region and fall back to a secondary region if availability is low. This does not help if you need InfiniBand (InfiniBand is within a single data center), but for single-node runs or Ethernet-connected clusters, multi-region fallback significantly improves spot availability.

Mixed on-demand/spot clusters: One pattern that reduces full-cluster preemption risk is running 1-2 on-demand "anchor" nodes that hold the NCCL rendezvous point and parameter server. The majority of compute nodes are spot. A partial preemption hits only the spot nodes; the anchor nodes stay up, allowing faster recovery. The anchor nodes cost more, but the recovery time savings often justify it for very long runs.

Bid sizing: Do not configure a maximum bid above 2x the current spot price. Setting a high bid limit does not give you priority in most spot markets, but it can accidentally result in on-demand charges if you misread the provider's pricing model. Set a hard cost cap at the spot tier.


Running a week-long 70B fine-tune on spot H100s can cut your compute bill by over $3,000 compared to on-demand rates. With the checkpointing patterns above, spot interruptions cost you under 15 minutes of recovery time on average.

Rent spot H100 → | View live spot pricing → | Start your training job →

STEPS / 05

Quick Setup Guide

  1. Set up persistent checkpoint storage

    Mount a persistent NVMe-backed network volume at /persistent-storage/checkpoints. This volume survives instance preemptions. On Spheron, attach a persistent volume before launching the spot instance.

  2. Configure incremental and full checkpoint intervals

    Set save_steps=500 for full checkpoints and add a custom TrainerCallback or FSDP checkpoint hook that saves optimizer state + RNG state every 100 steps to /persistent-storage/checkpoints/resume_latest.pt.

  3. Install the preemption signal handler

    Register a SIGTERM handler in your training script that triggers an emergency checkpoint flush before the process exits. On Kubernetes, configure a preStop lifecycle hook with terminationGracePeriodSeconds=120.

  4. Set up automated re-provisioning

    Write a job controller script (or use a Kubernetes Job with restartPolicy=OnFailure) that detects instance termination and re-requests the spot cluster, then resumes training from the latest checkpoint.

  5. Validate checkpoint integrity before resuming

    Before resuming from checkpoint, verify the checkpoint files are complete by checking file sizes and running a dry-load pass. Corrupted partial checkpoints from a mid-write preemption will cause cryptic errors on resume.

FAQ / 05

Frequently Asked Questions

Save checkpoints to persistent network storage on a fixed step interval. Keep incremental 'resume checkpoints' (optimizer state + RNG state + step counter, no full weights) at a shorter interval (e.g., every 100 steps) so recovery restarts from at most 100 steps back. Configure automated re-provisioning to reload from the latest checkpoint without human intervention.

Preemption signals vary by provider. Typically you receive a SIGTERM or a metadata endpoint status change 30-120 seconds before the instance is reclaimed. On Spheron, the preemption webhook fires before reclamation. That window is enough to flush an incremental checkpoint if your checkpoint write is under 30 seconds.

A full checkpoint saves model weights, optimizer states, and RNG state - typically 2-4x the model size in storage. An incremental (resume) checkpoint saves only optimizer states, learning rate scheduler state, step counter, and RNG state - about 1x the model size or less. Use full checkpoints every 500-1000 steps for disaster recovery; use incremental checkpoints every 100 steps for fast resumption after a preemption.

Yes, but only if the world size (total GPU count) stays the same. FSDP checkpoints are sharded per rank. If you need to resume on a different GPU count, use FSDP's consolidation API to convert sharded checkpoints to a full state dict first, then reshard for the new world size.

Use an 8x H100 spot cluster with async checkpoint offload to persistent NVMe-backed storage, incremental checkpoints every 100 steps, and automated re-provisioning on preemption. At current spot pricing of $1.69/hr per H100 versus $4.00/hr on-demand (lowest available rate), a 1-week fine-tune saves several thousand dollars on the GPU bill alone.

Build what's next.

The most cost-effective platform for building, training, and scaling machine learning models-ready when you are.