Notes on the Transformer architecture.
The transformer is a sequence-to-sequence neural architecture that revolutionized natural language processing by efficiently learning contextual representations through parallelized attention mechanisms. Originally introduced for machine translation in the seminal paper “Attention is All you Need” by Vaswani et al. (2017)
The motivation behind the transformer architecture lies in the limitations found in the previous state-of-the-art architectures for seq2seq tasks that used Recurrent Neural Networks (RNNs).
The Long Short-Term Memory (LSTM) architecture encoded the context from previous time steps in a single “state vector” that is continuously updated as information flows through each processing step. However, this results in the state vector suffering significant changes from one step to the next, and thus losing previous context quickly, especially from far-away previous steps.
To account for this, Gated Recurrent Units (GRU) were introduced by Chung, Junyoung et al. (2014)
There are two major shortcomings in these RNN approaches:
The transformer architecture addresses these shortcomings by:
However there is one important shortcoming introduced in the transformer architecture with respect to previous RNN approaches:
The Multi-Head Self-Attention mechanism allows words in the sequence to “pay attention” (attend) to each other, and through the learning procedure, eventually they will be able of identifying which words from previous time-steps are relevant to generate the element for the current time step.
To learn this component, we define three weight matrices: Query ($Q$), Key ($K$), and Value ($V$). These weight matrices will be multipled to obtain the attention matrix in the following way:
\[\begin{equation}\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\end{equation}\]We apply a dot product between the queries and the keys, and then divide the result by the square root of the number of dimensions of the vectors. This is done to make the dot-product maintain variance close to $1$. Then, we apply softmax to make this distribution sum to $1$, and finally apply a dot-product with the value matrix. Basically the softmax computation acts as a weight to scale the influence of the value matrix.
A good intuition for how this works is to compare the concept with a lookup table / hash table / dictionary.
In a lookup table, we have the following structure:
d = {
"dog": "bark",
"wolf": "howl",
"cat": "meow"
}
# this key exists, the query will match
# and return "value1"
query1 = "dog"
value1 = lookup[query1]
# this key doesnt exist, the query won't
# match and return and error
query3 = "bird"
value3 = lookup[query3]
Note how a query will either match entirely with a single key (and then return it’s value), or it will not match at all with any of the keys (and then return an error or a predefined value).
In the attention mechanism, we are performing a similar operation, but we implement a “soft” match of keys and queries. In fact, all queries will match will all keys, but in different intensities. The softmax computation introduces this notion of intensity, which will weigh the value that is going to be returned.
d = {
"dog": "bark":,
"wolf": "howl",
"cat": "meow"
}
# query doesn't exist in the dictionary
# but still we will match it with the keys by their similarity
query = "husky"
# the query has high similarity with dog, less with wolf, almost none with cat
attention = 0.7 * lookup["dog"] + 0.29 * lookup["wolf"] + 0.01 * lookup["cat"]
The output of the scaled dot-product is a square matrix with dimensions context_length by context_length, representing the attention score of token $i$ for token $j$. Note that the attention for token $i$ to token $j$ is not the same for token $j$ to token $i$ (i.e., the matrix is not symmetric). We then multiply this matrix with the value matrix to scale it according to the attention scores. $V$ has dimension (context_size, embedding_dim), thus the dot product between the attention scores and $V$ yields a matrix (context_size, embedding_dim).
Another important consideration is that the self-attention mechanism is slightly different in the encoder and in the decoder:
Finally, the last missing component for this mechanism is the composition of multiple SA units into a single Multi-Head Self-Attention (MHA) unit.
Motivation: a single attention mechanism might struggle to capture all the intricate patterns in complex sequences due to its limited capacity. Multi-head attention mitigates this by distributing the learning task across multiple heads, reducing the risk of bottlenecks where certain crucial information might be missed or underrepresented. In other words, each head will learn a different latent space, and encode its specialized features into a segment of the final vector. Hopefully each of these segments will capture specialized characteristics of the data, as opposed to having a single vector space encoding everything as in the single SA unit.
class HeadAttention(nn.Module):
def __init__(
self,
block_size,
dim_embedding,
dim_head
):
super().__init__()
self.dim_head = dim_head
self.dim_embedding = dim_embedding
# It is conventional to remove the bias for these
self.key = nn.Linear(
dim_embedding,
dim_head,
bias=False
)
self.query = nn.Linear(
dim_embedding,
dim_head,
bias=False
)
self.value = nn.Linear(
dim_embedding,
dim_head,
bias=False
)
# a register_buffer saves this variable in the model
# params, but they are not considered trainable
# parameters. can access the variable with self.mask
self.register_buffer(
"mask", torch.tril(
torch.ones(block_size, block_size)
))
def forward(self, x):
k = self.key(x)
q = self.query(x)
v = self.value(x)
prod = q @ k.transpose(-2, -1) * self.dim_head**-0.5
prod = prod.masked_fill(self.mask == 0, float("-inf"))
prod = F.softmax(prod, dim=-1)
out = prod @ v
return out
class MultiHeadAttention(nn.Module):
def __init__(
self,
block_size,
num_heads,
dim_head,
dim_embedding
):
super().__init__()
self.head_list = nn.ModuleList(
[
HeadAttention(
block_size,
dim_embedding,
dim_head,
) for _ in range(num_heads)]
)
self.proj = nn.Linear(dim_embedding, dim_embedding)
def forward(self, x):
out = torch.cat(
[head(x) for head in self.head_list],
dim=-1
)
out = self.proj(out)
return out
Unstructured inputs (e.g., text, images) must be converted into a numerical representation that will be processed. Mikolov et al. (2013)
In transformers, we also define an embedding matrix that will be learned jointly with the other components. The embedding matrix is initialised randomly, and is of dimension (vocabulary_size, embedding_dimension). Each input token will retrieve an embedding from the table. This embedding will then represent this token in input space.
Along the word embeddings, the transformer architecture defines a second type of embedding called positional encodings. Previously in this document we discussed how the SA mechanism doesn’t introduce any notion of space. The model doesn’t know what is the first and the second token and so on. Therefore, we must somehow encode this information.
There are many approaches to encode position into the word embeddings. Some approaches don’t involve learning new weights. For example, we can use sine and cosine functions to encode relative positions as a combination of previous positions.
The benefit of this approach is that:
Other approaches involve learning the embeddings jointly with the transformer block. We define an embedding matrix of size (context_size, embedding_dimention). For each token position, we will learn an embedding. This generally works well, at the cost of having a fixed number of positions (up to context_size), and adding more trainable parameters.
The input to the transformer block is the element-wise sum of both the word embedding and the positional encodings.
class EmbeddingEncoder(nn.Module):
def __init__(self, dim_embedding, vocab_size, context_size):
super().__init__()
self.embedding_table = nn.Embedding(
vocab_size,
dim_embedding
)
self.positional_embedding_table = nn.Embedding(
context_size,
dim_embedding
)
def forward(self, x):
x_emb = self.embedding_table(x)
pos_emb = self.positional_embedding_table(
torch.arange(
self.block_size,
device=device
)
)
x_emb = x_emb + pos_emb
return x_emb
Up to now, all the operations performed are linear transformations: matrix multiplications, softmax, addition. To add non-linearity, we simply put a feed forward network with a non-linear activation after the MHA layer.
This is usually a very shallow network with only $1$ hidden layer with a ReLU / GeLU / etc. activation function. It is also common to do an up projection and a down projection in the hidden layer, meaning the input gets stretched to a higher dimensional space when entering the hidden unit, then gets down projected to the original size after leaving the hidden unit.
class MLP(nn.Module):
def __init__(
self,
dim_head,
dim_embedding,
):
super().__init__()
# increase the dimension by 4
# then bring it back to the original size
self.net = nn.Sequential(
nn.Linear(dim_head, 4*dim_embedding),
nn.ReLU(),
nn.Linear(4*dim_embedding, dim_embedding)
)
def forward(self, x):
out = self.net(x)
return out
Batch normalisation is a procedure that helps stabilise and accelerate the training process by reducing internal covariate shift, which refers to the change in the distribution of layer inputs during training. It consists in normalising the outputs of the layers so that their distribution have a mean of $0$ and a variance of $1$, then applying two learned parameters to scale and shift the distribution. This reduces the chance that gradients explode or vanish (grow or shrink exponentially). This is a key engineering component in ensuring that large and deep neural networks can converge in a stable fashion.
To normalise an input batch, we scale it in this fashion:
\[\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}\]where $\mu_B$ is the mean of the mini-batch, $\sigma_B^2$ is the variance of the mini-batch, and $\epsilon$ is a small constant added for numerical stability.
Next, we define two learnable parameters gamma and beta to allow the model to reconstruct the original (non-scaled) distribution, if needed:
\[out = \gamma \hat{x}_i + \beta\]If the learned $\gamma$ is equal to $1$ and the learned $\beta$ is equal to $0$, this layer is simply normalising the input to mean $0$ and variance $1$. However, the model may learn other parameters that allows it to scale and shift the distribution as it pleases.
class BatchNorm(nn.Module):
def __init__(
self,
embedding_dim,
eps=1e-5,
momentum=0.1
):
super(BatchNorm, self).__init__()
self.embedding_dim = embedding_dim
self.eps = eps
self.momentum = momentum
self.gamma = nn.Parameter(torch.ones(embedding_dim))
self.beta = nn.Parameter(torch.zeros(embedding_dim))
# Initialize running mean and variance
self.register_buffer(
'running_mean',
torch.zeros(embedding_dim)
)
self.register_buffer(
'running_var',
torch.ones(embedding_dim)
)
def forward(self, x):
# mean and var are only updated in training
if self.training:
batch_mean = x.mean(keepdim=True)
batch_var = x.var(unbiased=False, keepdim=True)
# Update running statistics
self.running_mean = (
(1 - self.momentum) *
self.running_mean +
self.momentum *
batch_mean
)
self.running_var = (
(1 - self.momentum) *
self.running_var +
self.momentum *
batch_var
)
# stored mean and variance are used in inference
else:
batch_mean = self.running_mean
batch_var = self.running_var
# Normalise to mean 0 variance 1
x_hat = (
(x - batch_mean) /
torch.sqrt(batch_var + self.eps)
)
# Scale and shift the distribution
out = self.gamma * x_hat + self.beta
return out
Residual connections were introduced by He et al. (2015)
As neural networks become deeper, gradients can become very small during backpropagation, making it difficult to update the weights of earlier layers effectively. This can lead to very slow convergence or even stopping learning altogether.
Residual connections address these issues by allowing the gradient to flow directly through the network, making it easier to train deep networks. They work by adding the input of a layer to its output, effectively allowing the network to learn a residual mapping (identity function).
For very deep neural nets, at some point the data transformations being applied to the input are not beneficial anymore, and performance starts decreasing. To fix this, residual connections allow some layers to simply learn “nothing”, i.e. learn to pass the input as the output without modifying it whatsoever (i.e., learn the identity function). With this approach, scaling the network should never hurt performance, as it can simply stop learning new functions if needed.
class GenericLayer(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
...
)
def forward(self, x):
out = x + self.net(x)
return out
Check this google collab to test the architecture with a practical example borrowed from Andrej Karpathy’s lecture series.
# Source: adapted from https://youtu.be/kCc8FmEb1nY
class Head(nn.Module):
"""One head of self-attention."""
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
k = self.key(x) # (B, T, C)
q = self.query(x) # (B, T, C)
wei = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
wei = self.dropout(wei)
v = self.value(x) # (B, T, C)
out = wei @ v # (B, T, C)
return out
class MultiHeadAttention(nn.Module):
"""Multiple heads of self-attention in parallel."""
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedForward(nn.Module):
"""A simple linear layer followed by a non-linearity."""
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Transformer(nn.Module):
"""A decoder-only transformer block."""
def __init__(self, n_embd, n_head):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedForward(n_embd)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class BigramLanguageModel(nn.Module):
"""Bigram language model using transformers."""
def __init__(self):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd) # Final layer norm
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding_table(idx) # (B, T, C)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
x = tok_emb + pos_emb # (B, T, C)
x = self.blocks(x) # (B, T, C)
x = self.ln_f(x) # (B, T, C)
logits = self.lm_head(x) # (B, T, vocab_size)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] # (B, C)
probs = F.softmax(logits, dim=-1) # (B, C)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx