Single GPUT (RTX4090) RLHF Training Pipeline w/ TRL
┌─────────────────────────────────────────────────────────────────┐
│ Anthropic/hh-rlhf Dataset │
│ 160k examples with "chosen" and "rejected" responses │
└──────────────┬────────────────────────┬─────────────────────────┘
│ │
│ "chosen" only │ preference pairs
│ (20k subset) │ (50k subset)
↓ │
┌──────────────────────────────────┐ │
│ meta-llama/Llama-2-7b-hf │ │
│ (4-bit quantized base) │ │
└──────────────┬───────────────────┘ │
↓ │
┌──────────────────────────────────────────────────┐
│ STEP 1: SFT Training │
│ ┌────────────────────────────────────────────┐ │
│ │ Input: "chosen" responses (20k) │ │
│ │ Loss: Cross-entropy on completions │ │
│ │ Metric: Perplexity → 3.3-4.5 │ │
│ │ Memory: 18-20 GB │ │
│ │ Time: 2-4 hours │ │
│ └────────────────────────────────────────────┘ │
└──────────────┬───────────────────────────────────┘
↓
┌───────────────┐
│ SFT Model │ (~50 MB LoRA adapters)
└───────┬───────┘
│
┌───────┴────────────────┬─────────────────┐
│ │ │
↓ ↓ ↓
┌──────────────────┐ ┌─────────────────┐ ┌──────────────┐
│ Policy (PPO) │ │ Reference (PPO) │ │ RM Base │
│ +value head │ │ frozen copy │ │ +reward head │
│ trainable │ │ for KL penalty │ │ │
└──────┬───────────┘ └────────┬────────┘ └──────┬───────┘
│ │ │ + pairs (50k)
│ │ ↓
│ │ ┌─────────────────────────────┐
│ │ │ STEP 2: Reward Model │
│ │ │ ┌────────────────────────┐ │
│ │ │ │ Input: Pairs (50k) │ │
│ │ │ │ Loss: Ranking loss │ │
│ │ │ │ Metric: Accuracy >70% │ │
│ │ │ │ Memory: 20-22 GB │ │
│ │ │ │ Time: 3-6 hours │ │
│ │ │ └────────────────────────┘ │
│ │ └──────────┬──────────────────┘
│ │ ↓
│ │ ┌──────────────┐
│ │ │ Reward Model │ (frozen scorer)
│ │ └──────┬───────┘
│ │ │
└────────────────────────┴─────────────────────┘
│
+ prompts (20k) │
↓
┌────────────────────────────────────────────────┐
│ STEP 3: PPO RLHF Optimization │
│ ┌──────────────────────────────────────────┐ │
│ │ Input: Prompts (20k) │ │
│ │ Policy + Reference + RM │ │
│ │ Loss: PPO clipped + KL penalty │ │
│ │ Metric: Mean reward ↑, KL 0.1-0.3 │ │
│ │ Memory: 22-24 GB │ │
│ │ Time: 6-12 hours │ │
│ │ Loop: 1000 PPO steps │ │
│ └──────────────────────────────────────────┘ │
└───────────────────────┬────────────────────────┘
↓
┌─────────────────────┐
│ Final RLHF Model │
│ (LoRA adapters) │
└─────────────────────┘
Memory Usage Breakdown (RTX 4090 - 24 GB) #
Component Memory Notes
────────────────────────────────────────────────────────────
Base Model (4-bit) ~3.5 GB Llama-2-7B in NF4
LoRA Adapters ~0.05 GB r=16, small footprint
Optimizer States ~8-10 GB Paged AdamW 8-bit
Activations ~6-8 GB With gradient checkpointing
KV Cache (PPO generation) ~2-4 GB During generation phase
────────────────────────────────────────────────────────────
Total 18-24 GB Fits on RTX 4090!
Training Metrics Timeline #
SFT RM PPO
Time: 0h────4h 4h────10h 10h────22h
│ │ │
Loss: 2.5 → 1.2 0.69 → 0.35 -
Perplexity: 12 → 3.3 - -
Accuracy: - 50% → 72% -
Reward: - - +2 → +8
KL: - - 0.05 → 0.25
│ │ │
Output: SFT Model ────────────→ Reward Model ─────→ RLHF Model
Key Relationships #
- SFT → RM: Warm start (better initialization than random)
- SFT → Policy: Direct inheritance (then optimized)
- SFT → Reference: Frozen copy (anchor for KL penalty)
- RM → PPO: Provides reward signal (quality score)
Loss Functions #
SFT Loss:
L_SFT = -log P(y_completion | x_prompt)
RM Loss:
L_RM = -log σ(r(x, y_chosen) - r(x, y_rejected))
PPO Loss:
L_PPO = min(r_t(θ)·A_t, clip(r_t(θ), 1±ε)·A_t) + β·D_KL(π||π_ref)
where r_t(θ) = π_θ(a|s) / π_θ_old(a|s) (probability ratio)