How to Think About Self-Attention Intuitively

Neural networks are universal function approximators. They can, with an infinite number of parameters and data, learn any function. In practice, we can structure our models to encode priors about the problem we are working on and reduce its complexity significantly. For instance, a convolutional neural network bakes in the assumption that local spatial patterns are important. Attention assumes that context matters, and that any token could be a relevant bit of context for any other. Relevance is determined by a token's content.

The formula for self-attention is: Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V A standard one sentence description of what going here reads something like:

Each token is projected into three separate vectors:

  • Query (what am I looking for?)
  • Key (what do I advertise to others?)
  • Value (what information do I contribute?)

I've always found that concept very hard to visualise. This is a short post stepping through an imaginary and overly simplified description of how attention could capture the relationship between a single token pair.

Imagine our model's embedding size dd is 4, that is to say imagine our model represents each token as 4 numbers. We say such a vector has four dimensions, or equivalently, that it lives in a 4-dimensional vector space. Imagine each of these dimensions represents an easily described concept (in practice, learned dimensions aren't so neatly interpretable):

DimensionMeaning
d1Animacy (high = animate)
d2Number (high = singular)
d3Syntactic role (high = subject)
d4Semantic content: "cat-ness" vs "mat-ness"

Now imagine we feed in the sentence:

The cat sat on the mat because it was tired.

cat might project to:

Query: [0.1, 0.2, 0.3, 0.1] → I'm not really looking for anything right now
Key:   [0.9, 0.9, 0.8, 0.2] → I'm animate, singular, a subject
Value: [0.7, 0.6, 0.4, 0.9] → Here's my actual content: cat-ness

it might project to:

Query: [0.9, 0.8, 0.7, 0.1] → I need an animate, singular, subject-like thing
Key:   [0.3, 0.9, 0.5, 0.1] → I'm singular, but not much else to offer
Value: [0.1, 0.5, 0.3, 0.1] → I have little semantic content of my own

mat might project to:

Query: [0.1, 0.1, 0.2, 0.1] → Not searching for much
Key:   [0.1, 0.9, 0.2, 0.2] → I'm singular, but inanimate, not a subject
Value: [0.1, 0.6, 0.3, 0.9] → Here's my content: mat-ness

Let's see what happens when we calculate the attention scores for it:

QitKcat=(0.9,0.8,0.7,0.1)(0.9,0.9,0.8,0.2)=0.81+0.72+0.56+0.02=2.11QitKmat=(0.9,0.8,0.7,0.1)(0.1,0.9,0.2,0.2)=0.09+0.72+0.14+0.02=0.97\begin{align*} Q_{\text{it}} \cdot K_{\text{cat}} &= (0.9, 0.8, 0.7, 0.1) \cdot (0.9, 0.9, 0.8, 0.2) \\ &= 0.81 + 0.72 + 0.56 + 0.02 \\ &= 2.11 \\ \\ Q_{\text{it}} \cdot K_{\text{mat}} &= (0.9, 0.8, 0.7, 0.1) \cdot (0.1, 0.9, 0.2, 0.2) \\ &= 0.09 + 0.72 + 0.14 + 0.02 \\ &= 0.97 \end{align*}

it attends much more strongly to cat than to mat, as we would expect.

If in our model a single vector served the roles of K and V together. The model would face a trade off between preserving either the syntactic structural information of the Key vectors or the semantic meaning in the Value vectors.

Tags: AI