Attention#

Why attention?#

Attention means keeping tabs on the most important parts. Attention comes from a key observation: Not all words are equal, and some words are more crucial to understanding the sentence than other. For example, the sentence “It is raining outside”. You probably understand that it’s raining outside if I say: “Rain! Out!”. In this case, it and is are completely redundant. And if a model is trying to understand the sentence, throwing out it and is is probably not going to make a difference.

What’s attention?#

So how to focus only on the most important part? One way to do it is to multiply the important parts by a large factor, while reducing the unimportant parts values (those parts are, in fact, numbers in machine’s language). And that’s what attention mechanism does.

Try attention in code#

%matplotlib inline

import numpy as np
from matplotlib import pyplot as plt
def softmax(x, t = 1):
    exp = np.exp(x / t)

    # sums over the last axis
    sum_exp = exp.sum(-1, keepdims=True)
    
    return exp / sum_exp
num = 5

weights = softmax(np.random.randn(num), t=0.1)
data = np.random.randn(num)

print(weights)
print(data)
[3.02235853e-05 9.63224842e-01 3.70953352e-08 2.89464043e-02
 7.79849294e-03]
[ 1.01624972 -0.88726094 -2.05834714 -1.96504155 -2.27676688]
average = data.sum() / data.size
attn_applied = weights @ data

print(average)
print(attn_applied)

print(weights.argmax())
print(data[weights.argmax()])
-1.2342333580912643
-0.9292373757333592
1
-0.8872609375458659

See how the attention mask makes the weighted average of data closer to the desired place.