Skip to content

Online Normalizer Calculation for Softmax

by Maxim Milakov and Natalia Gimelshein, 2018

https://arxiv.org/abs/1805.02867

Original softmax

yi=exij=1Vexj

Naive algorithm

py
d[0] = 0
for j in range(1, V+1):
  d[j] = d[j-1] + e**x[j]
for i in range(1, V+1):
  y[i] = e**x[i] / d[V]
  • May overflow due to the exponent.

Safe version

py
m[0] = float('-inf')
for k in range(1, V+1):
  m[k] = max(m[k-1], x[k])
d[0] = 0
for j in range(1, V+1):
  d[j] = d[j-1] + e**(x[j] - m[V])
for i in range(1, V+1):
  y[i] = e**(x[i] - m[V]) / d[V]
  • Requires three passes over the input vector.

Online normalizer calculation

py
m[0] = float('-inf')
d[0] = 0
for j in range(1, V+1):
  m[j] = max(m[j-1], x[j])
  d[j] = d[j-1] * e**(m[j-1] - m[j]) + e**(x[j] - m[j])
for i in range(1, V+1):
  y[i] = e**(x[i] - m[V]) / d[V]
  • Reduced to two passes.
  • Reduced memory accesses from 4 to 3 per element.

Parallel online normalizer calculation

[mVdV]=[x11][x21]...[xV1]

where

[midi][mjdj]=[max(mi,mj)di×emimax(mi,mj)+dj×emjmax(mi,mj)]
  • Associative, which enables parallel evaluation.
  • Commutative, which provides more flexibility.

Softmax and TopK fusion

  • TopK itself requires 5 memory accesses per element.
  • Online softmax + TopK = 5 + 4 = 9 accesses.
py
m[0] = float('-inf')
d[0] = 0
u = [float('-inf') for _ in range(K+2)]
p = [-1 for _ in range(K+2)]
for j in range(1, V+1):
  m[j] = max(m[j-1], x[j])
  d[j] = d[j-1] * e**(m[j-1] - m[j]) + e**(x[j] - m[j])
  u[K+1] = x[j]
  p[K+1] = j
  k = K
  while k >= 1 and u[k] < u[k+1]:
    u[k], u[k+1] = u[k+1], u[k]
    p[k], p[k+1] = p[k+1], p[k]
    k -= 1
for i in range(1, k+1):
  v[i] = e**(u[i]-m[V])/d[V]
  z[i] = p[i]
  • Fused version requires 5 accesses.

Benchmark

  • Online softmax:
    • V>1000: ~1.3x faster than safe softmax.
    • Close to naive softmax.
  • Softmax TopK Fused:
    • K=5, V=25000: 5x = 2.5x (softmax) * 2x (fusion).
    • Larger K, less improvement.

Smaller vector size: GPU is underutilized, and the performance is limited not by the memory bandwidth, but various latencies.