Date: 2025-03-31
β Question #
How can we use the
[CLS]
token (i.e.,h_cls
) from the last layer of BERT for classification tasks? Given that the BERT output has shape[batch_size, sequence_length, hidden_size]
, how is it valid to pass only[batch_size, hidden_size]
to ann.Linear(hidden_size, num_classes)
without flattening the sequence? And why don’t we flatten the whole sequence β wouldn’t that destroy order?
β Answer #
πΉ BERT Output and the [CLS]
Token
#
BERT outputs a tensor of shape:
[batch_size, sequence_length, hidden_size]
But for classification tasks, we typically use only the [CLS] token, which is located at position 0
in the sequence:
h_cls = outputs.last_hidden_state[:, 0, :] # Shape: [batch_size, hidden_size]
This token is designed to act as a summary representation of the entire sequence, and this output shape matches exactly what a nn.Linear(hidden_size, num_classes)
expects β no flattening needed.
πΉ Why Not Flatten? #
Flattening the whole sequence (e.g., [batch_size, sequence_length * hidden_size]
) loses:
- Token order
- Positional embeddings
- Sequence structure
In NLP, this breaks the semantic and syntactic structure of the input. Instead, use:
πΈ Recommended Pooling Strategies #
Strategy | Description |
---|---|
[CLS] Token |
Use outputs[:, 0, :] ; trained as a sequence summary |
Mean Pooling | outputs.mean(dim=1) ; averages token embeddings |
Max Pooling | outputs.max(dim=1).values ; takes strongest signal |
Attention Pooling | Learns weights to summarize tokens adaptively |
π Sources and Justification #
- BERT Paper: Devlin et al. (2018) β [CLS] token for classification
- Sentence-BERT: Reimers & Gurevych (2019) β Mean pooling often better for embeddings
- Hugging Face Transformers: Practical implementation patterns
- NLP Community Practices: Kaggle, blogs, and tutorials
π§ͺ Summary #
- Use
[CLS]
or pooling (not flattening) for sequence-level tasks. - Flattening destroys sequence information and is rarely appropriate in NLP.
- The linear layer works on
[batch_size, hidden_size]
β no need to flatten across tokens.