Transformers

Notes on the Transformer architecture.

Introduction

The transformer is a sequence-to-sequence neural architecture that revolutionized natural language processing by efficiently learning contextual representations through parallelized attention mechanisms. Originally introduced for machine translation in the seminal paper “Attention is All you Need” by Vaswani et al. (2017) , transformers have since become the foundation for most modern NLP systems.

Figure 1: Overview of the Transformer Architecture (Source).

Motivation

The motivation behind the transformer architecture lies in the limitations found in the previous state-of-the-art architectures for seq2seq tasks that used Recurrent Neural Networks (RNNs).

The Long Short-Term Memory (LSTM) architecture encoded the context from previous time steps in a single “state vector” that is continuously updated as information flows through each processing step. However, this results in the state vector suffering significant changes from one step to the next, and thus losing previous context quickly, especially from far-away previous steps.

Figure 2: Recurrent Neural Network: each time step $t$ processes an input $x_t$ and outputs an output $o_t$. A state vector $V$ is carried from previous time steps (initialised as zero). Source: Raj, Dinesh Samuel Sathia et al.

To account for this, Gated Recurrent Units (GRU) were introduced by Chung, Junyoung et al. (2014) . The key idea behind GRU is the use of gate mechanisms that open or close the information flow (i.e., control when and how much to update the context vector in a given time step). This allows the context to flow for longer sequences, as some unimportant states can be ignored.

Figure 3: The update gate $z$ selects whether the hidden state $h_t$ is to be updated with a new hidden state $\tilde{h}_t$. The reset gate $r$ decides whether the previous hidden state $h_{t-1}$ is ignored. Source: Cho, Kyunghyun, et al.

There are two major shortcomings in these RNN approaches:

The transformer architecture addresses these shortcomings by:

However there is one important shortcoming introduced in the transformer architecture with respect to previous RNN approaches:

Components

Multi-Head Self-Attention (MHA)

The Multi-Head Self-Attention mechanism allows words in the sequence to “pay attention” (attend) to each other, and through the learning procedure, eventually they will be able of identifying which words from previous time-steps are relevant to generate the element for the current time step.

Figure 4: Self-Attention: The red word is the word being generated at the current time step. Blue words are words with high attention score at the current time step. Note that words from future time steps are not considered. For example at time step $0$, only the first word "The" is considered. For time step $1$, only the words "The" and "FBI" are considered. Source: Allam, Tarek & McEwen, Jason. (2021)

To learn this component, we define three weight matrices: Query ($Q$), Key ($K$), and Value ($V$). These weight matrices will be multipled to obtain the attention matrix in the following way:

\[\begin{equation}\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\end{equation}\]

We apply a dot product between the queries and the keys, and then divide the result by the square root of the number of dimensions of the vectors. This is done to make the dot-product maintain variance close to $1$. Then, we apply softmax to make this distribution sum to $1$, and finally apply a dot-product with the value matrix. Basically the softmax computation acts as a weight to scale the influence of the value matrix.

Figure 5: Attention Mechanism: The input X is multiplied with matrices $\theta$, $\phi$, and $g$ to obtain the resulting matrices $Q$, $K$, and $V$. $Q$ and $K$ are multiplied to obtain the attention score matrix, that is finally multiplied with the $V$ matrix (Source).

A good intuition for how this works is to compare the concept with a lookup table / hash table / dictionary.

In a lookup table, we have the following structure:

d = {
    "dog": "bark",
    "wolf": "howl",
    "cat": "meow"
}

# this key exists, the query will match
# and return "value1"
query1 = "dog" 
value1 = lookup[query1]

# this key doesnt exist, the query won't 
# match and return and error
query3 =  "bird"
value3 = lookup[query3]

Note how a query will either match entirely with a single key (and then return it’s value), or it will not match at all with any of the keys (and then return an error or a predefined value).

In the attention mechanism, we are performing a similar operation, but we implement a “soft” match of keys and queries. In fact, all queries will match will all keys, but in different intensities. The softmax computation introduces this notion of intensity, which will weigh the value that is going to be returned.

d = {
    "dog": "bark":,
    "wolf": "howl",
    "cat": "meow"
}

# query doesn't exist in the dictionary
# but still we will match it with the keys by their similarity
query = "husky"

# the query has high similarity with dog, less with wolf, almost none with cat
attention = 0.7 * lookup["dog"] + 0.29 * lookup["wolf"] + 0.01 * lookup["cat"]

The output of the scaled dot-product is a square matrix with dimensions context_length by context_length, representing the attention score of token $i$ for token $j$. Note that the attention for token $i$ to token $j$ is not the same for token $j$ to token $i$ (i.e., the matrix is not symmetric). We then multiply this matrix with the value matrix to scale it according to the attention scores. $V$ has dimension (context_size, embedding_dim), thus the dot product between the attention scores and $V$ yields a matrix (context_size, embedding_dim).

Another important consideration is that the self-attention mechanism is slightly different in the encoder and in the decoder:

Figure 6: Attention Matrix Masking (Source).

Finally, the last missing component for this mechanism is the composition of multiple SA units into a single Multi-Head Self-Attention (MHA) unit.

Figure 7: Multi-Head Self-Attention (MHA): Input sequence "I am a student" is processed through $4$ SA units with embedding_dim equals $2$. The attention values $a_0, a_1, a_2, a_3$ are concatenated to form the output with embedding_dim equals $8$ (Source).

Motivation: a single attention mechanism might struggle to capture all the intricate patterns in complex sequences due to its limited capacity. Multi-head attention mitigates this by distributing the learning task across multiple heads, reducing the risk of bottlenecks where certain crucial information might be missed or underrepresented. In other words, each head will learn a different latent space, and encode its specialized features into a segment of the final vector. Hopefully each of these segments will capture specialized characteristics of the data, as opposed to having a single vector space encoding everything as in the single SA unit.

Implementation

class HeadAttention(nn.Module):
    def __init__(
        self,
        block_size,
        dim_embedding,
        dim_head
      ):
        super().__init__()

        self.dim_head = dim_head
        self.dim_embedding = dim_embedding
        
        # It is conventional to remove the bias for these
        self.key = nn.Linear(
            dim_embedding,
            dim_head,
            bias=False
          )
        self.query = nn.Linear(
            dim_embedding,
            dim_head,
            bias=False
          )
        self.value = nn.Linear(
            dim_embedding,
            dim_head,
            bias=False
          )
        
        # a register_buffer saves this variable in the model
        # params, but they are not considered trainable
        # parameters. can access the variable with self.mask
        self.register_buffer(
            "mask", torch.tril(
                torch.ones(block_size, block_size)
          ))
        
    def forward(self, x):
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        prod = q @ k.transpose(-2, -1) * self.dim_head**-0.5
        prod = prod.masked_fill(self.mask == 0, float("-inf"))
        prod = F.softmax(prod, dim=-1)
        
        out = prod @ v
        return out

class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        block_size,
        num_heads,
        dim_head,
        dim_embedding
      ):
        super().__init__()

        self.head_list = nn.ModuleList(
            [
                HeadAttention(
                    block_size,
                    dim_embedding,
                    dim_head,
                  ) for _ in range(num_heads)]
        )
        self.proj = nn.Linear(dim_embedding, dim_embedding)

    def forward(self, x):
        out = torch.cat(
            [head(x) for head in self.head_list],
            dim=-1
          )
        out = self.proj(out)
        
        return out

Word Embeddings and Positional Encodings

Unstructured inputs (e.g., text, images) must be converted into a numerical representation that will be processed. Mikolov et al. (2013) introduced the concept of word embeddings, which are latent vectors that capture the semantics of words.

In transformers, we also define an embedding matrix that will be learned jointly with the other components. The embedding matrix is initialised randomly, and is of dimension (vocabulary_size, embedding_dimension). Each input token will retrieve an embedding from the table. This embedding will then represent this token in input space.

Figure 8: Word Embeddings (Source).

Along the word embeddings, the transformer architecture defines a second type of embedding called positional encodings. Previously in this document we discussed how the SA mechanism doesn’t introduce any notion of space. The model doesn’t know what is the first and the second token and so on. Therefore, we must somehow encode this information.

There are many approaches to encode position into the word embeddings. Some approaches don’t involve learning new weights. For example, we can use sine and cosine functions to encode relative positions as a combination of previous positions.

Figure 9: Positional Encoding with sine and cosine functions (Source).

The benefit of this approach is that:

  1. We don’t have to learn the embeddings, they are calculated analytically
  2. We can encode any arbitrary position up to infinity

Other approaches involve learning the embeddings jointly with the transformer block. We define an embedding matrix of size (context_size, embedding_dimention). For each token position, we will learn an embedding. This generally works well, at the cost of having a fixed number of positions (up to context_size), and adding more trainable parameters.

The input to the transformer block is the element-wise sum of both the word embedding and the positional encodings.

Implementation

class EmbeddingEncoder(nn.Module):
    def __init__(self, dim_embedding, vocab_size, context_size):
        super().__init__()
        
        self.embedding_table = nn.Embedding(
            vocab_size,
            dim_embedding
          )
        self.positional_embedding_table = nn.Embedding(
            context_size,
            dim_embedding
          )

    def forward(self, x):
        x_emb = self.embedding_table(x)
        pos_emb = self.positional_embedding_table(
            torch.arange(
                self.block_size,
                device=device
              )
            )

        x_emb = x_emb + pos_emb
        
        return x_emb

Feed Forward / Multi-layer Perceptron (MLP)

Up to now, all the operations performed are linear transformations: matrix multiplications, softmax, addition. To add non-linearity, we simply put a feed forward network with a non-linear activation after the MHA layer.

This is usually a very shallow network with only $1$ hidden layer with a ReLU / GeLU / etc. activation function. It is also common to do an up projection and a down projection in the hidden layer, meaning the input gets stretched to a higher dimensional space when entering the hidden unit, then gets down projected to the original size after leaving the hidden unit.

Figure 10: Multi-Layer Perceptron: Input gets up projected from $3$ to $4$ dimensions, then down projected back from $4$ to $3$ dimensions. Source: Garg, Siddhant & Ramakrishnan, Goutham. (2020) .

Implementation

class MLP(nn.Module):
    def __init__(
        self,
        dim_head,
        dim_embedding,
      ):
        super().__init__()
        # increase the dimension by 4
        # then bring it back to the original size
        self.net = nn.Sequential(
            nn.Linear(dim_head, 4*dim_embedding),
            nn.ReLU(),
            nn.Linear(4*dim_embedding, dim_embedding)
        )

    def forward(self, x):
        out = self.net(x)

        return out

Batch Normalisation

Batch normalisation is a procedure that helps stabilise and accelerate the training process by reducing internal covariate shift, which refers to the change in the distribution of layer inputs during training. It consists in normalising the outputs of the layers so that their distribution have a mean of $0$ and a variance of $1$, then applying two learned parameters to scale and shift the distribution. This reduces the chance that gradients explode or vanish (grow or shrink exponentially). This is a key engineering component in ensuring that large and deep neural networks can converge in a stable fashion.

To normalise an input batch, we scale it in this fashion:

\[\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}\]

where $\mu_B$ is the mean of the mini-batch, $\sigma_B^2$ is the variance of the mini-batch, and $\epsilon$ is a small constant added for numerical stability.

Next, we define two learnable parameters gamma and beta to allow the model to reconstruct the original (non-scaled) distribution, if needed:

\[out = \gamma \hat{x}_i + \beta\]

If the learned $\gamma$ is equal to $1$ and the learned $\beta$ is equal to $0$, this layer is simply normalising the input to mean $0$ and variance $1$. However, the model may learn other parameters that allows it to scale and shift the distribution as it pleases.

Implementation

class BatchNorm(nn.Module):
    def __init__(
        self,
        embedding_dim,
        eps=1e-5,
        momentum=0.1
      ):
        super(BatchNorm, self).__init__()
        self.embedding_dim = embedding_dim
        self.eps = eps
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.ones(embedding_dim))
        self.beta = nn.Parameter(torch.zeros(embedding_dim))
        
        # Initialize running mean and variance
        self.register_buffer(
            'running_mean',
            torch.zeros(embedding_dim)
          )
        self.register_buffer(
            'running_var',
            torch.ones(embedding_dim)
          )

    def forward(self, x):
        # mean and var are only updated in training
        if self.training:
            batch_mean = x.mean(keepdim=True)
            batch_var = x.var(unbiased=False, keepdim=True)
            
            # Update running statistics
            self.running_mean = (
                (1 - self.momentum) * 
                self.running_mean + 
                self.momentum * 
                batch_mean
              )
            self.running_var = (
                (1 - self.momentum) *
                self.running_var + 
                self.momentum * 
                batch_var
             )
        # stored mean and variance are used in inference
        else:
            batch_mean = self.running_mean
            batch_var = self.running_var
        
        # Normalise to mean 0 variance 1
        x_hat = (
            (x - batch_mean) /
            torch.sqrt(batch_var + self.eps)
          )
        
        # Scale and shift the distribution
        out = self.gamma * x_hat + self.beta
        
        return out

Residual Connections

Residual connections were introduced by He et al. (2015) . It is a technique to address the problem of vanishing gradients and to make it easier to train very deep neural networks.

As neural networks become deeper, gradients can become very small during backpropagation, making it difficult to update the weights of earlier layers effectively. This can lead to very slow convergence or even stopping learning altogether.

Residual connections address these issues by allowing the gradient to flow directly through the network, making it easier to train deep networks. They work by adding the input of a layer to its output, effectively allowing the network to learn a residual mapping (identity function).

Figure 11: Residual Connection in a feed forward neural network. The input $X$ branches in two directions: the first direction goes into the MLP and is processed normally. The other direction skips the MLP completely, (i.e., the input is not changed at all). Both branches are aggregated afterwards. Source: He, Kaiming, et al. (2016) .

For very deep neural nets, at some point the data transformations being applied to the input are not beneficial anymore, and performance starts decreasing. To fix this, residual connections allow some layers to simply learn “nothing”, i.e. learn to pass the input as the output without modifying it whatsoever (i.e., learn the identity function). With this approach, scaling the network should never hurt performance, as it can simply stop learning new functions if needed.

Implementation

class GenericLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            ...
        )

    def forward(self, x):
        out = x + self.net(x)

        return out

Implementation

Check this google collab to test the architecture with a practical example borrowed from Andrej Karpathy’s lecture series.

# Source: adapted from https://youtu.be/kCc8FmEb1nY
class Head(nn.Module):
    """One head of self-attention."""

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # (B, T, C)
        q = self.query(x) # (B, T, C)
        wei = q @ k.transpose(-2, -1) * C**-0.5  # (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))  # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        v = self.value(x) # (B, T, C)
        out = wei @ v # (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """Multiple heads of self-attention in parallel."""

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    """A simple linear layer followed by a non-linearity."""

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Transformer(nn.Module):
    """A decoder-only transformer block."""

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class BigramLanguageModel(nn.Module):
    """Bigram language model using transformers."""

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # Final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        x = tok_emb + pos_emb # (B, T, C)
        x = self.blocks(x) # (B, T, C)
        x = self.ln_f(x) # (B, T, C)
        logits = self.lm_head(x) # (B, T, vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] # (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

Variations

Figure 12: Evolutionary Tree of Language Models (Source).

Encoder-Decoder (suitable for seq2seq tasks)

Encoder-only (suitable for input understanding tasks e.g. classification)

Decoder-only (suitable for generative tasks)