Llama3 Laboratory @ MediaTek

In this lab, we explaind the essential design choices used by llama3, including RMSNorm, Rotrary Embedding, and KV cache.

This lab is conducted with my smartest collegue, Jason, during our time in Institute of Information Science, Academia Sinica. The target audience are engineers from MediaTek, one of the world's leading semiconductor solution company.
The overall architechture of llama3, each block is stacked by a RMS Norm, a GQA + KV cache, and a SwiGLU as an activation function.

RMSNorm was selected as the normalization method due to its computational efficiency while maintaining normalization effectiveness, as originally demonstrated by Zhang and Sennrich (2019). The advantages of RMSNorm are twofold: it reduces computational complexity by eliminating the mean statistics and centering operation present in LayerNorm, while maintaining comparable model performance. This efficiency-to-performance ratio has led to its adoption in several prominent language models including Gopher (Rae et al., 2021), LLaMA (Touvron et al., 2023), and LLaMA 2 (Touvron et al., 2023).

Compared to standard positional embedding (categorized as ‘absolute embedding’), rotary position embedding (Su et al., 2021) offers greater flexibility. The key advantage of rotary embedding lies in its relative nature, which enables the model to learn position-dependent representations without being constrained to fixed absolute positions. The computation process involves two main steps: first, calculating the rotation for each embedding dimension, then computing the inner product between the rotated word embeddings and positional embeddings. This is visualized across three plots: the leftmost shows the original word embeddings, the middle displays the rotary embedding base vectors, and the rightmost demonstrates the final rotated embeddings. The transformation preserves the relative relationships between tokens while allowing for more flexible position modeling.

A 2D visualization of rotarary embedding, we did not scale the original embedding to better preserve the embedding length. The left shows the original embedding, the middle is the rotary positional embedding. The rotary mechanism ensure that no matter which exact position of the 2 words in the sequence are, they remain the same distance to each other. (While in the standard positional embeddings, [????]). The right is taking the inner product of the word and positional embedding. Note: This visualization only shows the first 2 dimensions of all embeddings. An interactable is provided in the colab.

The KV (Key-Value) Cache implementation optimizes transformer inference by storing and reusing previously computed key and value tensors. When generating tokens sequentially, instead of recomputing attention for the entire sequence, the implementation caches past key-value pairs and only computes attention for the new token (Dao et al., 2022). This significantly reduces computational overhead, especially for long sequences. As shown in the diagram, the top half illustrates standard attention without caching (computing full Q×K^T for all tokens), while the bottom half demonstrates cached attention (computing Q×K^T only for the new token by reusing stored keys). The implementation includes a KVCache dataclass for managing the cache state and a CachedAttention module that leverages this cache during inference.