機(jī)器之心報(bào)道
編輯:Panda
在正在舉辦的半導(dǎo)體行業(yè)會(huì)議 Hot Chips 2025 上,TogetherAI 首席科學(xué)家 Tri Dao 公布了FlashAttention-4
據(jù)介紹,在 Backwell 上,F(xiàn)lashAttention-4 的速度比英偉達(dá) cuDNN 庫(kù)中的注意力核實(shí)現(xiàn)快可達(dá) 22%!
在這個(gè)新版本的 FlashAttention 中,Tri Dao 團(tuán)隊(duì)實(shí)現(xiàn)了兩項(xiàng)關(guān)鍵的算法改進(jìn)。
一、它使用了一種新的在線 softmax 算法,可跳過(guò)了 90% 的輸出 rescaling。
二、為了更好地將 softmax 計(jì)算與張量核計(jì)算重疊,它使用了指數(shù) (MUFU.EX2) 的軟件模擬來(lái)提高吞吐量。
此外,F(xiàn)lashAttention-4 使用的是 CUTLASS CuTe Python DSL,其移植到 ROCm HIP 的難度要高出 10 倍,而 CUDA C++ 移植到 ROCm HIP 則更容易。
有意思的是,Tri Dao 還宣布,在執(zhí)行 A@B+C 計(jì)算時(shí),對(duì)于 Blackwell 上在歸約維度 K 較小的計(jì)算場(chǎng)景中,他使用 CUTLASS CuTe-DSL 編寫的核(kernel)比英偉達(dá)最新的 cuBLAS 13.0 庫(kù)快不少。而在標(biāo)準(zhǔn)矩陣算法 A@B 時(shí),兩者速度總體是相當(dāng)?shù)摹?/p>
據(jù)介紹,他的核通過(guò)使用兩個(gè)累積緩沖區(qū)來(lái)重疊 epilogue,從而擊敗了 cuBLAS。
Semi Analysis 表示,像 Tri Dao 這樣的開發(fā)者是 CUDA 護(hù)城河的核心優(yōu)勢(shì)之一,因?yàn)?Tri Dao 只使用英偉達(dá) GPU,并將其大部分核開源給其他英偉達(dá)開發(fā)者群體。Tri Dao 等研究者均不使用 ROCm AMD GPU 或 Trainium 芯片。
這對(duì)于 AMD 等來(lái)說(shuō)可不是好消息,假如 AMD 希望 Tri Dao 和他的團(tuán)隊(duì)在 ROCm 上實(shí)現(xiàn)算法突破。那么,它就應(yīng)該為 TogetherAI GPU 云服務(wù)上的 AMD GPU 提供優(yōu)惠支持。Semi Analysis 分析說(shuō):「谷歌為 Noam Shazeer 支付了 27 億美元,Zucc 為 OpenAI 工程師支付了 1 億美元,AMD 擁有足夠的現(xiàn)金,可以為 TogetherAI/Tri Dao 支付 5000 萬(wàn)美元來(lái)啟動(dòng) ROCm 生態(tài)系統(tǒng)?!?/p>
FlashAttention最早由 Tri Dao 等人在 2022 年提出,論文標(biāo)題為《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》。
論文地址:https://arxiv.org/pdf/2205.14135
其背景是傳統(tǒng)的注意力機(jī)制因需生成 N×N 的注意力矩陣,在序列長(zhǎng)度 N 增長(zhǎng)時(shí)引發(fā)二次的(quadratic)時(shí)間和內(nèi)存開銷。
而 FlashAttention 強(qiáng)調(diào)「IO-awareness」,不再將注意力矩陣完整載入,而是通過(guò)「tiling+softmax rescaling」策略,將數(shù)據(jù)塊臨時(shí)存入高速緩存(SRAM),在內(nèi)部積累,再寫回高帶寬內(nèi)存(HBM),避免了大量讀寫開銷,內(nèi)存復(fù)雜度得到顯著降低 —— 從 O (N2) 降至 O (N)。
如圖所示,在左圖中,F(xiàn)lashAttention 使用了 tiling 技術(shù)來(lái)防止在(相對(duì)較慢的)GPU HBM 上執(zhí)行很大的 × 注意力矩陣(虛線框)。在外層循環(huán)(紅色箭頭)中,F(xiàn)lashAttention 循環(huán)遍歷 K 和 V 矩陣的塊,并將其加載到快速片上 SRAM 中。在每個(gè)塊中,F(xiàn)lashAttention 循環(huán)遍歷 Q 矩陣的塊(藍(lán)色箭頭),將其加載到 SRAM 中,并將注意力計(jì)算的輸出寫回 HBM。
在右圖中,可以看到相比 GPT-2 上 PyTorch 注意力實(shí)現(xiàn),F(xiàn)lashAttention 速度更快 ——FlashAttention 無(wú)需將大型 × 注意力矩陣讀寫到 HBM,從而將注意力計(jì)算速度提升了 7.6 倍。
整體上,初代 FlashAttention 帶來(lái)的增益也很顯著:在 BERT-large(序列長(zhǎng)度 512)中相比 MLPerf 基線提升訓(xùn)練速度約 15%;GPT-2(序列長(zhǎng)度 1K)提升約 3 倍;在 Long-Range Arena(序列長(zhǎng)度 1K–4K)提升約 2.4 倍。
一年后,FlashAttention-2問(wèn)世,這一次,作者僅 Tri Dao 一人。順帶一提,他還在這一年的晚些時(shí)候與 Albert Gu 共同提出了 Mamba。
論文地址:https://arxiv.org/pdf/2307.08691
其改進(jìn)的焦點(diǎn)是:FlashAttention 已顯著提升性能,但在 GPU 上仍存在低吞吐率的問(wèn)題,僅能達(dá)到理論峰值很低的比例(約 25–40%)。
為此,Tri Dao 提出的解決策略包括:
- 工作劃分優(yōu)化:重新設(shè)計(jì)分塊策略與線程分配,提升并行效率,增加硬件利用率;
- 減少非矩陣運(yùn)算,加快整體執(zhí)行;
- 支持更大 head size(至 256) 及多查詢注意力(MQA) 和分組查詢注意力(GQA),適配更多模型架構(gòu)需求。
結(jié)果,相比初代 FlashAttention,F(xiàn)lashAttention-2 速度提高約 2–4×;在 A100 GPU 上 FP16/BF16 可達(dá)到高至 230 TFLOPs/s,達(dá) PyTorch 標(biāo)準(zhǔn)實(shí)現(xiàn) 9 倍速度提升。參閱機(jī)器之心報(bào)道《比標(biāo)準(zhǔn) Attention 提速 5-9 倍,大模型都在用的 FlashAttention v2 來(lái)了》。
又一年,FlashAttention-3誕生,這一次改進(jìn)的重點(diǎn)是適配 Hopper 架構(gòu),異步與低精度??梢钥吹?,Tri Dao 這一次的名字掛在最后。此時(shí)他雖然還繼續(xù)在普林斯頓大學(xué)任教,但也同時(shí)已經(jīng)是 Together AI 的首席科學(xué)家。
論文地址:https://arxiv.org/pdf/2407.08608
為了能加速在 Hopper GPU 上的注意力,F(xiàn)lashAttention-3 主要采用了三種技術(shù):
- 通過(guò) warp-specialization 重疊整體計(jì)算和數(shù)據(jù)移動(dòng);
- 交錯(cuò)分塊 matmul 和 softmax 運(yùn)算;
- 利用硬件支持 FP8 低精度的不連貫處理。
FlashAttention-3 的速度是 FlashAttention-2 的 1.5-2.0 倍,高達(dá) 740 TFLOPS,即 H100 理論最大 FLOPS 利用率為 75%。使用 FP8,F(xiàn)lashAttention-3 的速度更是接近 1.2 PFLOPS。參閱機(jī)器之心報(bào)道《英偉達(dá)又賺到了!FlashAttention3 來(lái)了:H100 利用率飆升至 75%》。
現(xiàn)在,到了 2025 年,FlashAttention-4準(zhǔn)時(shí)到來(lái),增加了對(duì) Blackwell GPU 的原生支持——之前,想要在 Blackwell 上跑 FlashAttention,如果直接用開源倉(cāng)庫(kù),常常會(huì)遇到編譯錯(cuò)誤、kernel 缺失或性能未優(yōu)化的情況,可用的 Blackwell 加速主要是借助英偉達(dá) Triton/cuDNN 的間接支持。
圖源:https://www.reddit.com/r/LocalLLaMA/comments/1mt9htu/flashattention_4_leak/
此時(shí),F(xiàn)lashAttention 的 GitHub 軟件庫(kù)已經(jīng)積累了超過(guò)1.91 萬(wàn)星。
項(xiàng)目地址:https://github.com/Dao-AILab/flash-attention
目前,Tri Dao 團(tuán)隊(duì)尚未發(fā)布 FlashAttention-4 的技術(shù)報(bào)告,更多細(xì)節(jié)還有待進(jìn)一步揭曉。
https://x.com/tri_dao/status/1960217005446791448
https://x.com/SemiAnalysis_/status/1960070677379133949
https://www.reddit.com/r/LocalLLaMA/comments/1mt9htu/flashattention_4_leak/
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺(tái)“網(wǎng)易號(hào)”用戶上傳并發(fā)布,本平臺(tái)僅提供信息存儲(chǔ)服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.