Llama3 Laboratory @ MediaTek
In this lab, we explaind the essential design choices used by llama3, including RMSNorm, Rotrary Embedding, and KV cache.


RMSNorm was selected as the normalization method due to its computational efficiency while maintaining normalization effectiveness, as originally demonstrated by Zhang and Sennrich (2019). The advantages of RMSNorm are twofold: it reduces computational complexity by eliminating the mean statistics and centering operation present in LayerNorm, while maintaining comparable model performance. This efficiency-to-performance ratio has led to its adoption in several prominent language models including Gopher (Rae et al., 2021), LLaMA (Touvron et al., 2023), and LLaMA 2 (Touvron et al., 2023).

Compared to standard positional embedding (categorized as ‘absolute embedding’), rotary position embedding (Su et al., 2021) offers greater flexibility. The key advantage of rotary embedding lies in its relative nature, which enables the model to learn position-dependent representations without being constrained to fixed absolute positions. The computation process involves two main steps: first, calculating the rotation for each embedding dimension, then computing the inner product between the rotated word embeddings and positional embeddings. This is visualized across three plots: the leftmost shows the original word embeddings, the middle displays the rotary embedding base vectors, and the rightmost demonstrates the final rotated embeddings. The transformation preserves the relative relationships between tokens while allowing for more flexible position modeling.

The KV (Key-Value) Cache implementation optimizes transformer inference by storing and reusing previously computed key and value tensors. When generating tokens sequentially, instead of recomputing attention for the entire sequence, the implementation caches past key-value pairs and only computes attention for the new token (Dao et al., 2022). This significantly reduces computational overhead, especially for long sequences. As shown in the diagram, the top half illustrates standard attention without caching (computing full Q×K^T for all tokens), while the bottom half demonstrates cached attention (computing Q×K^T only for the new token by reusing stored keys). The implementation includes a KVCache dataclass for managing the cache state and a CachedAttention module that leverages this cache during inference.
