- Published on
- |Views: 62|7 min read
Stop using torch.cat for your KV cache implementations
- Authors

- Name
- Shashank Shekhar
- @sshkhr16
For a primer on the attention mechanism, transformer model, and KV cache optimizations, see my MLA blog on DeepSeek's Multi-Head Latent Attention.
KV caching is fundamental to efficient LLM inference. If you've implemented autoregressive decoding in PyTorch, you might have written something like this:
if self.k_cache is None:
self.k_cache = K
self.v_cache = V
else:
self.k_cache = torch.cat([self.k_cache, K], dim=2)
self.v_cache = torch.cat([self.v_cache, V], dim=2)
This pattern makes logical sense. At each decode step, concatenate the new key-value pair onto the existing cache.
But it goes against the spirit of caching.
Everyone uses this pattern
Search "KV cache PyTorch implementation" and you'll find torch.cat everywhere. Sebastian Raschka's tutorial and HuggingFace's KV caching blog post both use it, and they're far from alone. I found the same pattern in production codebases with over 100K stars on GitHub.
At this point, this usage pattern is so widespread that it has probably made its way into the training data for every coding model on the planet. It's logically correct. But it's also doing the opposite of what KV caching exists to do.
torch.cat does the opposite of caching
The whole point of a KV cache is to compute each token's key and value projections once and reuse them for the rest of decoding. Pay the cost once, reuse forever. That's what caching means.
But torch.cat re-copies every cached element at every single step. You're not caching, you're maintaining a copy log that gets rewritten from scratch each time. This isn't a new observation, either. There has been a bunch of posts on the PyTorch forums about this over the years, and the consensus has always been: torch.cat is not in-place, it allocates new memory and copies data. If you want to avoid that, one thing you could do is pre-allocate a buffer and write into it directly.
Concatenate tensors without memory copying from Sep 2019 asked about in-place concatenation. A PyTorch dev's response was essentially: you can't, but you can pre-allocate a buffer and write into it.
... Another solution is to pre-allocate the full tensor and compute t1 and t2 directly into it doing inplace operations. That way you don’t need the cat operation at all.
A year later (Sep 2020),Torch.cat() massive bottleneck?! raised this again, and the legendary ptrblck responded:
If you are using torch.cat in a loop, you should instead ... try to preallocate the output tensor and fill it.
Here's some pseudocode demonstrating what torch.cat is actually doing:
def torch_cat_equivalent(existing_cache, new_tensor, dim):
# Step 1: Allocate NEW memory for combined size
new_size = existing_cache.size(dim) + new_tensor.size(dim)
new_cache = torch.empty(..., new_size, ...) # O(N) allocation
# Step 2: Copy ALL existing data to new location
new_cache[..., :existing_cache.size(dim)] = existing_cache # O(N) copy!
# Step 3: Copy new data
new_cache[..., existing_cache.size(dim):] = new_tensor # O(1) copy
return new_cache
# Old cache becomes garbage, memory pressure until GC runs
We've eliminated redundant computation but introduced redundant memory operations that scale the same way. Every step copies everything that came before. That's the opposite of caching. At step , you copy elements just to add 1 new one. The cumulative cost over decode steps:
This is in addition to the unavoidable cost of the attention computation itself. We've introduced a second quadratic term, purely from memory allocation.
And it gets worse. cudaMalloc is not free, each allocation involves driver overhead and memory pool management. Repeatedly allocating slightly larger buffers fragments memory, potentially forcing expensive compaction. The cache moves to a new address each step, defeating hardware caching and memory prefetching. Every aspect of this pattern fights the GPU's memory subsystem.
Every torch.cat doesn't just add one element to your cache tensor. It copies the entire existing cache to a new memory location. At decode step 1000, you're copying 999 elements just to add 1.
A naive solution: pre-allocated buffers
A simple solution (based on the PyTorch dev's comments): allocate the full buffer upfront and write to it in place. This lends itself well to transformer decoding, since the max_seq_len is typically set in advance.
# Instead of this (quadratic):
self.k_cache = torch.cat([self.k_cache, K], dim=2)
# Do this (linear):
self.k_cache[:, :, self.position:self.position + seq_len] = K
self.position += seq_len
By just writing the new KV pair to the next slot, each step performs exactly memory operations. The cumulative cost drops from to .
Trade-off: you must know max_seq_len upfront. In practice, this is rarely a problem as we typically have a known context limit (4K, 8K, 128K tokens), and over-allocating by a small margin is cheap* (see notes) compared to the quadratic penalty.
I benchmarked both implementations on a single attention layer with d_model=1024, n_heads=16, across a range of context lengths. Full benchmark code is available as a Google Colab notebook.
| Context Length | torch.cat (ms) | Pre-allocated (ms) | Speedup |
|---|---|---|---|
| 128 | 4.79 | 4.52 | 1.06x |
| 256 | 6.78 | 5.98 | 1.13x |
| 512 | 11.31 | 8.90 | 1.27x |
| 1024 | 23.32 | 15.25 | 1.53x |
| 2048 | 56.16 | 30.15 | 1.86x |
| 4096 | 158.10 | 71.42 | 2.21x |
| 8192 | 537.23 | 200.80 | 2.68x |
| 16384 | 2013.42 | 661.04 | 3.05x |
| 32768 | 7768.38 | 2389.15 | 3.25x |
| 65536 | 30799.31 | 9376.71 | 3.28x |

Below ~1K tokens, the difference is in the noise. Between 1K and 8K, the gap starts to become becomes real as a 2-3x speedup translates to meaningful latency reduction. Above 8K tokens, the quadratic term dominates. At 64K context, we see over 3x speedup, and that's just one layer. A 32-layer model would see this penalty compounded.
Many completions are short enough that the quadratic cost stays in the noise. But long-context use cases (document summarization, RAG, agentic workflows) accumulate substantial context, and that's where allocation patterns matter.
Where to learn proper KV cache management
The tutorials using torch.cat are optimizing for clarity, not production performance. But they are anthetical to the idea of caching. The pre-allocated buffer strategy is the right trade-off for educational content without getting into too much details. But if you're building production inference systems, you need better references.
If you want to see KV cache management done right, start here:
- The PagedAttention paper and vLLM's cache management docs show how to manage KV cache in fixed-size blocks with memory sharing across requests.
- Meta's torchtune KV cache is a clean example of pre-allocated buffers done right.
- HuggingFace Transformers'
StaticCache(post-v4.36) is the pre-allocated alternative to theirDynamicCache.
Benchmark code: Google Colab Notebook