Single GPUT (RTX4090) RLHF Training Pipeline w/ TRL

AI Reasoning Logo Single GPUT (RTX4090) RLHF Training Pipeline w/ TRL


┌─────────────────────────────────────────────────────────────────┐
│                    Anthropic/hh-rlhf Dataset                    │
│  160k examples with "chosen" and "rejected" responses           │
└──────────────┬────────────────────────┬─────────────────────────┘
               │                        │
               │ "chosen" only          │ preference pairs
               │ (20k subset)           │ (50k subset)
               ↓                        │
┌──────────────────────────────────┐    │
│   meta-llama/Llama-2-7b-hf       │    │
│   (4-bit quantized base)         │    │
└──────────────┬───────────────────┘    │
               ↓                        │
┌──────────────────────────────────────────────────┐
│              STEP 1: SFT Training                │
│  ┌────────────────────────────────────────────┐  │
│  │ Input:  "chosen" responses (20k)           │  │
│  │ Loss:   Cross-entropy on completions       │  │
│  │ Metric: Perplexity → 3.3-4.5               │  │
│  │ Memory: 18-20 GB                           │  │
│  │ Time:   2-4 hours                          │  │
│  └────────────────────────────────────────────┘  │
└──────────────┬───────────────────────────────────┘
               ↓
       ┌───────────────┐
       │  SFT Model    │ (~50 MB LoRA adapters)
       └───────┬───────┘
               │
       ┌───────┴────────────────┬─────────────────┐
       │                        │                 │
       ↓                        ↓                 ↓
┌──────────────────┐   ┌─────────────────┐   ┌──────────────┐
│ Policy (PPO)     │   │ Reference (PPO) │   │ RM Base      │
│ +value head      │   │ frozen copy     │   │ +reward head │
│ trainable        │   │ for KL penalty  │   │              │
└──────┬───────────┘   └────────┬────────┘   └──────┬───────┘
       │                        │                   │ + pairs (50k)
       │                        │                   ↓
       │                        │         ┌─────────────────────────────┐
       │                        │         │  STEP 2: Reward Model       │
       │                        │         │  ┌────────────────────────┐ │
       │                        │         │  │ Input:  Pairs (50k)    │ │
       │                        │         │  │ Loss:   Ranking loss   │ │
       │                        │         │  │ Metric: Accuracy >70%  │ │
       │                        │         │  │ Memory: 20-22 GB       │ │
       │                        │         │  │ Time:   3-6 hours      │ │
       │                        │         │  └────────────────────────┘ │
       │                        │         └──────────┬──────────────────┘
       │                        │                    ↓
       │                        │              ┌──────────────┐
       │                        │              │ Reward Model │ (frozen scorer)
       │                        │              └──────┬───────┘
       │                        │                     │
       └────────────────────────┴─────────────────────┘
                                │
                + prompts (20k) │
                                ↓
       ┌────────────────────────────────────────────────┐
       │         STEP 3: PPO RLHF Optimization          │
       │  ┌──────────────────────────────────────────┐  │
       │  │ Input:  Prompts (20k)                    │  │
       │  │         Policy + Reference + RM          │  │
       │  │ Loss:   PPO clipped + KL penalty         │  │
       │  │ Metric: Mean reward ↑, KL 0.1-0.3        │  │
       │  │ Memory: 22-24 GB                         │  │
       │  │ Time:   6-12 hours                       │  │
       │  │ Loop:   1000 PPO steps                   │  │
       │  └──────────────────────────────────────────┘  │
       └───────────────────────┬────────────────────────┘
                               ↓
                    ┌─────────────────────┐
                    │  Final RLHF Model   │
                    │  (LoRA adapters)    │
                    └─────────────────────┘

Memory Usage Breakdown (RTX 4090 - 24 GB) #

Component                  Memory      Notes
────────────────────────────────────────────────────────────
Base Model (4-bit)         ~3.5 GB    Llama-2-7B in NF4
LoRA Adapters              ~0.05 GB   r=16, small footprint
Optimizer States           ~8-10 GB   Paged AdamW 8-bit
Activations                ~6-8 GB    With gradient checkpointing
KV Cache (PPO generation)  ~2-4 GB    During generation phase
────────────────────────────────────────────────────────────
Total                      18-24 GB   Fits on RTX 4090!

Training Metrics Timeline #

                SFT                    RM                    PPO
Time:       0h────4h              4h────10h            10h────22h
            │                     │                    │
Loss:       2.5 → 1.2            0.69 → 0.35           -
Perplexity: 12 → 3.3             -                     -
Accuracy:   -                    50% → 72%             -
Reward:     -                    -                     +2 → +8
KL:         -                    -                     0.05 → 0.25
            │                     │                    │
Output:     SFT Model ────────────→ Reward Model ─────→ RLHF Model

Key Relationships #

  1. SFT → RM: Warm start (better initialization than random)
  2. SFT → Policy: Direct inheritance (then optimized)
  3. SFT → Reference: Frozen copy (anchor for KL penalty)
  4. RM → PPO: Provides reward signal (quality score)

Loss Functions #

SFT Loss:

L_SFT = -log P(y_completion | x_prompt)

RM Loss:

L_RM = -log σ(r(x, y_chosen) - r(x, y_rejected))

PPO Loss:

L_PPO = min(r_t(θ)·A_t, clip(r_t(θ), 1±ε)·A_t) + β·D_KL(π||π_ref)
where r_t(θ) = π_θ(a|s) / π_θ_old(a|s)  (probability ratio)