RLHF Pipeline: Key Non-Default Settings #
Critical Configuration (Non-Defaults Only) #
π― Models & Architecture #
| Setting | Value | Why Not Default? |
|---|---|---|
| Model Class | Step 1/3: CausalLMStep 2: SequenceClassification |
Step 2 needs scalar reward output, not text generation |
| Quantization | 4-bit QLoRA | Fits 7B model in 24GB VRAM (vs 28GB for fp16) |
| num_labels | Step 2: 1 |
Reward model outputs single scalar score |
π Dataset Configuration #
| Setting | Value | Why Not Default? |
|---|---|---|
| Dataset Size | SFT: 20K, RM: 50K, PPO: 20K | RM needs more data for robust preference learning |
| Data Format | SFT: chosen only RM: chosen+rejected pairs PPO: prompts only |
Each step requires different supervision signal |
| Train/Eval Split | Step 2: 5% eval | Only RM needs validation to prevent reward hacking |
βοΈ Training Hyperparameters #
| Setting | Value | Why Not Default? |
|---|---|---|
| Learning Rate | SFT: 2e-4RM: 5e-5PPO: 1e-6 |
Decreasing LR prevents destabilizing previous training |
| Batch Size | 1 (SFT/RM)8 (PPO) |
Memory constraint; PPO needs multiple rollouts per update |
| Gradient Accumulation | SFT: 16 RM: 32 |
Simulates larger batch sizes within memory limit |
| Effective Batch Size | SFT: 16 RM: 32 PPO: 8 |
RM needs larger batches for stable ranking gradients |
π§ LoRA Configuration #
| Setting | Value | Why Not Default? |
|---|---|---|
| r (rank) | 16 |
Balance between parameter efficiency and model capacity |
| alpha | 32 (2Γr) |
Standard scaling for LoRA updates |
| dropout | 0.05 |
Mild regularization to prevent adapter overfitting |
| task_type | SFT/PPO: CAUSAL_LMRM: SEQ_CLS |
Matches the model head type for each step |
ποΈ Quantization Details #
| Setting | Value | Why Not Default? |
|---|---|---|
| load_in_4bit | True |
Reduces memory by 75% vs fp16 |
| bnb_4bit_use_double_quant | True |
Quantizes quantization constants (extra memory savings) |
| bnb_4bit_quant_type | "nf4" |
Normal Float 4-bit optimal for weights (vs uniform) |
| bnb_4bit_compute_dtype | bfloat16 |
Better numerical stability than fp16 for training |
π Optimization Settings #
| Setting | Value | Why Not Default? |
|---|---|---|
| optim | paged_adamw_8bit |
Memory-efficient optimizer for 4-bit training |
| bf16 | True |
Better gradient stability than fp16 |
| gradient_checkpointing | True |
Trades compute for memory (enables longer sequences) |
| lr_scheduler_type | "cosine" |
Smooth LR decay prevents abrupt training disruption |
| warmup_ratio | 0.03 |
Stabilizes initial training with 4-bit quantization |
| max_grad_norm | 0.3 |
Prevents gradient explosion in LoRA training |
π PPO-Specific (Step 3 Only) #
| Setting | Value | Why Not Default? |
|---|---|---|
| mini_batch_size | 1 |
Memory constraint during on-policy generation |
| ppo_epochs | 4 |
Multiple passes over collected experience |
| init_kl_coef | 0.1 |
Prevents policy from diverging too far from SFT |
| adap_kl_ctrl | True |
Dynamically adjusts KL penalty based on divergence |
| gamma | 1.0 |
No discounting (language has no clear episode structure) |
| lam | 0.95 |
GAE parameter balancing bias-variance in advantage |
| cliprange | 0.2 |
Limits policy update size (PPO core mechanism) |
| vf_coef | 0.1 |
Weight of value function loss vs policy loss |
Training Flow Summary #
Llama-2-7b-hf (4-bit quantized)
β
[Step 1: SFT] β 20K chosen examples, LR=2e-4, LoRA r=16
β
βββ [Step 2: RM] β 50K preference pairs, LR=5e-5, outputs scalar
β β
βββ [Step 3: PPO] β 20K prompts, LR=1e-6, KL=0.1
β
Final RLHF Model
Key Metrics to Monitor #
| Step | Primary Metric | Danger Sign |
|---|---|---|
| SFT | Training loss β | Eval loss β (overfitting) |
| RM | Ranking accuracy β | Reward always higher for longer text (length bias) |
| PPO | Mean reward β | KL > 0.5 (policy collapse) |
Why These Specific Values? #
Learning Rate Decay Pattern #
- SFT (2e-4): Highest LR for initial adaptation from base model
- RM (5e-5): Lower to preserve SFT knowledge while learning preferences
- PPO (1e-6): Tiny updates to avoid destroying alignment from RM
Batch Size Strategy #
- Small per-device (1): GPU memory constraint with 7B model
- Large accumulation (16-32): Stabilizes gradients for contrastive learning (RM)
- PPO (8 rollouts): Enough diversity for policy gradient estimation
Quantization Choices #
- 4-bit: Only option that fits 7B + optimizer states in 24GB
- NF4: Specifically designed for neural network weight distributions
- Double quant: Squeezes extra ~1GB by quantizing quantization parameters
- bfloat16 compute: Prevents underflow in gradients during backprop
LoRA Design #
- r=16: Sweet spot for 7B models (too low = capacity loss, too high = overfitting)
- alpha=32: Standard 2Γ scaling keeps update magnitudes reasonable
- All attention + FFN: Covers both information routing and transformation
PPO Parameters #
- KL penalty (0.1): Prevents catastrophic forgetting of SFT behavior
- Clip (0.2): Conservative updates reduce instability
- Gamma (1.0): No temporal discounting (each token equally important)