Low-Dimensional Transformer Model Visualization

I’ve tried to learn the details of the transformer architecture the last couple weeks. Some resources I found useful:

  • Neel Nandas youtube videos - good to double check understanding of details/implement it yourself. Everything is done in tensor operations.
  • Nelson Elhages blog post - a different (non-matrix math) perspective, where he walks through a template implementation in Rust. Easier to understand if you know coding but don’t feel as confident manipulating tensors. Also just gives another perspective.
  • paper: A Mathematical Framework for Transformer Circuits - Gives a lot of intuitive explanations for how information flows in a Transformer.

In the spirit of the saying “A picture says a thousand words”, I decided to try to draw a visualization of the parts I found the hardest to understand:

  • The overall information flow between different parts in the model
  • How self-attention heads work.

Each line in the diagram below depicts one (usually floating point) number passing through the model. OBS: there are some parts that are left out to de-clutter the picture (for example layer normalization).

Transformer Diagram

For comparison/reference, here are the hypothetical model parameters:

residual stream/model dimension: 3
internal head dimension: 2
number of heads: 2
number of [Attention + MLP] layers: 2
context length (in tokens): 2

Hope this helps someone! If you have any questions/error corrections, send me an email.