Skip to content

Linear Attention

Original Linear Attention

The original attention mechanism is defined as:

Complexity: O(N2dk+N2dv)

O=softmax(QK)V

If we omit the softmax operator it becomes:

Complexity: O(Ndkdv)

O=QKV=Q(KV)

which is of linear complexity.

In Autoagressive Models

In autoregressive models like GPT we need a causal mask:

Training:O=softmax(QKM)VInference:ot=j=1texp(qtkj)l=1texp(qtkl)vj

Removing softmax yields:

Training:O=(QKM)VInference:ot=j=1t(qtkj)vj

Unfortunately the Hadamard product () does not commute with matrix multiplication, so the operation remains quadratic.

RNN-like Sequential Form

Even though, Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention observes that the inference equation can be written in an RNN-like form:

ot=j=1t(vjkj)qt=Stqt,St=St1+vtkt

This achieves linear time, but introduces two drawbacks:

  1. Memory during autograd:

    High memory footprint during autograd: every intermediate state St must be stored, resulting in O(Ld2) memory usage. The authors alleviate this by recomputing St on-the-fly during back-propagation.

  2. Parallelism:

    Poor training parallelism: the update is element-wise instead of large matrix multiplications, which under-utilizes GPU tensor cores.

    A compromise is the chunkwise algorithm proposed in Transformer Quality in Linear Time, which allows parallelism while remaining linear.

    From up to bottom: Parallel form, recurrent form, chunkwise parallel form:

Gated Linear Attention

A learnable 2D forget gate Gt(0,1)dk×dv is added:

St=GtSt1+ktvt

This is very general and encompasses many recent RNNs with 2D hidden states:

Gated Linear Attention Transformers with Hardware-Efficient Training also proposed its GLA:

  • Recurrent form:
St=(αt1)St1+ktvt=Diag(αt)St1+ktvt
  • Parallel form:
St=i=1t((j=i+1tαj)kivi)
  • Chunkwise parallel:

References

Sonta. "Zhihu answer." https://www.zhihu.com/question/9740764576/answer/80735153803

Katharopoulos A., Vyas A., Pappas N., Fleuret F. "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention." https://arxiv.org/abs/2006.16236

Vyas A., Katharopoulos A., Fleuret F. "Transformer Quality in Linear Time." https://arxiv.org/abs/2202.10447

Yang S.L., Wang B.L., et al. "Gated Linear Attention Transformers with Hardware-Efficient Training." https://arxiv.org/abs/2312.06635