Serving a 70B model in production costs roughly 3-4x more per token than serving an 8B model. If your 70B model is doing 90% of its heavy lifting on tasks an 8B could handle, you are burning real money. Model distillation is how you fix that: train a smaller student model to replicate the teacher's reasoning, then retire the expensive one. As part of a broader GPU cost optimization strategy, distillation typically delivers the largest single cost reduction of any technique, because it permanently reduces model size rather than just squeezing existing weights harder.
What Is Model Distillation
Model distillation trains a small student model to mimic the output probability distribution of a large teacher model, not just to predict the correct label. The key insight is that a teacher's full output vector across the vocabulary (its "soft labels") contains much richer information than a simple correct/wrong signal.
When a teacher assigns 60% probability to "Paris", 30% to "Lyon", and 5% to "Marseille" for a geography question, the student doesn't just learn "Paris is right." It learns something about the relative closeness of French cities, the structure of the question, and how confident to be. That gradient of information is what makes distilled models punch above their parameter count.
Here's how distillation compares to the other main techniques for reducing VRAM and inference cost:
| Technique | What Changes | VRAM Reduction | Accuracy Trade-off |
|---|---|---|---|
| Quantization | Same model, lower precision per weight | 2-4x | 1-3% on benchmarks |
| Fine-tuning | Same model size, adapted to new tasks | None | Varies by domain |
| Distillation | New, smaller model trained on teacher outputs | 4-10x (from 70B to 8B) | 3-8% on general benchmarks |
For a closer look at how quantization interacts with GPU memory requirements, including exact VRAM numbers by model size and precision format, see that post's calculator tables.
When to Distill vs Quantize vs Fine-Tune
The right technique depends on what you're actually trying to achieve:
| Goal | Recommended Approach | Complexity | Cost Reduction |
|---|---|---|---|
| Reduce inference cost at scale | Distillation | High | 80-95% |
| Reduce VRAM on existing model | Quantization | Low | 50-75% |
| Adapt model to domain | Fine-tuning | Medium | None |
| Reduce latency, keep accuracy | Distillation or quantization | Medium-High | 60-90% |
Distillation is the right call when you're spending real money on 70B inference and your task doesn't genuinely require 70B-scale reasoning. Quantization gets you partway there. Distillation goes the rest.
Don't distill when:
- Your dataset is under 5,000 examples. The student needs enough signal to generalize; thin data produces a model that memorizes teacher outputs on the training distribution and collapses on anything else.
- Your task genuinely requires the teacher's full breadth. Coding, legal reasoning, and complex multi-step math often benefit from 70B-scale capacity. A narrow benchmark score won't tell you this; your production task distribution will.
- Teacher and student have incompatible tokenizers. This is solvable (you can project logits across vocabularies), but it adds significant complexity.
When fine-tuning is the right choice instead, see the fine-tuning guide for a complete workflow with cost breakdowns.
How Model Distillation Works: The Mechanics
The distillation loss combines two signals:
L_total = alpha * L_KL + (1 - alpha) * L_CEL_KL is the KL divergence between the teacher's softened probability distribution and the student's, computed token by token. L_CE is the standard cross-entropy loss against the ground truth label. Alpha is typically set to 0.5 or 0.7, weighting distillation higher than the task signal.
The temperature parameter T controls how "soft" the teacher's distribution is. At T=1, you get the raw softmax probabilities. At T=2 or T=4, the distribution flattens, making smaller probabilities more visible to the student. Most practitioners use T=2 for the distillation step and T=1 for the task loss.
When teacher and student share the same tokenizer (for example, both are Llama-family models), logit alignment is automatic. If they don't share a vocabulary, you have two options: subset the teacher's logits to the student's vocabulary (loses some signal) or train a projection layer to map across the full vocab space (adds parameters and complexity). Use the same model family when possible and skip this problem entirely.
Sequence-level distillation is a simpler variant: the teacher generates complete output sequences, and the student is trained on those sequences with standard cross-entropy. It's less data-efficient than token-level KL distillation but easier to implement and still significantly better than training from scratch on hard labels alone.
GPU Requirements for 70B-to-8B Distillation
| Component | Model | GPU | VRAM | Hourly Cost (Spheron) |
|---|---|---|---|---|
| Teacher inference | Llama 3.3 70B at FP8 | 1x H100 PCIe 80GB | ~70GB | $2.01/hr on-demand |
| Student training | Llama 3.1 8B at BF16 | 1x H100 PCIe 80GB | ~30-40GB (weights + activations + optimizer) | $2.01/hr on-demand |
| Student deployment | Llama 3.1 8B at BF16 | 1x L40S 48GB | ~14GB weights | $0.72/hr on-demand |
Pricing fluctuates based on GPU availability. The prices above are based on 26 Mar 2026 and may have changed. Check current GPU pricing → for live rates.
The reason you want teacher inference separate from student training is memory. Running a 70B FP8 teacher and an 8B BF16 student simultaneously would require ~85GB of VRAM, which exceeds a single H100 80GB. Generating soft labels offline onto disk first lets each GPU do one job cleanly.
For a typical distillation project with 10,000 training examples and 3 epochs:
- Soft label generation: ~3 hours (one H100 PCIe at $2.01/hr = $6.03)
- Student training: ~6-10 hours (one H100 PCIe at $2.01/hr = $12.06-20.10)
- Total project cost: ~$18-26
That's one-time training cost. The permanent saving comes from serving the 8B student at $0.72/hr on an L40S instead of the 70B teacher at $2.01/hr on an H100.
Provision H100 GPU rental on Spheron for the training phase and check current GPU pricing for the latest rates across all GPU types. For account setup and connecting to your first GPU instance, see the Spheron getting started guide.
Step-by-Step: Distilling a 70B Teacher into an 8B Student
Step 1: Set Up the Environment
pip install torch==2.5.1 transformers==4.45.0 "trl>=0.12.0" datasets accelerate bitsandbytesRequires CUDA 12.4 and PyTorch 2.5+. The bitsandbytes library handles 8-bit loading of the 70B teacher. The trl 0.12+ API is required for the SFTTrainer compute_loss signature used in Step 3; newer versions (0.13+, 0.14+, etc.) are also compatible.
Step 2: Load the Teacher and Generate Soft Labels
This is the most GPU-memory-intensive step. A 70B model at FP8 needs ~70GB. One H100 80GB is the minimum. For detailed VRAM numbers by precision format, see the GPU memory requirements guide.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
TEACHER_MODEL = "meta-llama/Llama-3.3-70B-Instruct"
DATASET_NAME = "your-org/your-dataset" # replace with your dataset
TOP_K = 50
TEMPERATURE = 2.0
OUTPUT_PATH = "./soft_labels"
tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL)
teacher = AutoModelForCausalLM.from_pretrained(
TEACHER_MODEL,
load_in_8bit=True, # requires bitsandbytes; use load_in_4bit for tighter VRAM
device_map="auto",
torch_dtype=torch.float16,
)
teacher.eval()
dataset = load_dataset(DATASET_NAME, split="train")
import os
os.makedirs(OUTPUT_PATH, exist_ok=True)
for idx, example in enumerate(dataset):
inputs = tokenizer(
example["text"],
return_tensors="pt",
truncation=True,
max_length=512,
).to("cuda")
with torch.no_grad():
outputs = teacher(**inputs)
# Apply temperature scaling to soften the distribution before saving.
# The saved logits are already temperature-scaled (not raw logits).
logits = outputs.logits / TEMPERATURE
# Save only top-K logits per position to reduce disk space
top_k_logits, top_k_indices = torch.topk(logits, k=TOP_K, dim=-1)
torch.save(
{"logits": top_k_logits.cpu(), "indices": top_k_indices.cpu()},
f"{OUTPUT_PATH}/sample_{idx:06d}.pt",
)
if idx % 500 == 0:
print(f"Processed {idx}/{len(dataset)}")Note on vocabulary alignment: This example uses a Llama 3.3 70B teacher and will be paired with a Llama 3.1 8B student in Step 3. Both use the same tokenizer (Llama tokenizer with 128,256 vocab size), so logit indices align directly. If you use a different student family (e.g., Qwen or Mistral), you must handle vocabulary mismatch: either subset the top-K indices to the student's vocabulary or add a projection layer. Silently ignoring this produces incorrect KL divergence and a broken distillation loss.
Step 3: Define the Distillation Loss
import torch
import torch.nn.functional as F
from trl import SFTTrainer
from transformers import TrainingArguments
ALPHA = 0.7 # weight for distillation loss vs task loss
TEMPERATURE = 2.0
SOFT_LABELS_PATH = "./soft_labels"
class DistillationTrainer(SFTTrainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
# Pop sample_indices BEFORE passing inputs to the parent class.
# The model's forward() does not accept this field and will raise
# a TypeError if it is present when super().compute_loss() is called.
batch_indices = inputs.pop("sample_indices", None)
# Get student outputs via the parent class (handles labels internally)
loss, outputs = super().compute_loss(
model, inputs, return_outputs=True, **kwargs
)
if batch_indices is None:
# Fall back to standard task loss if indices not provided
return (loss, outputs) if return_outputs else loss
student_logits = outputs.logits # [batch, seq_len, vocab]
distill_losses = []
for i, sample_idx in enumerate(batch_indices):
saved = torch.load(
f"{SOFT_LABELS_PATH}/sample_{int(sample_idx):06d}.pt",
map_location=student_logits.device,
)
teacher_indices = saved["indices"][0] # [seq_len, top_k]
teacher_logits_topk = saved["logits"][0] # [seq_len, top_k]
seq_len = min(student_logits.shape[1], teacher_indices.shape[0])
# Scatter teacher top-K logits back to full vocabulary space
teacher_full = torch.full(
(seq_len, student_logits.shape[-1]),
float("-inf"),
device=student_logits.device,
)
teacher_full.scatter_(
dim=-1,
index=teacher_indices[:seq_len],
src=teacher_logits_topk[:seq_len].to(student_logits.dtype),
)
# The saved teacher logits are already temperature-scaled (divided by
# TEMPERATURE in Step 2). Do NOT divide again here — that would apply
# temperature twice (effective T^2), producing a severely flattened
# distribution that does not reflect the teacher's actual predictions.
teacher_probs = F.softmax(teacher_full, dim=-1)
student_log_probs = F.log_softmax(
student_logits[i, :seq_len] / TEMPERATURE, dim=-1
)
# KL divergence: sum over vocab, mean over sequence positions
kl = F.kl_div(student_log_probs, teacher_probs, reduction="batchmean")
distill_losses.append(kl)
distill_loss = torch.stack(distill_losses).mean()
# Combined loss: alpha * KL + (1 - alpha) * cross-entropy
# Scale distillation loss by T^2 as per Hinton et al.
total_loss = ALPHA * (TEMPERATURE ** 2) * distill_loss + (1 - ALPHA) * loss
return (total_loss, outputs) if return_outputs else total_lossStep 4: Train the Student
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTConfig
from datasets import load_dataset
STUDENT_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
DATASET_NAME = "your-org/your-dataset"
student_tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL)
if student_tokenizer.pad_token is None:
student_tokenizer.pad_token = student_tokenizer.eos_token
student = AutoModelForCausalLM.from_pretrained(
STUDENT_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
)
dataset = load_dataset(DATASET_NAME, split="train")
# Add integer sample indices so DistillationTrainer can look up the
# pre-computed soft labels saved in Step 2. Without this column,
# batch_indices is always None and distillation is silently skipped.
dataset = dataset.map(lambda example, idx: {"sample_indices": idx}, with_indices=True)
training_args = SFTConfig(
output_dir="./student-distilled",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=1e-4,
lr_scheduler_type="cosine",
warmup_steps=100,
bf16=True,
logging_steps=25,
save_strategy="epoch",
dataloader_num_workers=4,
report_to="none",
)
student.config.pad_token_id = student_tokenizer.pad_token_id
trainer = DistillationTrainer(
model=student,
args=training_args,
train_dataset=dataset,
tokenizer=student_tokenizer,
)
trainer.train()
trainer.save_model("./student-distilled-final")Step 5: Evaluate the Student
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
DATASET_NAME = "your-org/your-dataset" # replace with your dataset
def evaluate_model(model_path, dataset, num_samples=500):
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.eval()
correct = 0
for example in dataset.select(range(num_samples)):
inputs = tokenizer(example["question"], return_tensors="pt").to("cuda")
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=64, do_sample=False)
prediction = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
if example["answer"].strip().lower() in prediction.strip().lower():
correct += 1
return correct / num_samples
eval_dataset = load_dataset(DATASET_NAME, split="test")
teacher_acc = evaluate_model("meta-llama/Llama-3.3-70B-Instruct", eval_dataset)
student_acc = evaluate_model("./student-distilled-final", eval_dataset)
print(f"Teacher accuracy: {teacher_acc:.1%}")
print(f"Student accuracy: {student_acc:.1%}")
print(f"Gap: {(teacher_acc - student_acc):.1%}")Deploying Your Distilled 8B Model with vLLM on Spheron
Once training completes, deploy the student on an L40S or RTX 4090. The 8B model at BF16 uses ~14GB of VRAM, well within the L40S's 48GB. The RTX 4090 (24GB) works for low-concurrency workloads at --max-model-len 4096.
docker run --gpus all --ipc=host -p 8000:8000 \
-v /path/to/student-distilled-final:/model \
vllm/vllm-openai:latest \
--model /model \
--dtype bfloat16 \
--gpu-memory-utilization 0.90 \
--max-num-seqs 256Call the server using the OpenAI SDK with a base_url override:
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="none")
response = client.chat.completions.create(
model="/model",
messages=[{"role": "user", "content": "Explain transformer attention in one paragraph."}],
max_tokens=256,
)
print(response.choices[0].message.content)For a complete vLLM production setup covering multi-GPU tensor parallelism, FP8 quantization, and load balancing, see the vLLM production deployment guide. For a broader overview of LLM inference options on Spheron GPUs, including Ollama and SGLang setups, see the Spheron LLM inference guide. For the full workflow from model weights to a self-hosted OpenAI-compatible endpoint, see the self-hosted OpenAI-compatible API guide.
Benchmarks: Distilled 8B vs Original 70B
Quality comparison (representative scores based on published DeepSeek R1 distillation results and Llama distillation literature; figures are approximate composites, not from a single benchmark run):
| Benchmark | 70B Teacher | 8B Student | Gap |
|---|---|---|---|
| MMLU | 88.0% | 83.2% | -4.8pp |
| HellaSwag | 89.4% | 85.1% | -4.3pp |
| GSM8K | 91.2% | 84.7% | -6.5pp |
| Domain task accuracy | 87.5% | 85.9% | -1.6pp |
| ROUGE-L (summarization) | 0.412 | 0.389 | -5.6% |
The domain task row shows why distilling with in-domain data matters: when the student learns teacher soft labels on the same distribution it will face in production, the gap collapses to under 2%.
Cost and performance comparison:
| Model | GPU | On-demand ($/hr) | Tokens/sec (vLLM) | Cost per 1M tokens |
|---|---|---|---|---|
| Llama 3.3 70B | H100 PCIe | $2.01/hr | ~180 tok/s | ~$3.10 |
| Llama 3.1 8B (distilled) | L40S 48GB | $0.72/hr | ~1,200 tok/s | ~$0.17 |
| Llama 3.1 8B (distilled) | RTX 4090 | $0.50/hr | ~900 tok/s | ~$0.15 |
Pricing fluctuates based on GPU availability. The prices above are based on 26 Mar 2026 and may have changed. Check current GPU pricing → for live rates.
The distilled 8B model costs roughly 18x less per token to serve on Spheron, at comparable quality on most real-world tasks. For a broader GPU inference cost comparison across more GPU types and model sizes, see that guide's cost-per-token tables.
Production Checklist: Evaluation, Monitoring, and When to Re-Distill
- Run held-out eval before promoting to production. Cover domain-specific tasks, not just general benchmarks like MMLU. A 4% MMLU gap might be acceptable; a 15% gap on your actual customer queries is not.
- Set up latency and quality monitoring in production. An 8B student can drift if you re-distill from a different teacher checkpoint or if task distribution shifts. Use GPU monitoring for ML to track GPU utilization, token throughput, and latency percentiles alongside application-level quality metrics.
- Log model-scored quality metrics. Send 1% of live traffic to the teacher model and compare outputs with the student. This gives you a continuous quality signal without human labeling.
- Re-distill when:
- Task distribution shifts (new product features, new user queries)
- The teacher is updated with significant capability improvements
- Student error rate on monitored tasks exceeds a defined threshold
- Consider distilling specialized students. One student per domain (code, math, customer support) often beats one general student. The teacher gives different signal quality on different task types. A student trained only on coding tasks can match a 70B teacher on code at 8B scale, where a general student falls short.
Model distillation cuts serving costs by 80-95% on tasks where an 8B student matches the 70B teacher's quality. Spheron provides on-demand H100 access for the training phase and affordable L40S or RTX 4090 instances for serving the distilled model.
Rent H100 for distillation training → | View all GPU pricing → | Get started on Spheron →
