Self Attention

Last updated:
Created:

The three big papers that led to the original form of self-attention were:

Standard self-attention, as described in Attention is all you need, is expressed like so:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

In code, this looks like:

def self_attention(X, W_q, W_k, W_v):
    """
    X: (seq_len, d_model) - input embeddings
    W_q, W_k: (d_k, d_model) - query/key projection weights
    W_v: (d_v, d_model) - value projection weights
    """
    # Project to Q, K, V
    Q = X @ W_q.T  # (seq_len, d_k)
    K = X @ W_k.T  # (seq_len, d_k)
    V = X @ W_v.T  # (seq_len, d_v)

    # Scaled dot-product attention
    d_k = K.shape[1]
    scores = Q @ K.T / (d_k**0.5)  # (seq_len, seq_len)
    attn_weights = F.softmax(scores, dim=-1)

    # Weighted sum of values
    output = attn_weights @ V  # (seq_len, d_v)
    return output, attn_weights

And visually it looks like this (from here):

Self-attention visualisation