LLM - Part 1 : Basic Components

Introduction to Large Language Model

This blog summarizes the key components of LLM briefly.

In 2017, the publication Attention is All You Need by Google and University of Toronto brought huge innovation into NLP, the era of transformers has launched. This efficient, scalable, and parallelizable approach has opened the Pandora's box of generative AI and ushered AI research into a whole new era.

What is Attention?

Traditional Encoder-Decoder models are widely used in NLP, however, their performance drops significantly when dealing with long sentences, and one of the main reasons for this is encoder-decoder model encodes the input sequence to one fixed length vector from which to decode each output time step. Attention is a technique to resolve this issue by focusing on the relevant parts of the input sequence.

The mechanism of attention includes the following steps:

Note that the above steps are repeated for many times. And the figure below illustrates the steps above.

steps for attention on one decoder hidden state (one iteration)

Introduction to Transformer

Transformer is a model that boosts the training speed and improves regenerative capability by using attention. The power of transformer lies in its ability to learn the relevance and context of all the words in a sentenece. For example, the figure below illustrates the relationship of words in a sentence, and is called a self-attention.

self-attention example in a sentence

Self Attention in Detail

Self-attention mechanism enables the model to weigh the importance of different elements in an input sequence and dynamically adjust their influence on the output.

As shown in the figure above, the word book is strongly connected with the word teacher and student. The ability to learn a tension in this way across the whole input significantly improves the model's ability to encode language.

Now let's dive into the calculation of self-attention :

The first step is to get three vectors from each of the encoder's input vectors, they are respectively queries, keys, and values. They are get by multiplying the embedding by three matrices that we trained during the training process.

Get Queries, Keys, and Values

The second step is to score each word of the input sentence against each single input word. This score determines when encoding a word at a certain position, and how much focus to place on other parts of the input sentence. The score is calculated by dot product of query vector and key vector of the respective word. e.g., If we process the self-attention for the word Machine, the first score would be the dot product of $q_1$ and $k_1$, the second score would be the dot product of $q_1$ and $k_2$.

The third step is to do a normalization on scores for more stable gradients. All scores are divided by $8\sqrt{d_k}$, where $d_k$ is the dimension of the key vectors. Then apply softmax on scores, so that the scores are all between 0 and 1, and they sum up to 1.

The fourth step is to scale each value vector by multiplying their corresponding softmax score, then sum up the weighted value vectors, e.g., summing up the scaled $v_1$ and scaled $v_2$ (scaled by softmax score of the word Machine in the figure), we get $z_1$, which is the self-attention result for the first word Machine.

Multi-Headed Attention

Multi-head attention is a simple extension of the single-head attention above to give us more representation power. The motivation of multi-head attention is to represent different aspects of a word, e.g., from syntax perspective or semantics perspective, so we need sth more than a single embedding.

Instead of producing one set of $W^Q, W^K, W^V$ matrices, multi-head attention produces $k$ sets of $W^Q, W^K, W^V$ matrices. The computation process is the same for each set of $W^Q, W^K, W^V$ matrices, \(Z_i = softmax\left(\frac{W^{Q_i}(W^{K_i})^{T}}{\sqrt{d}}\right)W^{V_i}, \;\;\;i=1,...,k\)

Once we get $k$ outputs, we concatenate and project the outputs by \(Z = W[Z_0, Z_1, ..., Z_k]\) where $W$ is a weight matrix trained with the model.

Now let's investigate the time complexity of calculating multi-head atttention :

So the overall time complexity will be $O(kdn^2)$, and since $d$ and $k$ are constants, the complexity would be $O(n^2)$, which is expensive. So how to make the calculation more efficient? The high-level idea would be to sparsify the attention matrices, methods include locality sensitive hasing, low-rank decomposition.

Positional Encoding

As you could notice, self-attention is not sensitive to word ordering, but in NLP problems, order matters a lot, so we'll need to represent the position of input sequence. To address this, the transformer adds a positional embedding to the input word embedding, and the positional embedding should be of the same dimension as word embedding, then we add them up. The intuition of position embedding is binary encoding, i.e., the frequency of bit flips increases from left to right. It uses $sin$ and $cos$ functions with a frequency that increases.

Encoder Structure

The encoding component is a stack of encoders, and the decoding component is a stack of decoders of the same number.

Each encoder layer has a residual connection around it, and is followed by a layer-normalization step.

Get Queries, Keys, and Values

Decoder Structure

The output of the final layer of encoder is fed into every decoder layer as part of the input, and the decoder stack passes their outputs to the next decoder layer.

Get Queries, Keys, and Values

Prompt Engineering and Generative AI project lifecycle

Providing examples inside the context window is called in-context learning. The method that includes the input data within the prompt is called zero-shot inference. The inclusion of a single example is known as one-shot inference, and of course, extending the idea of giving a single example to include multiple examples is known as few-shot inference. The largest models are good at zero-shot inference and are able to infer and complete tasks that they were not trained to perform; While smaller models are generally good at only similar tasks that they were trained on.

The Generative AI project lifecycle is as follows:

The three main classes of PEFT methods are as follows :

LoRA : Low Rank Adaption of LLMs

LoRA is a widely used reparameterization method that reparameterize model weights using a low-rank representation.

The steps of LoRA are as follows:

Steps to update model for inference :

Here's a concrete example of how LoRA reduce the number of parameters : If transformer weights have dimensions $d\times k = 512\times 64$, then the number of trainable parameters would be 32678; While in LoRA with rank r=8, $A$ has dimensions $r\times k = 8\times 64 = 512$ parameters, B has dimension $d\times r = 512\times 8 = 4096$ trainable parameters, so that would reduce 86% in parameters to train.

Soft Prompts

Although sounds similar, prompt tuning is different from prompt engineering, the goal of prompt engineering is to help the model understand the nature of the task you're asking it to carry out in order for a better completion. However, prompt engineering requires huge manual effort and has big limitations. But with prompt tuning, you add additional trainable tokens to your prompt and leave it up to the supervised learning process to determine their optimal values. And the set of trainable tokens is called a soft prompt.

Soft prompt vectors are of the same length as the embedding token vectors, and including somewhere between 20 and 100 virtual tokens can be sufficient for good performance.

The figure on the left shows intuitively the tokens that represent natural language, where they each corresponds to a fixed location in the embedding vector space; The figure on the right shows that soft prompts can be seen as virtual tokens that can take any value within the continuous multi-dimensional embedding space, and the words closest to the soft promopt tokens have similar meanings, i.e., they form tight semantic clusters.

soft prompt intuitive

Compared to full fine-tuning, where millions to billions of parameters are updated, only 10k to 100k parameters need to be updated in prompt tuning. Intuitively, prompt tuning for multiple tasks are like assembling blocks, you'll only need to switch out soft prompt (i.e., soft prompt vectors for different tasks) at inference time to change task!

soft prompt parameter trend

As Lester et al. shows in The Power of Scale for Parameter-Efficient Prompt Tuning (the figure below), as the models have around 10 billion parameters, prompt tuning can be as effective as full fine tuning, and offers a significant boost in performance over prompt engineering alone.

soft prompt parameter trend

Reinforcement Learning from Human Feedback

Motivated by some existing bad behaviours in LLM, like toxic language, aggressive responses, dangerous information etc, fine-tuning a model based on human feedback is needed. RLHF is such a popular technique to finetune LLM, it helps maximize helpfulness, minimize harm, and avoid dangerous topics.

I'll omit the basic concepts of RL here.

The steps in fine-tuning an LLM with RLHF are as follows :

Here is the process flowchart :

soft prompt intuitive

Reward hack in RLHF

In LLM, reward hacking can manifest as the addition of words or phrases to completions that result in high scores for the metric being aligned but reduce the overall quality of the language. For example, you have already trained a reward model that can carry out sentiment analysis and classify model completions as toxic or non-toxic, as you iterate, RLHF will update the LLM to create a less toxic responses. However, as the policy tries to optimize the reward, the model might start generating completions which are very exaggerated.

In order to prevent reward hacking, we can use the initial instruct LLM as performance reference. The weights of the reference model are frozen and are not updated during iterations. As shown in the below flowchart, you compare the two completions and calculate KL divergence between them, and that could be added into the reward calculation. This would penalize the RL updated model if it shifts too far from the reference LLM.

soft prompt intuitive

References

{0xc0001feb40 0xc000428c60}