Introduction
On the Latent Space podcast, Alessio, a Partner and CTO-in-Residence at Decibel Partners, hosts a discussion with guest Tri Dao. Tri recently completed his PhD at Stanford and is a main author of the groundbreaking FlashAttention paper pivotal in the Transformers era. Tri shares insights into efficient transformer training, inference, and long-range sequence models. He is set to be an assistant professor in Computer Science at Princeton in the coming year. Tri also recently joined as the Chief Scientist at the company, Together, which is responsible for RedPajama.
Tri reveals a personal tidbit that he initially intended to major in economics during his early days at Stanford, but after taking math classes, he shifted his focus to mathematics. This decision played a significant role in steering him towards his current career in math, computer science, and AI research.
The discussion delves deep into FlashAttention and its recently released successor, FlashAttention 2. The innovation in FlashAttention is its capability to scale linearly, as opposed to the traditional quadratic scaling. Tri emphasizes the importance of avoiding approximation in attention mechanisms. He explains that while other methods focus on approximating attention, their main objective was efficiency and memory. Their approach saw a wall-clock speed up of 2 to 4 times, making training 2 to 4 times longer possible without added costs.
A significant aspect of their innovation involves merging ideas from both machine learning and system designs, particularly kernel fusion. This technique optimizes memory reading and writing, which consumes most of the time in attention mechanisms. While kernel fusion has its merits, Tri acknowledges that it may limit flexibility, especially for researchers keen on tweaking the attention process. However, the benefits are primarily in leveraging faster memory (SRAM) compared to the more massive but slower memory (HBM), capitalizing on the asymmetric memory hierarchy present in GPUs.