Attention in the age of ADHD

I explain what is attention and self-attention in short.

Simon Li

When I first got into machine learning, I always heard Transformer this and attention that. Unfortunately and frustratingly, I had absolute ZERO idea what attention was even after some brief search on the internet. Therefore, I will try to explain the concept of attention in a way that hopefully would be straight forward even for people who’s new to machine learning or computer science.

Why attention?

Before the arrival of Transformers, Recurrent Neural Network (RNN) was one of the most popular architecture for sequential data. RNN is known for its ability to generate local temporal dependcies However, it lacks the ability for long-range time dependencies, meaning that it would favor more recent information than further ones. To battle this weakness of leveraging information from hidden layers of RRN, attention mechanism is proposed.

What’s Attention?

In short, the attention’s output sequence is the weighted average of the input sequence. More specifically, attention is a function that transforms an input sequence to an output sequence that does not necessary have the same length using a learned input-dependent weighted average.

A: tokens weighted avg. tokens A: \text{ tokens } \to_{\text{weighted avg.}} \text{ tokens}

Math behind attention

In this section, we are going to formalize how the weighted average is taken for the output sequence from the input sequence.

Suppose we have TinT_{in} and ToutT_{out} number of input and output tokens, respectively. Then, we have an input sequence $V$ and an output sequence $Z$.

VRTin×DZRTout×DV \in \R^{T_{in}\times D} \quad Z \in \R^{T_{out} \times D}

Let pi,jp_{i,j} be the weight of input token ii in the output token jj. Then, we have that the output tokens are

zi=j=1Tinpi,jvj    Z=PV z_i = \sum^{T_{in}}_{j=1} p_{i,j}v_j \implies Z = PV

Note that we’d require weighting coefficients P[0,1]Tout×TinP \in [0,1]^{T_{out} \times T_{in}} and that j=1Tinpi,j=1\sum^{T_{in}}_{j=1} p_{i,j} = 1.

More math behind attention

We now know how the weighted average output is calculated. But where does the weighting coefficients PP come from?

Suppose that we are given Query Tokens QRTout×DKQ \in \R^{T_{out}\times D_K} and Key tokens KRTin×DKK \in \R^{T_{in}\times D_K}. Then, we can determine the weight coefficient pi,jp_{i,j} by calculating how similar qiq_i and kjk_j are. We normally would use cosine similarity to calculate the similarity; however, using just the numerator of cosine similarity, which is a dot product, not only works well but also saves a heafty amount of computations.

After we obtained the raw similarity by using the inner product, we would scale the result by dividng it with Dk\sqrt{D_k} where DKD_K is a scaling factor. This step is necessary as due to random initialization, we could have a sharp distribution of weight coefficients PP, which could take the model much more time to adjust the initial peaks. With the scaling factor, it ensures that the distribution at the start is more uniform, thus guaranteeing a faster convergence.

Lastly, we normalize the value after scaling with softmax to obtain a probability distribution.

P=softmax(QKTDK)P = \text{softmax} \left(\frac{QK^T}{\sqrt{D_K}}\right)

Self-Sttention

Self-attention is a special case of attention where TTin=ToutT \coloneqq T_{in} = T_{out} and that all of V,K,QV, K, Q are derived from the same input token sequence XRT×DX \in \R^{T\times D}. This means that to calculate V,K,QV, K, Q, we have learnable parameters WV,WK,WQW_V, W_K, W_Q such that

V=XWVRT×D,WVRD×DK=XWKRT×DK,WkRD×DKQ=XWQRT×DK,WQRD×DK\begin{aligned} V &= XW_V \in \R^{T\times D}, W_V \in \R^{D \times D} \\ K &= XW_K \in \R^{T\times D_K}, W_k \in \R^{D \times D_K} \\ Q &= XW_Q \in \R^{T\times D_K}, W_Q \in \R^{D \times D_K} \end{aligned}

where DKD_K is the dimension of the Keys and Queries tokens.

Therefore, we have that the output of the self-attention would be

Z=softmax(XWQWKTXTDKXWV) Z = \text{softmax} \left( \frac{X W_Q W_K^T X^T}{\sqrt{D_K}} XW_V \right)

Multi-Head Self-Attention

Just like a convolution layer where we can run multiple convolutions, we can run mutiple attention heads per layer! Suppose the output of each head $h_i$ is given by $Z_i$ described above. Then, the final output is obtained by concatenating all the individual head’s output and apply a linear transformation Z=[Z1,,ZH]WOZ = [Z_1, \dots, Z_H]W_O where WORHDV×DW_O \in \R^{HD_V\times D} is a learnable parameter.