Posts

Multi-head Latent Attention (In A Nutshell!)

Image
In this post, I will dive into the Multi-head Latent Attention (MLA) mechanism, one of the innovations presented by the DeepSeek Team! This post assumes prior knowledge of the attention mechanism and the key-value cache. For a quick refresher on these topics, refer to my previous post on self-attention! One of the main problems with multi-head self-attention is the memory cost associated with the size of the key-value cache. MLA reduces the size of the key-value cache and speeds up LLM inference. The core idea is to cache latent embeddings that are  shared across all heads (and for both keys and values)  instead of different key and value embeddings for each head like in multi-head self-attention (Figure 1). The latent embeddings are multiplied with different key and value up-projection matrices for each head to produce different key and value embeddings unique to each head. Having unique key and value embeddings for each head maintains the expressivity of the attention mech...

Self-Attention and the Key-Value Cache (In A Nutshell!)

Image
The Transformer architecture underpins most modern large language models (LLMs). In the seminal paper  "Attention is All You Need" , Vaswani et al. propose a Transformer architecture that relies solely on the multi-head self-attention mechanism to learn global dependencies (i.e. relationships) between words in a sentence. In this post, I will first explain the multi-head self-attention mechanism that is used in LLMs such as the original ChatGPT model (which was derived from GPT-3.5), and go on to explain why a key-value cache is needed for efficient inference. For illustrative purposes, I use words to represent tokens. ChatGPT uses a decoder-only Transformer architecture, and is trained to predict the next token (e.g. sub-word, punctuation) given a context (i.e. message). I start by illustrating the multi-head self-attention mechanism using a single sentence as an example. Assume that we are training an LLM with a context window of 10 words; a 10-word sentence in the trai...

Training an LLM (In a Nutshell!)

A large language model (LLM) learns how to reply to a conversation by learning how to predict the next token (akin to a word), given the preceding conversation. This is framed as a multi-class classification problem, where each token represents a different class. The LLM outputs the likelihood of each token in the vocabulary (i.e. set of all possible tokens) of being the next token (which represent the probability parameters of a categorical distribution). Prediction of the next token occurs by sampling from the categorical distribution of all possible tokens. After a token is predicted, it is used along with tokens from the preceding conversation to predict the next token. Training an LLM typically consists of three main stages (different LLMs may have different training schemes): Unsupervised pre-training Supervised fine-tuning Reinforcement learning Stages 2 and 3 are commonly termed collectively as the fine-tuning stage. Unsupervised pre-training In the unsupervised pre-tra...

Upcoming blog posts

 Here is a list of topics that I plan to post about in the future! They will be part of the “In a nutshell” series introducing important topics on AI! [ x ] Training an LLM (In A Nutshell!) [    ] Deep Reinforcement Learning (In A Nutshell!) [    ] AlphaFold 3: Triangle Attention (In A Nutshell!) [   ] Diffusion Models (In A Nutshell!) [   ] Large Multimodal Models (In A Nutshell!) [   ] Importance of tool calling in LLMs! [   ] DeepSeekV3: Multi-head latent attention! Do let me know if you would like me to post about any topic!

Popular posts from this blog

Training an LLM (In a Nutshell!)

Self-Attention and the Key-Value Cache (In A Nutshell!)

Upcoming blog posts