Core Components and Fusion Strategies in Multimodal LLMs

Core Components of a Multimodal LLM #

  1. Visual Encoder
    Converts input images into feature embeddings. Common choices include CLIP, ViT, and EVA.

  2. Modality Adapter (Aligner)
    Projects or transforms visual features to be compatible with the language model’s embedding space (e.g., via MLP or cross-attention).

  3. Language Model (LLM)
    A large pretrained language model (e.g., LLaMA, GPT) that consumes both text and aligned visual inputs to generate or classify responses.


Fusion Strategies in Multimodal LLMs #

1. Projection + Token Injection #

Models: BLIP-2, LLaVA
How it works:

  • Visual features are extracted using a frozen image encoder (e.g., ViT or CLIP).
  • These features are projected via an MLP to match the LLM’s token embedding size.
  • The projected visual tokens are prepended or interleaved with text tokens.
# Hugging Face-style pseudocode
image_embeds = vision_encoder(image)         # Shape: (batch, num_patches, hidden_dim)
projected_embeds = visual_proj(image_embeds) # Match LLM hidden size
input_embeds = torch.cat([projected_embeds, text_token_embeds], dim=1)
output = llm(inputs_embeds=input_embeds)

2. Cross-Attention Adapters #

Models: Flamingo, MiniGPT-4
How it works:

  • Visual tokens are kept separate from text tokens.
  • The LLM has cross-attention layers where text tokens attend to visual context.
# Pseudocode with cross-attn
text_embeds = llm.text_embeddings(text_input)
visual_context = vision_encoder(image)

for block in llm.transformer_blocks:
    text_embeds = block.self_attn(text_embeds)
    text_embeds = block.cross_attn(text_embeds, context=visual_context)

3. Joint Pretraining (Early Fusion) #

Models: Unified-IO, GIT, PaLI
How it works:

  • Images are tokenized (as patches or regions).
  • Both image and text tokens are passed together into a unified transformer.
# Pseudocode for joint vision-text transformer
image_tokens = patch_embed(image)            # ViT-style patch tokens
text_tokens = tokenizer(text)
all_tokens = torch.cat([image_tokens, text_tokens], dim=1)
output = joint_transformer(all_tokens)