How to Use BERT's CLS Token for Classification
❓ Question #
How can we use the
[CLS]token (i.e.,h_cls) from the last layer of BERT for classification tasks? Given that the BERT output has shape[batch_size, sequence_length, hidden_size], how is it valid to pass only[batch_size, hidden_size]to ann.Linear(hidden_size, num_classes)without flattening the sequence? And why don’t we flatten the whole sequence — wouldn’t that destroy order?
✅ Answer #
🔹 BERT Output and the [CLS] Token
#
BERT outputs a tensor of shape:
[batch_size, sequence_length, hidden_size]
But for classification tasks, we typically use only the [CLS] token, which is located at position 0 in the sequence:
h_cls = outputs.last_hidden_state[:, 0, :] # Shape: [batch_size, hidden_size]
This token is designed to act as a summary representation of the entire sequence, and this output shape matches exactly what a nn.Linear(hidden_size, num_classes) expects — no flattening needed.
🔹 Why Not Flatten? #
Flattening the whole sequence (e.g., [batch_size, sequence_length * hidden_size]) loses:
- Token order
- Positional embeddings
- Sequence structure
In NLP, this breaks the semantic and syntactic structure of the input. Instead, use:
🔸 Recommended Pooling Strategies #
| Strategy | Description |
|---|---|
[CLS] Token |
Use outputs[:, 0, :]; trained as a sequence summary |
| Mean Pooling | outputs.mean(dim=1); averages token embeddings |
| Max Pooling | outputs.max(dim=1).values; takes strongest signal |
| Attention Pooling | Learns weights to summarize tokens adaptively |
📚 Sources and Justification #
- BERT Paper: Devlin et al. (2018) — [CLS] token for classification
- Sentence-BERT: Reimers & Gurevych (2019) — Mean pooling often better for embeddings
- Hugging Face Transformers: Practical implementation patterns
- NLP Community Practices: Kaggle, blogs, and tutorials
🧪 Summary #
- Use
[CLS]or pooling (not flattening) for sequence-level tasks. - Flattening destroys sequence information and is rarely appropriate in NLP.
- The linear layer works on
[batch_size, hidden_size]— no need to flatten across tokens.