Published on
|Views: 23|37 min read

Understanding DeepSeek's Multi-Head Latent Attention (MLA)

Authors
Series Overview

This is Part 1 of a three-part series on DeepSeek's FlashMLA, based on my talk at the Toronto Machine Learning Summit 2025.

All code samples from this particular post are available in this GitHub repository: github.com/sshkhr/mla-pytorch

A Brief Recap of Multi-Head Scaled Dot Product Attention

Attention is all you need by Vaswani et al introduced the now ubiquitous transformer architecture, scaled dot-product attention, as well as the multi-head attention mechanism, which is the de facto form of attention one thinks of in transformers.

Scaled Dot-Product Attention and Multi-Head Attention

In scaled dot-product attention (SDPA), given a set of queries QQ, keys KK, and values VV (vectors derived from input tokens), attention computes a weighted sum of the values, where each query-key pair’s weight is proportional to their similarity. A simple approach to calculating similarity between tokens is by taking a dot product of the query and key vectors. The product is then scaled by the square root of the key dimension (in order to preserve an appropriate variance in the attention). Optionally, a mask is applied e.g. to compute only scores for a token based on all previous tokens in an autoregressive setting, not the future tokens. Lastly, a softmax operation is applied to get a probability distribution over the tokens. The attention weights are then used to produce a weighted sum of the values as the output:

Scaled Dot-Product Attention(Q,K,V)=softmax(QKTdk)V\text{Scaled Dot-Product Attention}(Q, K, V) = \mathrm{softmax}(\frac{Q K^T}{\sqrt{d_{k}}})\, V

If you prefer code to math, here is a minimal PyTorch implementation of scaled dot-product attention (optimized for reading clarity, to match the diagram above):

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention as described in "Attention Is All You Need"
    """
    def __init__(self):
        super().__init__()
    
    def forward(self, Q, K, V, mask=None):
        """
        Args:
            Q: Query tensor of shape (batch_size, seq_len_q, d_k)
            K: Key tensor of shape (batch_size, seq_len_k, d_k) 
            V: Value tensor of shape (batch_size, seq_len_v, d_v)
            mask: Optional mask tensor of shape (batch_size, seq_len_q, seq_len_k)
        
        Returns:
            output: Attention output of shape (batch_size, seq_len_q, d_v)
            attention_weights: Attention weights of shape (batch_size, seq_len_q, seq_len_k)
        """
        batch_size, seq_len_q, d_k = Q.shape
        seq_len_k = K.shape[1]
        
        # Step 1: Compute raw attention scores Q @ K^T
        # (batch_size, seq_len_q, d_k) @ (batch_size, d_k, seq_len_k) -> (batch_size, seq_len_q, seq_len_k)
        attention_scores = torch.matmul(Q, K.transpose(-2, -1))
        
        # Step 2: Scale by sqrt(d_k)
        attention_scores = attention_scores / math.sqrt(d_k)
        
        # Step 3: Apply mask (if provided)
        if mask is not None:
            attention_scores = attention_scores + mask  # mask should contain -inf for masked positions
        
        # Step 4: Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)  # (batch_size, seq_len_q, seq_len_k)
        
        # Step 5: Apply attention weights to values
        # (batch_size, seq_len_q, seq_len_k) @ (batch_size, seq_len_k, d_v) -> (batch_size, seq_len_q, d_v)
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

We will see more details on the exact shapes and sizes of the tensors in a bit, as this is fundamental to understanding the bottlenecks in the standard attention mechanism which necessitated MLA. But before that, we need to mention that scaled dot-product attention is usually not used in isolation. Each layer has multiple attention heads and the setup used is called multi-head attention. The different attention "heads" are simply several scaled dot-product attention modules running in parallel, each operating on a different learned projection of the input. Instead of one single attention computation with very large dkd_{k}, we split the queries, keys, and values into (for example) 8 or 16 smaller sets (heads), perform attention in each subspace, and then concatenate the results. Here is a simple PyTorch implementation of multi-head attention that utilizes the scaled dot-product attention we defined above:

Multi-Head Attention Implementation
class MultiHeadAttention(nn.Module):
   """
   Multi-Head Attention mechanism that applies multiple scaled dot-product attention heads in parallel
   """
   def __init__(self, d_model, n_heads):
       """
       Args:
           d_model: Model dimension (embedding dimension)
           n_heads: Number of attention heads
       """
       super().__init__()
       assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
       
       self.d_model = d_model
       self.n_heads = n_heads
       self.d_k = d_model // n_heads  # dimension per head (d_k = d_v in standard implementation)
       
       # Linear projection layers
       self.W_q = nn.Linear(d_model, d_model, bias=False)  # Query projection
       self.W_k = nn.Linear(d_model, d_model, bias=False)  # Key projection  
       self.W_v = nn.Linear(d_model, d_model, bias=False)  # Value projection
       self.W_o = nn.Linear(d_model, d_model, bias=False)  # Output projection
       
       # Scaled dot-product attention module
       self.attention = ScaledDotProductAttention()
   
   def forward(self, x, mask=None):
       """
       Args:
           x: Input tensor of shape (batch_size, seq_len, d_model)
           mask: Optional mask tensor of shape (batch_size, seq_len, seq_len)
       
       Returns:
           output: Multi-head attention output of shape (batch_size, seq_len, d_model)
       """
       batch_size, seq_len, d_model = x.shape
       
       # Step 1: Linear projections for Q, K, V
       # Each projection: (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
       Q = self.W_q(x)  # (batch_size, seq_len, d_model)
       K = self.W_k(x)  # (batch_size, seq_len, d_model)
       V = self.W_v(x)  # (batch_size, seq_len, d_model)
       
       # Step 2: Reshape and split into multiple heads
       # (batch_size, seq_len, d_model) -> (batch_size, seq_len, n_heads, d_k) -> (batch_size, n_heads, seq_len, d_k)
       Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
       K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
       V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
       
       # Step 3: Apply scaled dot-product attention to each head
       # Reshape to apply attention: (batch_size * n_heads, seq_len, d_k)
       Q_reshaped = Q.contiguous().view(batch_size * self.n_heads, seq_len, self.d_k)
       K_reshaped = K.contiguous().view(batch_size * self.n_heads, seq_len, self.d_k)
       V_reshaped = V.contiguous().view(batch_size * self.n_heads, seq_len, self.d_k)
       
       # Apply mask to all heads if provided
       if mask is not None:
           # Expand mask for all heads: (batch_size, seq_len, seq_len) -> (batch_size * n_heads, seq_len, seq_len)
           mask_expanded = mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1).view(batch_size * self.n_heads, seq_len, seq_len)
       else:
           mask_expanded = None
       
       # Apply attention: (batch_size * n_heads, seq_len, d_k)
       attention_output, _ = self.attention(Q_reshaped, K_reshaped, V_reshaped, mask_expanded)
       
       # Step 4: Concatenate heads
       # (batch_size * n_heads, seq_len, d_k) -> (batch_size, n_heads, seq_len, d_k) -> (batch_size, seq_len, d_model)
       attention_output = attention_output.view(batch_size, self.n_heads, seq_len, self.d_k)
       attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
       
       # Step 5: Final linear projection
       output = self.W_o(attention_output)  # (batch_size, seq_len, d_model)
       
       return output

Multi-head attention allows the model to focus on different positions or features in different heads. For example, one head might attend strongly to very nearby tokens while another attends to tokens far back, etc. The output of MHA is the same shape as the input (d_model), so it can be added back to the input (residual connection) and fed into the next layer.

Why Attention Scales O(N2)O(N^2) with Sequence Length

Now that we have the MHA machinery in place, let's examine where the computational bottleneck arises. The culprit is the attention score computation:

S=QKTdkS = \frac{QK^T}{\sqrt{d_k}}

Consider the shapes involved (per batch, per head):

  • QRN×dhQ \in \mathbb{R}^{N \times d_h}
  • KRN×dhK \in \mathbb{R}^{N \times d_h}
  • QKTRN×NQK^T \in \mathbb{R}^{N \times N}

The matrix multiplication (N×dh)(dh×N)(N \times d_h) \cdot (d_h \times N) requires computing N2N^2 dot products, each of dimension dhd_h. That's O(N2dh)O(N^2 \cdot d_h) operations per head, per layer. Since dhd_h is typically fixed (64-256), the dominant term is the quadratic dependence on sequence length.

Multi-Head Attention Operations
The N×NN \times N attention matrix is the computational bottleneck. For typical head dimensions (dh=64d_h = 64-256256) and modern context lengths (N=4KN = 4K-1M+1M+), N>>dhN >> d_h by orders of magnitude.

Let's make this concrete with some numbers:

import torch

def attention_tensor_sizes(batch_size, seq_len, d_model, n_heads, dtype=torch.float16):
    """Calculate tensor sizes for MHA computation."""
    d_h = d_model // n_heads
    bytes_per_elem = torch.tensor([], dtype=dtype).element_size()
    
    # Per-head shapes
    Q_shape = (batch_size, n_heads, seq_len, d_h)
    K_shape = (batch_size, n_heads, seq_len, d_h)
    V_shape = (batch_size, n_heads, seq_len, d_h)
    attn_scores_shape = (batch_size, n_heads, seq_len, seq_len)  # The N×N matrix
    
    # Memory in bytes
    qkv_memory = 3 * batch_size * n_heads * seq_len * d_h * bytes_per_elem
    attn_matrix_memory = batch_size * n_heads * seq_len * seq_len * bytes_per_elem
    
    return {
        "Q/K/V each": Q_shape,
        "Attention scores": attn_scores_shape,
        "QKV memory (MB)": qkv_memory / (1024**2),
        "Attention matrix memory (MB)": attn_matrix_memory / (1024**2),
    }

Let's compare what the attention matrix memory looks like as we scale up sequence length for GPT-2 (dheadd_{head}=768, 12 heads) and DeepSeek-V2 (dheadd_{head}=5120, 128 heads):

Sequence LengthGPT-2 AttentionDeepSeek-V2 Attention
5126 MB64 MB
102424 MB256 MB
204896 MB1.0 GB
4096384 MB4.0 GB
81921.5 GB16.0 GB
163846.0 GB64.0 GB
3276824.0 GB256.0 GB
6553696.0 GB1,024.0 GB
Attention matrices memory size as we scale up the sequence length for GPT-2 and DeepSeek-V2 models. Red indicates attention matrices that won't fit on a H100 GPU's 80 GB memory.

The quadratic scaling is brutal. At DeepSeek-V2 scale with 64K context, the attention matrix alone requires 1 TB of memory—per layer, per forward pass. This is clearly untenable.

Attention Memory Scaling

KV Caching: Trading Memory for Compute

Let's take a closer look at the matrix operations during the token generation process at inference time for the autoregressive transformer:

Decoding

Notice anything interesting 👀? Not yet?

How about if we focus on the step-by-step decoding process?

Decoding Step-Wise

Here's a key insight: during autoregressive generation, we decode one token at a time. At step tt, we only need to compute attention for the new query qtq_t against all previous keys and values. The attention scores for previous tokens don't change—we're just appending a new row.

Decoding vs Prefill

During prefill (processing the initial input context), we compute attention for all tokens simultaneously, resulting in the full N×NN \times N attention matrix.

During decoding (generating new tokens), we only compute attention for the new token's query against all previous keys and values, resulting in a single row of the attention matrix.

This observation leads us to KV caching: instead of recomputing KK and VV for all previous generated tokens at each step during inference, we cache them and only compute the projections for the new token:

Decoding with KV Cache

We cache KK and VV because the keys and values for previous tokens remain unchanged - only the new token's query QQ needs to attend to them.

Here's a PyTorch implementation that makes the caching logic concrete:

class MultiHeadAttentionWithKVCache(nn.Module):
    def __init__(self, d_model, n_heads, max_seq_len=4096):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_h = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        # Pre-allocated KV cache buffers (no allocation during decode)
        self.register_buffer('k_cache', torch.zeros(1, n_heads, max_seq_len, self.d_h))
        self.register_buffer('v_cache', torch.zeros(1, n_heads, max_seq_len, self.d_h))
        self.cache_position = 0

    def forward(self, x, use_cache=False):
        """
        Args:
            x: Input tensor. During prefill: (B, N, d_model). During decode: (B, 1, d_model)
            use_cache: Whether to use/update the KV cache
        """
        B, seq_len, _ = x.shape

        # Compute Q, K, V for current input
        Q = self.W_q(x).view(B, seq_len, self.n_heads, self.d_h).transpose(1, 2)
        K = self.W_k(x).view(B, seq_len, self.n_heads, self.d_h).transpose(1, 2)
        V = self.W_v(x).view(B, seq_len, self.n_heads, self.d_h).transpose(1, 2)

        if use_cache:
            # Write to pre-allocated buffer (no new allocation)
            start = self.cache_position
            end = start + seq_len
            self.k_cache[:B, :, start:end, :] = K
            self.v_cache[:B, :, start:end, :] = V
            self.cache_position = end

            # Attend over full cached sequence
            K = self.k_cache[:B, :, :end, :]
            V = self.v_cache[:B, :, :end, :]

        # Standard attention computation
        # Q: (B, n_heads, seq_len, d_h) - seq_len=1 during decode
        # K: (B, n_heads, cache_len, d_h)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_h)
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)

        # Reshape and project
        output = output.transpose(1, 2).contiguous().view(B, seq_len, self.d_model)
        return self.W_o(output)

    def reset_cache(self):
        self.k_cache.zero_()
        self.v_cache.zero_()
        self.cache_position = 0

The complexity tradeoff is stark:

Without KV CacheWith KV Cache
QKTQK^T compute per stepO(N2)O(N^2)O(N)O(N)
Memory per stepO(1)O(1)O(N)O(N)

We've converted a compute-bound problem into a memory-bound one... but how memory-bound? Let's quantify it.

The KV Cache Memory Wall

How much memory does the KV cache actually require? Each decoding step now requires only O(N)O(N) operations (the new query attending to NN cached keys), but we need to store all those keys and values. Let's work through the math for a small model (GPT-2):

def measure_kv_cache_memory(d_model, n_heads, seq_len, dtype=torch.float16):
    """Calculate KV cache memory in MB."""
    d_h = d_model // n_heads
    bytes_per_elem = torch.tensor([], dtype=dtype).element_size()
    kv_cache_bytes = 2 * n_heads * seq_len * d_h * bytes_per_elem
    return kv_cache_bytes / (1024**2)
TokensWith Cache (s)No Cache (s)SpeedupCache (MB)
1280.02930.02981.0x0.4
2560.05850.05991.0x0.8
5120.11760.15431.3x1.5
10240.23510.60652.6x3.0
20480.48513.53007.3x6.0
40960.969124.492425.3x12.0

On an NVIDIA 5060 Ti, decoding with a context window of 4096 tokens:

  • KV caching provides a 25.3x speedup
  • At the cost of 12.0 MB of cache memory
KV Caching benchmark

At GPT-2 scale (768 dim, 12 heads), 12 MB for 4K tokens seems manageable. But GPT-2 is a relic by modern standards—its 1.5B parameters and 1024-token context window are dwarfed by today's frontier models. The field has moved aggressively toward longer contexts, and for good reason.

Side Note: Why Long Context Matters

The push toward longer context windows isn't just about bragging rights. Several concrete capabilities unlock as context grows

  • In-context learning improves with more examples. The landmark GPT-3 paper (Brown et al., 2020) demonstrated that few-shot performance scales with the number of demonstrations. More context means more examples, which means better task adaptation without fine-tuning. Recent work on many-shot in-context learning (Agarwal et al., 2024) shows this scaling continues well into hundreds or thousands of examples—if you have the context window to fit them.

  • Complex reasoning requires extended working memory. Chain-of-thought prompting (Wei et al., 2022) and its descendants rely on the model "thinking out loud" across many tokens. For difficult problems—multi-step mathematics, code debugging, legal analysis—the reasoning trace itself can consume tens of thousands of tokens. Models like OpenAI's o1 and DeepSeek-R1 (DeepSeek-AI, 2025) generate extensive internal reasoning chains that would be impossible with short contexts.

  • Agentic workflows are inherently multi-turn. When LLMs operate as agents—browsing the web, writing and executing code, interacting with APIs—the conversation history accumulates rapidly. A single agentic session with tool calls, observations, and corrections can easily span 50K+ tokens. Projects like AutoGPT, OpenDevin, and Claude's computer use all stress-test context limits.

Long context in other modalities

Vision Transformers process images as sequences of patches—a single 1024×1024 image at 16×16 patch size yields 4,096 tokens. Video understanding multiplies this by frame count. Gemini 1.5 demonstrated 1M+ token contexts partly to handle hour-long videos. Robotics applications processing continuous sensor streams face similar challenges.

Now that we know why we care about long context inference, let's now consider what KV caching looks like at the scale of a modern frontier model. DeepSeek-V2 (DeepSeek-AI, 2024) uses 60 attention layers with 128 heads per layer and a head dimension of 128. The model supports 128K token contexts.

For standard multi-head attention, the KV cache per token requires storing KK and VV vectors for every head in every layer:

DeepSeek-V2 KV Cache (hypothetical standard MHA) with 60 layers, 128 heads, d_h=128, fp16:

Context LengthKV Cache Memory
1K tokens3.7 GB
4K tokens14.6 GB
32K tokens117.2 GB
128K tokens468.8 GB
The Memory Wall

At 128K context, the KV cache alone would require ~470 GB—far exceeding even the largest single GPU (H100 has 80GB). This is a fundamental blocker for long-context inference.

The situation compounds when we consider production realities of inference:

  • Batch size > 1: KV cache scales linearly with batch size. Serving 8 concurrent requests at 32K context would require nearly 1 TB just for the cache.
  • Multiple users: Production systems serve thousands of concurrent sessions. Even with sophisticated memory management (paging, offloading), the aggregate memory pressure is immense.
  • Model weights: The parameters themselves still need ~30-60GB in fp16. The KV cache competes for the same limited HBM.

We've traded one impossible problem (quadratic compute) for another (linear but massive memory). The field has developed several approaches to address this memory wall. Each makes different tradeoffs between memory efficiency and model capacity. Let's examine them.

Attention Variants

The memory wall has motivated several architectural modifications that reduce KV cache size by sharing key-value heads across multiple query heads.

MHA Variants
In MHA, each query head has its own K and V. GQA groups multiple query heads to share a single K/V pair. MQA takes this to the extreme—all query heads share one K/V pair.

Multi-Query Attention (MQA) (Shazeer, 2019) was the first major attempt to address KV cache bloat. The idea is simple: instead of each attention head having its own KK and VV projections, all heads share a single KK and a single VV. The queries remain independent (hence "multi-query"), but the keys and values are computed once and broadcast to all heads.

Grouped-Query Attention (GQA) (Ainslie et al., 2023) is a middle ground. Instead of all heads sharing one KV pair (MQA) or each head having its own (MHA), GQA divides the query heads into ngn_g groups, where each group shares a single KV pair. This interpolates between MHA (ng=nhn_g = n_h) and MQA (ng=1n_g = 1).

You can compare the three attention variants using their minimal implementations below:

MHA Variants KV Cache KV cache memory usage comparison for Multi-Head Attention (MHA), Grouped-Query Attention (GQA), and Multi-Query Attention (MQA) for 1 layer at DeepSeek-V2 scale (128 heads, head dim 128, fp16).

The memory savings are substantial. We show the savings per layer in the figure above. But the real benefit comes when we scale to full models and long contexts. Over a full 128K token context at DeepSeek-V2 scale, the KV cache sizes are:

Multi-Head (MHA)Grouped-Query (GQA)Multi-Query (MQA)
KV cache per token2nhdhl2 n_h d_h l2ngdhl2 n_g d_h l2dhl2 d_h l
KV cache size per token4 MB500 KB (ng=8)(n_g = 8)31 KB
KV cache for 128K context512 GB64 GB~4 GB
Model qualityHighMediumLow
Cache sizes computed for DeepSeek-V2 scale: nh=128n_h = 128 heads, dh=128d_h = 128 head dimension, l=60l = 60 layers, fp16 precision. GQA assumes 8 KV groups.

where:

  • nhn_h = number of query heads
  • ngn_g = number of KV head groups (GQA parameter)
  • dhd_h = dimension per head
  • ll = number of layers

MQA reduces the 128K cache from 512 GB to just 4 GB! But the savings in memory come at the cost of reduced model capacity.

The Capacity Tradeoff

The multiple heads in MHA aren't just redundancy. Each head (in theory) learns to attend to different aspects of the input—one head might track syntactic dependencies, another semantic similarity, another positional relationships. When we force all heads to share the same keys and values, we're compressing this representational diversity.

Empirically, MQA degrades model quality. The original MQA paper showed modest perplexity increases, but subsequent work found larger gaps on downstream tasks, particularly for complex reasoning. GQA recovers much of the quality while retaining most of the memory savings, which is why models like Llama 2 70B and Mistral adopted it.

But we're still making a tradeoff: memory efficiency versus representational capacity. What if we could have both?

Multi-Head Latent Attention : Low-Rank Decomposition of KV matrices

GQA and MQA reduce cache size by sharing KV heads, but they sacrifice model capacity in the process. Deepseek's Multi-Head Latent Attention takes a fundamentally different approach: instead of sharing heads, it compresses the KV representations into a low-dimensional latent space.

Let's take a look again at standard multi-head attention.

MHA KV Cache

The KV cache stores the projected keys and values for each token:

K=XWKandV=XWVK = XW_K \quad \text{and} \quad V = XW_V

where XRN×dX \in \mathbb{R}^{N \times d} is the input sequence and WK,WVRd×(nhdh)W_K, W_V \in \mathbb{R}^{d \times (n_h \cdot d_h)} are the projection matrices. The cache stores KK and VV, which have dimension nhdhn_h \cdot d_h per token.

But here's the question: do we really need to store the full projected representations?

The KV projections map from dd-dimensional inputs to (nhdh)(n_h \cdot d_h)-dimensional outputs. For DeepSeek-V2, that's mapping from 5120 dimensions to 128×128=16384128 \times 128 = 16384 dimensions. What if the "useful" information in those 16K dimensions actually lives in a much smaller subspace?

This idea—that high-dimensional representations can be approximated by low-dimensional ones—is a recurring theme in machine learning:

Low-Rank Structure in Deep Learning

The principle of low-rank factorization appears throughout ML:

  • LoRA (Hu et al., 2021): Fine-tunes large models by learning low-rank updates ΔW=BA\Delta W = BA where BRd×rB \in \mathbb{R}^{d \times r} and ARr×dA \in \mathbb{R}^{r \times d} with rdr \ll d
  • Bottleneck layers: ResNet and Inception use 1×1 convolutions to compress channels before expensive 3×3 operations
  • Autoencoders: Compress inputs to a latent space, then reconstruct—the latent captures the "essential" information
  • Matrix factorization: Recommendation systems approximate user-item matrices as products of low-rank factors

The common insight: real-world data often has intrinsic dimensionality much lower than its ambient dimensionality.

For a weight matrix WRdin×doutW \in \mathbb{R}^{d_{in} \times d_{out}}, we can approximate it as:

WWdownWupW \approx W_{\text{down}} \cdot W_{\text{up}}

where WdownRdin×dcW_{\text{down}} \in \mathbb{R}^{d_{in} \times d_c} and WupRdc×doutW_{\text{up}} \in \mathbb{R}^{d_c \times d_{out}}, with compression dimension dcmin(din,dout)d_c \ll \min(d_{in}, d_{out}).

Low-rank factorization

The intermediate representation after WdownW_{\text{down}} lives in a compressed dcd_c-dimensional space. If we're strategic about what we cache, we can store this compressed representation instead of the full projection.

Multi-Head Latent Attention applies this principle directly to the KV projections. Instead of:

K=XWK,V=XWVK = XW_K, \quad V = XW_V

MLA first projects to a shared low-dimensional latent representation, then expands back:

LKV=XWDKV(down-projection to latent)L_{KV} = XW_{\text{DKV}} \quad \text{(down-projection to latent)} K=LKVWUK,V=LKVWUV(up-projection to full K, V)K = L_{KV}W_{\text{UK}}, \quad V = L_{KV}W_{\text{UV}} \quad \text{(up-projection to full K, V)}

where:

  • WDKVRd×dcW_{\text{DKV}} \in \mathbb{R}^{d \times d_c} projects to the compressed latent space
  • WUKRdc×(nhdh)W_{\text{UK}} \in \mathbb{R}^{d_c \times (n_h \cdot d_h)} expands latent to keys
  • WUVRdc×(nhdh)W_{\text{UV}} \in \mathbb{R}^{d_c \times (n_h \cdot d_h)} expands latent to values
  • dcd_c is the latent/compression dimension (e.g., 512 in DeepSeek-V2)

The crucial insight: we only need to cache LKVL_{KV}, not the full KK and VV. The up-projections WUKW_{\text{UK}} and WUVW_{\text{UV}} are static weights—we can apply them on-the-fly during attention computation.

MLA

The cache size per token drops from 2nhdh2 \cdot n_h \cdot d_h (for K and V) to just dcd_c. For DeepSeek-V2 with nh=128n_h = 128, dh=128d_h = 128, and dc=512d_c = 512:

Compression ratio=2×128×128512=32768512=64\text{Compression ratio} = \frac{2 \times 128 \times 128}{512} = \frac{32768}{512} = 64

For DeepSeek-V2's parameters, this is a 64X reduction in KV cache size—comparable to MQA's savings, but without sacrificing the representational capacity of having many independent heads.

Here's the implementation:

Query Compression and RoPE in MLA

The implementation above uses a standard query projection (W_q) for simplicity. However, DeepSeek-V2/V3 also compress the queries via a low-rank decomposition, mirroring the KV compression:

LQ=RMSNorm(XWDQ),Q=LQWUQL_Q = \text{RMSNorm}(XW_{\text{DQ}}), \quad Q = L_Q W_{\text{UQ}}

where WDQRd×dcqW_{\text{DQ}} \in \mathbb{R}^{d \times d_{cq}} down-projects to a query latent, RMSNorm normalizes it, and WUQRdcq×(nhdh)W_{\text{UQ}} \in \mathbb{R}^{d_{cq} \times (n_h \cdot d_h)} up-projects back to full query dimension. DeepSeek-V2 uses dcq=1536d_{cq} = 1536.

Unlike the KV latent, the query latent is not cached — it's only needed for the current token. The purpose is parameter efficiency: the two smaller matrices WDQW_{\text{DQ}} and WUQW_{\text{UQ}} have fewer total parameters than a single full-rank WQW_Q.

The RoPE embedding is also operated on separately to preserve the positional embeddings, since standard low-rank projection would mix positional and content information.

This is the full architecture of MLA with query compression and RoPE (taken from the DeepSeek-V3 paper):

Full MLA Architecture

Here is the full MLA implementation with query compression. We will skip the RoPE details for now, as they are not important for understanding the core idea of latent KV compression:

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight


class MultiHeadLatentAttention(nn.Module):
    """
    MLA with both query and KV compression.

    Query: x -> W_dq -> RMSNorm -> W_uq -> Q  (not cached)
    KV:    x -> W_dkv -> L_kv (cached) -> W_uk/W_uv -> K, V
    """
    def __init__(self, d_model, n_heads, d_c, d_cq, max_seq_len=4096):
        super().__init__()
        self.n_heads = n_heads
        self.d_h = d_model // n_heads
        self.d_c = d_c
        self.d_cq = d_cq

        # Query compression: down-project, normalize, up-project
        self.W_dq = nn.Linear(d_model, d_cq, bias=False)
        self.q_norm = RMSNorm(d_cq)
        self.W_uq = nn.Linear(d_cq, n_heads * self.d_h, bias=False)

        # KV compression
        self.W_dkv = nn.Linear(d_model, d_c, bias=False)
        self.W_uk = nn.Linear(d_c, n_heads * self.d_h, bias=False)
        self.W_uv = nn.Linear(d_c, n_heads * self.d_h, bias=False)

        self.W_o = nn.Linear(n_heads * self.d_h, d_model, bias=False)

        # Pre-allocated latent cache
        self.register_buffer('latent_cache', torch.zeros(1, max_seq_len, d_c))
        self.cache_position = 0

    def forward(self, x, use_cache=False):
        B, N, _ = x.shape

        # Query: down-project -> normalize -> up-project
        L_q = self.q_norm(self.W_dq(x))
        Q = self.W_uq(L_q).view(B, N, self.n_heads, self.d_h).transpose(1, 2)

        # KV: compress to latent
        L_kv = self.W_dkv(x)

        if use_cache:
            start = self.cache_position
            end = start + N
            self.latent_cache[:B, start:end, :] = L_kv
            self.cache_position = end
            L_kv = self.latent_cache[:B, :end, :]

        # Expand latent to full K, V (on-the-fly, not cached)
        K = self.W_uk(L_kv).view(B, -1, self.n_heads, self.d_h).transpose(1, 2)
        V = self.W_uv(L_kv).view(B, -1, self.n_heads, self.d_h).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_h)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)

        out = out.transpose(1, 2).contiguous().view(B, N, -1)
        return self.W_o(out)

    def cache_size_per_token(self):
        return self.d_c  # Query compression doesn't affect cache size

Note that query compression does not affect cache size — only the KV latent is cached. With the absorption trick (below), the absorbed query projection becomes WQ=WUQWUKTW_{Q}' = W_{\text{UQ}} \cdot W_{\text{UK}}^T, applied after the query compression and normalization.

The Absorption Trick: Making MLA Hardware-Friendly

The naive MLA implementation above has a problem: during decoding, we need to expand the entire cached latent sequence through WUKW_{\text{UK}} and WUVW_{\text{UV}} at every step. For a 100K token context, that's a lot of extra compute.

But there's a clever mathematical trick that eliminates this overhead. Let's trace through the attention computation:

Attention scores=QKT=(XWQ)(LKVWUK)T=XWQWUKTLKVT\text{Attention scores} = QK^T = (XW_Q)(L_{KV}W_{\text{UK}})^T = XW_Q W_{\text{UK}}^T L_{KV}^T

We can absorb WUKTW_{\text{UK}}^T into the query projection by defining:

WQ=WQWUKTW_Q' = W_Q W_{\text{UK}}^T

Now the attention scores become:

QKT=(XWQ)LKVTQK^T = (XW_Q')L_{KV}^T

The queries are projected directly into the latent space, and we compute attention against the cached latents without ever materializing the full keys!

Similarly, for the value side, instead of computing softmax(QKT)V\text{softmax}(QK^T) \cdot V and then applying WOW_O, we can absorb WUVW_{\text{UV}} into the output projection:

WO=WUVWOW_O' = W_{\text{UV}} W_O

The full absorbed computation becomes:

Output=softmax((XWQ)LKVTd)LKVWO\text{Output} = \text{softmax}\left(\frac{(XW_Q')L_{KV}^T}{\sqrt{d}}\right) \cdot L_{KV} \cdot W_O'

MLA with absorption

This is remarkable: the attention operation itself is unchanged—it's still QKTQK^T followed by softmax and multiplication with values. We've just redefined what QQ, KK, and VV mean. The latent LKVL_{KV} plays the role of both keys and values!

The absorbed version has a crucial advantage: attention is computed entirely in the latent space. We never materialize the full nhdhn_h \cdot d_h-dimensional keys and values. This means:

  1. Memory: Cache remains at dcd_c per token (same as naive MLA)
  2. Compute: No per-step expansion of cached latents through WUKW_{\text{UK}}, WUVW_{\text{UV}}
  3. Hardware compatibility: The attention pattern (QKTsoftmax×VQK^T \rightarrow \text{softmax} \rightarrow \times V) is identical to standard attention

That last point is critical. Because the absorbed MLA has the same computational structure as standard attention, we can leverage highly optimized attention kernels like FlashAttention. The latent simply plays the role of a compressed key/value representation.

Looking Ahead: FlashMLA

In the next post, we'll explore how FlashAttention's memory-efficient tiling strategy applies directly to MLA's latent attention. The absorption trick isn't just mathematically elegant—it's the key to achieving both memory efficiency and computational efficiency on modern GPUs.

Summary

Let's compare the cache sizes across all the approaches we've covered:

MHAGQA (ng=8n_g = 8)MQAMLA (dc=512d_c = 512)
KV cache per token2nhdhl2 n_h d_h l2ngdhl2 n_g d_h l2dhl2 d_h ldcld_c l
Cache size (DeepSeek-V2)3.9 MB244 KB15 KB60 KB
128K context512 GB32 GB2 GB8 GB
Relative to MHA16×256×64×
Model qualityBaselineSlight degradationSignificant degradationMatches baseline
DeepSeek-V2: nh=128n_h = 128 heads, dh=128d_h = 128 head dim, l=60l = 60 layers, dc=512d_c = 512 latent dim, fp16.

MLA achieves compression comparable to MQA while maintaining the representational capacity of full MHA. The key insight is that compression happens in a learned latent space, not by arbitrarily sharing heads.

The DeepSeek team reports that MLA matches or exceeds MHA performance on their benchmarks while reducing KV cache by ~57X. This is the breakthrough that makes 128K+ context windows practical on commodity hardware.

References


Citation

If you found this post helpful, please consider citing it:

@article{shekhar2026mla,
  title   = {Understanding DeepSeek's Multi-Head Latent Attention (MLA)},
  author  = {Shekhar, Shashank},
  journal = {shashankshekhar.com},
  year    = {2026},
  month   = {February},
  url     = {https://shashankshekhar.com/blog/flashmla/flashmla-1-mla}
}