Road to ML Engineer #28 - Transformers

Last Edited: 10/26/2024

The blog post discusses about transformers in deep learning.

ML

In the last few articles, we have been discussing the challenges RNN models face in remembering contexts, despite numerous solutions introduced by researchers, as well as the critical lack of parallelizability. In this article, we will introduce a new, parallelizable algorithm that enables us to create nuanced word embeddings and has become dominant in deep learning today—the model I've been waiting forever to cover: Transformer.

Positional Encoding

One important factor to address when developing a parallelizable model that captures context is representing the position of inputs in a sequence. The RNN's sequential nature naturally encodes position, but a standard dense layer cannot encode position without intervention. The positional encoding used in the Transformer model employs as many sine and cosine waves as the embedding dimension and uses the values of these waves at each position to encode location.

P(k,2i)=sin(kn2im)P(k,2i+1)=cos(kn2im) P(k, 2i) = \text{sin}(\frac{k}{n^{\frac{2i}{m}}}) \\ P(k, 2i+1) = \text{cos}(\frac{k}{n^{\frac{2i}{m}}})

The above equation shows the mathematical operations for applying positional encoding. Here, mm is the dimension of the embedding, kk is the position of the input, ii is the ith wave used, ranging from 0 to m2\frac{m}{2}, and P(k,i)P(k, i) is the encoding function. The nn is a hyperparameter (defaulting to 10,000 in the original paper), which can influence the period of the waves. The results of all the waves are summed to create positional encoding, which is then added to the embedding. The sine and cosine values range from -1 to 1, keeping the positional encoding within a normalized range.

def positional_encoding(seq_len, embed_dim, n=10000):
    P = np.zeros((seq_len, embed_dim))
    for k in range(seq_len):
        for i in np.arange(int(embed_dim/2)):
            denominator = np.power(n, 2*i/embed_dim)
            P[k, 2*i] = np.sin(k/denominator)
            P[k, 2*i+1] = np.cos(k/denominator)
    return P

The code above is a Python implementation of positional encoding by Saeed, M. (2023). This implementation uses the equations we just covered to create a positional encoding matrix, which can be added to the word embedding. While this is not the only way to create positional encoding, the method we have here is simple and popular in NLP.

Attention Mechanism

After encoding position, we can finally move on to creating nuanced word embeddings based on context. Let's start by understanding the mechanism conceptually, setting aside computation details for now. Here, we can introduce the concept of attention, which determines the degree to which a word should influence another word. To obtain attention, we can introduce the concepts of query and key, which represent, respectively, an estimate of how relevant words should appear and how relevant each word is to other words. If the query and key are similar for a particular set of words, it means those words are worthy of attention when updating the embedding for the word of interest.

After computing attention, we calculate the value, which determines how the word should actually be modified by the other words, and finally update the word embedding with attention and value. This might sound abstract, so let's look at an example. For the sentence "I confirmed the validity of the document, so I put a seal on it," if we analyze the word "seal," we’d want to pay more attention to words like "document," "official," "animal," and "marine" to adjust the word representation accordingly.

In this case, "seal" would expect highly relevant words to be "document," "official," etc. (query), and those words, in turn, expect high relevance to the word "seal" (key). When these expectations are aligned (i.e., when query and key are similar), we pay more attention to these relevant words than to other, less relevant words. Then, depending on context, "seal" modifies the word representation either to mean a stamp or marine animal (value). Hence, we apply attention to the value and update the word representation. Note: these explanations are for conceptual understanding and may not represent exactly how queries, keys, and values actually learn to become.

Computing Attention

Computationally, we set up a shared matrix of size (m,n)(m, n) (where mm is the embedding dimension, nn is the latent dimension) for each query and key, WQW_Q and WKW_K, and multiply the word embedding EE of size (d,m)(d, m) (dd being the number of words in a sentence) by these matrices to transform the words into queries and keys in a lower dimension, resulting in QQ and KK of size (d,n)(d, n). To compute similarity between queries and keys, we use dot product similarity (QKTQK^T). The resulting matrix of size (d,d)(d, d) represents the attention.

Next, we set up a shared value matrix WVW_V of size (m,m)(m, m) to transform the words into values, resulting in VV of size (d,m)(d, m). Finally, we multiply the attention and value matrices to produce a matrix of size (d,m)(d, m), which we add via residual connection to update word representations to E~\tilde{E}. Before multiplying attention by value, we normalize the attention by dividing by d\sqrt{d} and applying the softmax function.

Q,K,V=EWQ,EWK,EWVAttention(Q,K,V)=softmax(QKTd)VE~=E+Attention(Q,K,V) Q, K, V = E W_Q, E W_K, E W_V \\ \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d}})V \\ \tilde{E} = E + \text{Attention}(Q, K, V)

However, in practice, working directly with the value matrix WVW_V of size (m,m)(m, m) can be inefficient, especially with large embedding dimensions. To reduce the size, we can use matrix factorization as discussed in Road to ML Engineer #24 - GloVe, creating matrices Wv1W_{v1} and Wv2W_{v2} with dimensions (m,n)(m, n), like WqW_q and WkW_k. The resulting computation then looks as follows.

Q,K,V=EWQ,EWK,EWv1Wv2TAttention(Q,K,V)=softmax(QKTd)VE~=E+Attention(Q,K,V) Q, K, V = E W_Q, E W_K, E W_{v1} W_{v2}^T \\ \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d}})V \\ \tilde{E} = E + \text{Attention}(Q, K, V)

This may seem complex, but it essentially involves a series of matrix multiplications and normalization steps, all of which are easily differentiable and fully parallelizable. A drawback of this approach is that it can only process a fixed number of words dd, unlike RNNs that handle varying input sequence lengths. However, this isn't a major issue due to workarounds like using an EOS (end-of-sentence) tag, padding, or simply scaling the model. To predict sequence continuation, we can shift the input while including previous outputs.

class SelfAttention(layers.Layer):
  def __init__(self, embedding_dim, hidden_dim):
    super(SelfAttention, self).__init__()
    self.wq = layers.Dense(hidden_dim)
    self.wk = layers.Dense(hidden_dim)
    self.wv1 = layers.Dense(hidden_dim)
    self.wv2 = layers.Dense(embedding_dim)
 
  def call(self, x):
    Q = self.wq(x)
    K = self.wk(x)
    V = self.wv2(self.wv1(x))
 
    Aw = tf.matmul(Q, K, transpose_b=True)
    d = tf.cast(tf.shape(K)[-1], tf.float32)
    Aw /= tf.math.sqrt(d)
    Aw = tf.nn.softmax(Aw, axis=-1)
    A = tf.matmul(Aw, V)
    return A

The code above is a TensorFlow implementation of attention (self-attention). Here, the query, key, and value matrices are generated by dense layers and manipulated to compute attention. You can see that attention itself is quite simple by examining the code. As a challenge, I recommend trying to implement the above in PyTorch.

Types of Attention

Since the query, key, and value matrices are shared among words, using only one attention module, or single-headed attention, may not be sufficient for updating embeddings. Hence, by default, multiple attention modules, or multi-headed attention, are set up to run in parallel and to produce updates that reflect different aspects of the sequence.

In the example used to understand attention, we applied self-attention, where attention is computed within the same sequence. Another type is cross-attention, where keys and values come from a different sequence than the one for which we are updating embeddings. Self-attention is useful for extracting sentence meaning and predicting the next word, while cross-attention helps translate or generate sequences in different languages or modalities, like audio for transcription.

A transformer is a model architecture that leverages tokenizers, embedding layers, positional encoding, and transformer layers consisting of attention and feedforward layers, which are all parallelizable. Although component implementations may vary depending on task nature, transformers frequently use multi-headed self-attention modules.

Why Transformer?

After reading the above, some of you might be wondering why we needed to develop a new model with a fixed input length when we already had simpler models like FNN and CNN that could handle fixed input lengths similarly with some workarounds. Could we not simply introduce positional encoding and use the conventional models? The straightforward reason is that, empirically, they did not perform as well compared to other models, including RNNs. This raises the question: why didn’t the conventional models work as effectively, while Transformers did?

The answer may lie in the inductive bias of the models. The inductive bias of an algorithm refers to the set of assumptions injected into a model to help it learn to make reasonable predictions on new data. Models like CNNs and RNNs incorporate various assumptions such as translational invariance (i.e., absolute positions do not matter), local receptive fields (meaning patterns can be captured by combining local features), feature hierarchy, sequential dependency, and temporal ordering. These biases make sense for certain data types, like images and text, and are the reason these models can quickly learn to recognize complex patterns with fewer weights.

While helpful for learning on small datasets of moderate complexity, these inductive biases can hinder a model’s ability to learn complex relationships that don’t conform to them. As researchers gained more access to data and computational resources, they aimed to model complex data like natural languages more accurately. In such cases, these inductive biases could be counterproductive. Also, a simple FNN cannot attend to different activations depending on the input, unlike a Transformer. While it can approximate similar computations with enough neurons, it cannot efficiently learn complex relationships within a reasonable model size and dataset.

Thus, it can be argued that the Transformer has an optimal level of inductive bias and flexibility, well-suited to handle the complex and large datasets commonly encountered today. Given techniques like transfer learning and fine-tuning, which leverage pre-trained models for specific use cases, and the rapid advancement of tools surrounding Transformers, it’s understandable why they have captured such significant attention from researchers and engineers in recent years.

Conclusion

In this article, we covered positional encoding, the conceptual and mathematical descriptions of the attention mechanism, types of attention, and the potential reasons why Transformers excel at various tasks, including NLP. Transformers are highly parallelizable, unlike RNNs, while effectively capturing contextual information in word embeddings, which outweighs the drawback of a fixed input length (at least for now). In the next article, we’ll explore specific Transformer architectures in greater detail.

Resources