Understanding How to Use BERT's CLS Token for Classification

Date: 2025-03-31

❓ Question #

How can we use the [CLS] token (i.e., h_cls) from the last layer of BERT for classification tasks? Given that the BERT output has shape [batch_size, sequence_length, hidden_size], how is it valid to pass only [batch_size, hidden_size] to a nn.Linear(hidden_size, num_classes) without flattening the sequence? And why don’t we flatten the whole sequence β€” wouldn’t that destroy order?


βœ… Answer #

πŸ”Ή BERT Output and the [CLS] Token #

BERT outputs a tensor of shape:

[batch_size, sequence_length, hidden_size]

But for classification tasks, we typically use only the [CLS] token, which is located at position 0 in the sequence:

h_cls = outputs.last_hidden_state[:, 0, :]  # Shape: [batch_size, hidden_size]

This token is designed to act as a summary representation of the entire sequence, and this output shape matches exactly what a nn.Linear(hidden_size, num_classes) expects β€” no flattening needed.


πŸ”Ή Why Not Flatten? #

Flattening the whole sequence (e.g., [batch_size, sequence_length * hidden_size]) loses:

  • Token order
  • Positional embeddings
  • Sequence structure

In NLP, this breaks the semantic and syntactic structure of the input. Instead, use:

Strategy Description
[CLS] Token Use outputs[:, 0, :]; trained as a sequence summary
Mean Pooling outputs.mean(dim=1); averages token embeddings
Max Pooling outputs.max(dim=1).values; takes strongest signal
Attention Pooling Learns weights to summarize tokens adaptively

πŸ“š Sources and Justification #

  • BERT Paper: Devlin et al. (2018) β€” [CLS] token for classification
  • Sentence-BERT: Reimers & Gurevych (2019) β€” Mean pooling often better for embeddings
  • Hugging Face Transformers: Practical implementation patterns
  • NLP Community Practices: Kaggle, blogs, and tutorials

πŸ§ͺ Summary #

  • Use [CLS] or pooling (not flattening) for sequence-level tasks.
  • Flattening destroys sequence information and is rarely appropriate in NLP.
  • The linear layer works on [batch_size, hidden_size] β€” no need to flatten across tokens.