Skip to content

Overview of Efficient Transformers

Published: at 01:32 PMSuggest Changes

Please note that the following content is intended for personal note-taking, and therefore is rather unorganized. On the occasion of any issues, please email me or make a thread in the comment section below.

Base:

image.png

  1. embed RB×RNRB×RN×Rdmodel \mathbb{R}^B \times \mathbb{R}^N\rightarrow \mathbb{R}^B \times \mathbb{R}^N \times \mathbb{R}^{d_{\text {model }}}
  2. +positional encoding
  3. Multi-Head Self-Attention Ah=Softmax(αQhKh)Vh,A_h=\operatorname{Softmax}\left(\alpha Q_h K_h^{\top}\right) V_h, HH heads, Wq,Wk,WvRd×dHW_q, W_k, W_v \in \mathbb{R}^{d \times \frac{d}{H}} , concatenate, dense. parallel: RB×RN×RH×RdH\mathbb{R}^B \times \mathbb{R}^N \times \mathbb{R}^H \times \mathbb{R}^{\frac{d}{H}}
  4. Position-Wise FFN F2(ReLU(F1(XA))),F_2\left(\operatorname{ReLU}\left(F_1\left(X_A\right)\right)\right), for every time step, apply the same MLP to improve network expression ability (promote then reduce dimensionality). It’s a trick to control the scale issue.

Computation Cost

memory & cc: quadratic. N×NN\times N

Mode

Modifications

image.png

Methodology

 Model/Article  Complexity  Decode  Class  Memory Compressed [48]O(Nc2)FP+M Image Transformer [55]O(N.m) FP  Set Transformer [43]O(kN)× M  Transformer-XL [16]O(N2)RC Sparse Transformer [9] O(NN) FP  Reformer [37]O(NlogN) LP  Routing Transformer [62] O(NN) LP  Axial Transformer [28] O(NN) FP  Compressive Transformer  [59] O(N2)RC Sinkhorn Transformer [74]O(B2) LP  Longformer [5] O(n(k+m))FP+M ETC [3]O(Ng2+NNg)×FP+M Synthesizer [73] O(N2)LR+LP Performer [10] O(N)KR Funnel Transformer [15] O(N2)FP+DS Linformer [87] O(N)× LR  Linear Transformers [36]O(N)KR Big Bird [93] O(N)×FP+M Random Feature Attention [56]O(N)KR Long Short Transformers [96]O(kN)FP+LR Poolingformer [95]O(N)×FP+M Nystro¨mformer [91]O(kN)×M+DS Perceiver [31] O(kN)M+DS Clusterformer [88]O(NlogN)× LP  Luna [50]O(kN) M  TokenLearner [63]O(k2)× DS  Adaptive Sparse Transformer [14]O(N2) Sparse  Product Key Memory [41] O(N2) Sparse  Switch Transformer [23] O(N2) Sparse  GShard [45] O(N2) Sparse  Scaling Transformers [32] O(N2) Sparse  GLaM [21] O(N2) Sparse \begin{array}{|c|c|c|c|} \hline \text { Model/Article } & \text { Complexity } & \text { Decode } & \text { Class } \\ \hline \text { Memory Compressed }^{\dagger}[48] & O\left(N_c^2\right) & \checkmark & \mathrm{FP}+\mathrm{M} \\ \hline \text { Image Transformer }{ }^{\dagger}[55] & O(N . m) & \checkmark & \text { FP } \\ \hline \text { Set Transformer }{ }^{\dagger}[43] & \mathcal{O}(k N) & \times & \text { M } \\ \hline \text { Transformer-XL }{ }^{\dagger}[16] & O\left(N^2\right) & \checkmark & \mathrm{RC} \\ \hline \text { Sparse Transformer [9] } & O(N \sqrt{N}) & \checkmark & \text { FP } \\ \hline \text { Reformer }^{\dagger}[37] & O(N \log N) & \checkmark & \text { LP } \\ \hline \text { Routing Transformer [62] } & O(N \sqrt{N)} & \checkmark & \text { LP } \\ \hline \text { Axial Transformer [28] } & O(N \sqrt{N}) & \checkmark & \text { FP } \\ \hline \text { Compressive Transformer }^{\dagger} \text { [59] } & O\left(N^2\right) & \checkmark & \mathrm{RC} \\ \hline \text { Sinkhorn Transformer }{ }^{\dagger}[74] & O\left(B^2\right) & \checkmark & \text { LP } \\ \hline \text { Longformer [5] } & \boldsymbol{O}(n(k+m)) & \checkmark & \mathrm{FP}+\mathrm{M} \\ \hline \text { ETC }[3] & O\left(N_g^2+N N_g\right) & \times & \mathrm{FP}+\mathrm{M} \\ \hline \text { Synthesizer [73] } & O\left(N^2\right) & \checkmark & \mathrm{LR}+\mathrm{LP} \\ \hline \text { Performer [10] } & O(N) & \checkmark & \mathrm{KR} \\ \hline \text { Funnel Transformer [15] } & O\left(N^2\right) & \checkmark & \mathrm{FP}+\mathrm{DS} \\ \hline \text { Linformer [87] } & O(N) & \times & \text { LR } \\ \hline \text { Linear Transformers }{ }^{\dagger}[36] & O(N) & \checkmark & \mathrm{KR} \\ \hline \text { Big Bird [93] } & O(N) & \times & \mathrm{FP}+\mathrm{M} \\ \hline \text { Random Feature Attention }{ }^{\dagger}[56] & O(N) & \checkmark & \mathrm{KR} \\ \hline \text { Long Short Transformers }{ }^{\dagger}[96] & O(k N) & \checkmark & \mathrm{FP}+\mathrm{LR} \\ \hline \text { Poolingformer }^{\dagger}[95] & O(N) & \times & \mathrm{FP}+\mathrm{M} \\ \hline \text { Nyströmformer }{ }^{\dagger}[91] & O(k N) & \times & M+D S \\ \hline \text { Perceiver [31] } & O(k N) & \checkmark & M+D S \\ \hline \text { Clusterformer }^{\dagger}[88] & O(N \log N) & \times & \text { LP } \\ \hline \text { Luna }[50] & O(k N) & \checkmark & \text { M } \\ \hline \text { TokenLearner }^{\dagger}[63] & O\left(k^2\right) & \times & \text { DS } \\ \hline \text { Adaptive Sparse Transformer }^{\dagger}[14] & O\left(N^2\right) & \checkmark & \text { Sparse } \\ \hline \text { Product Key Memory [41] } & O\left(N^2\right) & \checkmark & \text { Sparse } \\ \hline \text { Switch Transformer [23] } & O\left(N^2\right) & \checkmark & \text { Sparse } \\ \hline \text { GShard [45] } & O\left(N^2\right) & \checkmark & \text { Sparse } \\ \hline \text { Scaling Transformers [32] } & O\left(N^2\right) & \checkmark & \text { Sparse } \\ \hline \text { GLaM [21] } & O\left(N^2\right) & \checkmark & \text { Sparse } \\ \hline \end{array}

Detailed Walk-Through

Image Transformer

local attention

Self-attention computed within blocks independently. for block of length bb, cc=O(b2(n/b))=O(n)cc=O(b^2*(n/b))=O(n)

Memory-Compressed Attention

For Kd×NK^{d\times N}, apply convolution along axis 0 with kernel size and stride kk to reduce dimension to d×Nkd\times \frac{N}{k}. CC of attention then becomes O(nn/k)O(n \cdot n / k). However, it often either

  1. does not result in significant improvement in cc due to NN and N/kN/k being similar in orders of magnitude; or
  2. lose information during compression.

Two Attention Schemes

image.png

flattened in raster order, partition into non-overlapping query blocks of length lql_q, extends lml_mto memory blocks. m=lq+lm,cc=O(nm).m=l_q+l_m, \quad cc=O(n \cdot m). Loses global receptive field.

Sparse Transformer

Assumption: In softmax Attention, effective weights are sparsely distributed.

heads

1/21/2 fixed: FA^ij(1)={QiKj, if j/L=i/L0 otherwise FA^ij(2)={QiKj, if jmodLLc0 otherwise \hat{FA}_{i j}^{(1)}= \begin{cases}Q_iK_j^{\top}, & \text { if }\lfloor j / L\rfloor=\lfloor i / L\rfloor \\ 0 & \text { otherwise }\end{cases} \\ \hat{FA}_{i j}^{(2)}= \begin{cases}Q_iK_j^{\top}, & \text { if }j\bmod L \geq L-c \\ 0 & \text { otherwise }\end{cases} +1/2+ 1/2 strided: SA^ij(1)={QiKj, if max(0,iL)ji0 otherwise SA^ij(2)={QiKj, if (ij)modL=00 otherwise \hat{SA}_{i j}^{(1)}= \begin{cases}Q_iK_j^{\top}, & \text { if }\max{(0,i-L)} \leq j \leq i\\ 0 & \text { otherwise }\end{cases} \\ \hat{SA}_{i j}^{(2)}=\begin{cases}Q_iK_j^{\top}, & \text { if }(i-j)\bmod L=0 \\ 0 & \text { otherwise }\end{cases} which can be visualized below: image.png

usage

  1. alternate 11 and 22: attention(X)=Wpattend(X,A(rmodp))\operatorname{attention}(X)=W_p \cdot \operatorname{attend}\left(X, A^{(r \bmod p)}\right) Justification: 11 synthesizes block information, so striding in 22 does not affect the receptive field.
  2. merge: attention(X)=Wpattend(X,m=1pA(m))\operatorname{attention}(X)=W_p \cdot \operatorname{attend}\left(X, \bigcup_{m=1}^p A^{(m)}\right)
  3. multihead, then concatenate: attention(X)=Wp(attend(X,A)(i))i{1,,nh}\operatorname{attention}(X)=W_p\left(\operatorname{attend}(X, A)^{(i)}\right)_{i \in\left\{1, \ldots, n_h\right\}}

cc

set L=NL=\sqrt{N}, then O(NN).O(N \sqrt{N}). Strided attention is more suited for images and audio; (more local) fixed attention is more suited for texts. (more global)

Blockwise Self-Attention

Memory: for BERT, model 2.1%2.1\%, optimizer(adam)10.3%10.3\%, activation 87.6%87.6\% O(N2).O(N^2). BlockBERT: split Nn×NnN\rightarrow n\times \frac{N}{n}, then QQ1,Q2,...Qn,Q\rightarrow Q_{1},Q_{2},...Q_{n}, same for K,V.K,V. O(N2n2×n)=O(N2n)O\left(\frac{N^2}{n^2} \times n\right)=O\left(\frac{N^2}{n}\right)

Axial Transformer

Compute attention along axis. For image data, B×N×hwBw×N×h+Bh×N×wB\times N\times hw\rightarrow Bw\times N\times h+Bh\times N\times w saves O(N2)/O(NN)=O(N)O(N^2)/O(N\sqrt{N})=O(\sqrt{N}) Generally, for N=N1/d××N1/dN=N^{1 / d} \times \cdots \times N^{1 / d}, Axial transformer saves O(N2)/O(d×(N1/d)(d1)×(N1/d)2)=O(N(d1)/d)O(N^2)/O(d\times\left(N^{1/d}\right)^{(d-1)}\times \left(N^{1/d}\right)^{2})=O\left(N^{(d-1) / d}\right) image.png

models

auto-regressive: pθ(x)=i=1Npθ(xix<i)p_\theta(x)=\prod_{i=1}^N p_\theta\left(x_i \mid x_{<i}\right) Inner decode: row-wise model hEmbed(x)h ShiftRight (h)+ PositionEmbeddings h MaskedTransformerBlock 2(h)×Lrow \begin{aligned} & h \leftarrow \operatorname{Embed}(x) \\ & h \leftarrow \text { ShiftRight }(h)+\text { PositionEmbeddings } \\ & h \leftarrow \text { MaskedTransformerBlock }_2(h) \qquad \qquad \times L_{\text {row }}\end{aligned} where hh is the initial DD dimensional embedding of size H×W×D.H\times W\times D. shiftright ensures the current pixel is out of receptive field. Outer Decoder: capturing the rows above hEmbed(x)uh+ PositionEmbeddings u MaskedTransformerBlock 1(TransformerBlock2(u))×Lupper /2hShiftDown(u)+ShiftRight(h)+ PositionEmbeddings h MaskedTransformerBlock 2(h)×Lrow \begin{aligned} & h \leftarrow \operatorname{Embed}(x) \\ & u \leftarrow h+\text { PositionEmbeddings } \\ & u \leftarrow \text { MaskedTransformerBlock }_1\left(\operatorname{TransformerBlock}_2(u)\right) \quad \times L_{\text {upper }} / 2 \\ & h \leftarrow \underline{\operatorname{ShiftDown}(u)}+\operatorname{ShiftRight}(h)+\text { PositionEmbeddings } \\ & h \leftarrow \text { MaskedTransformerBlock }_2(h) \quad \times L_{\text {row }}\end{aligned} The tensor uu represents context captured above the current pixel. Finally, pass through LayerNorm and dense layer to produce logits H×W×256.H \times W \times 256. Effective for point clouds, etc.

Longformer

dilated CNNs \rightarrow dilated sliding windows Qs,Ks,VsQ_s, K_s, V_s Increases the receptive field by layer like CNN, and therefore this indirect approach performs similarly bad at long distance modeling. special tokens for global attention (BERT [CLS]) Qg,Kg,VgQ_g, K_g, V_g , which needs to attend to and be attended by all tokens. Crucial. image.png

Big Bird

global tokens + fixed patterns (local sliding windows)+ random attention (queries attend to random keys). Justification for randomness: The standard Transformer is a complete digraph, which can be approximated with random graphs. The model is Turing-complete. image.png

Routing Transformer

learns attention sparsity with k-means clustering where k=n.k=\sqrt{n}. Cluster centroid vectors μ=(μ1,,μk)Rk×d\boldsymbol{\mu}=\left(\mu_1, \cdots, \mu_k\right) \in \mathbb{R}^{k \times d} are shared for QQ and K.K. Xi=j:Kjμ(Qi),j<iAijVjX_i^{\prime}=\sum_{\substack{j: K_j \in \mu\left(Q_i\right), j<i}} A_{i j} V_j image.png For decoder (causal attention), solutions include:

  1. additional lower-triangular masking;
  2. share queries and keys. Namely, kQ.k\leftarrow Q. Works better.

Reformer

Locality Sensitive Hashing (LSH). For bb hashes, define random matrix Rdk×b2.R^{d_k\times \frac{b}{2}}. h(x)=argmax([xR;xR])h(x)=\arg \max ([x R ;-x R]) similarity with bb random vectors.

random_rotations = np.random.randn(hidden_dim, n_buckets // 2)
rotated_vectors = np.dot(x, random_rotations)
rotated_vectors = np.hstack([rotated_vectors, -rotated_vectors])
buckets = np.argmax(rotated_vectors, axis=-1)

attention: oi=jPiexp(qikjz(i,Pi))vj where Pi={j:h(qi)=h(kj)}o_i=\sum_{j \in \mathcal{P}_i} \exp \left(q_i \cdot k_j-z\left(i, \mathcal{P}_i\right)\right) v_j \quad \text { where } \mathcal{P}_i=\left\{j:h\left(q_i\right)=h\left(k_j\right)\right\}where zz is the normalizing term in softmax. image.png To avoid queries with no keys, set kj=qjqjk_j=\frac{q_j}{\left\|q_j\right\|} s.t. h(kj)=h(qj).h\left(k_j\right) =h\left(q_j\right). Multi-round: Pi=r=1nround Pi(r).\mathcal{P}_i=\bigcup_{r=1}^{n_{\text {round }}} \mathcal{P}_i^{(r)}.

Revnet:

Reduce memory (activation) cost with extra computation.

In reformer, set FF to LSH attention blocks, and gg to FFN.

Linformer

N×dk×dN\times d\rightarrow k\times d, reduction on NN dimension instead of k.k. Needs to maintain causal masking. Softmax(1dkXWiQ(EiXWiK))FiXWiV\operatorname{Softmax}\left(\frac{1}{\sqrt{d_k}} X W_i^Q\left(E_i X W_i^K\right)\right) \cdot F_i X W_i^V where Ei,FiE_i, F_i are k×Nk \times N projections. Reminiscent of depth-wise convolutions/ pooling.

Performer

Fast Attention via Orthogonal Random Features (FAVOR): With Kernel K(x,y)=E[ϕ(x)ϕ(y)]\mathrm{K}(\mathbf{x}, \mathbf{y})=\mathbb{E}\left[\phi(\mathbf{x})^{\top} \phi(\mathbf{y})\right], where ϕ\phi is a random feature map, we can write attention as A(i,j)=K(qi,kj)\mathbf{A}(i, j)=\mathrm{K}\left(\mathbf{q}_i^{\top}, \mathbf{k}_j^{\top}\right), then Att^(Q,K,V)=D^1(Q((K)V)),D^=diag(Q((K)1L))\widehat{\operatorname{Att}_{\leftrightarrow}}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\widehat{\mathbf{D}}^{-1}\left(\mathbf{Q}^{\prime}\left(\left(\mathbf{K}^{\prime}\right)^{\top} \mathbf{V}\right)\right), \quad \widehat{\mathbf{D}}=\operatorname{diag}\left(\mathbf{Q}^{\prime}\left(\left(\mathbf{K}^{\prime}\right)^{\top} \mathbf{1}_L\right)\right) image.png

slower on causal(autoregressive) due to additional steps for masking, therefore unable to be parallelized.

Linear Transformer

kernel method, linear-time, constant memory RNN. sim(q,k)=exp(qTkd)\operatorname{sim}(q, k)=\exp \left(\frac{q^T k}{\sqrt{d}}\right), observe: Vi=j=1psim(Qi,Kj)Vjj=1psim(Qi,Kj).V_i^{\prime}=\frac{\sum_{j=1}^p \operatorname{sim}\left(Q_i, K_j\right) V_j}{\sum_{j=1}^p \operatorname{sim}\left(Q_i, K_j\right)}. with kernel method: sim(q,k):=ϕ(q)Tϕ(k),\operatorname{sim}(q, k):=\phi(q)^T \phi(k), then Vi=ϕ(Qi)TSpϕ(Qi)TZp,Sp:=j=1pϕ(Kj)VjT,Zp:=j=1pϕ(Kj).\begin{aligned} V_i^{\prime} & =\frac{\phi\left(Q_i\right)^T S_p}{\phi\left(Q_i\right)^T Z_p}, \\ S_p & :=\sum_{j=1}^p \phi\left(K_j\right) V_j^T, \\ Z_p & :=\sum_{j=1}^p \phi\left(K_j\right) .\end{aligned} for unmasked, QiQ_i attends to the same NN keys, therefore we simply reuse Sp,Zp.S_p,Z_p. for masked, incremental: Si=Si1+ϕ(Ki)ViT,Zi=Zi1+ϕ(Ki)\begin{aligned} & S_i=S_{i-1}+\phi\left(K_i\right) V_i^T, \\ & Z_i=Z_{i-1}+\phi\left(K_i\right)\end{aligned} if O(c)O(c) to compute ϕ,\phi, then cc is O(Ncd).O(Ncd). choose: ϕ(x)=elu(x)+1\phi(x)=\operatorname{elu}(x)+1, then c=d.c=d.

References

[1] Yi Tay, Mostafa Dehghani, Dara Bahri, and Donald Metzler. 2022. Efficient Transformers: A Survey. ACM Comput. Surv. 55, 6, Article 109 (December 2022), 28 pages. https://doi.org/10.1145/3530811


Previous Post
Solution of Project Euler [484]
Next Post
The Rust Notebook