以“排序”替代“預(yù)測”:Token Order Prediction(TOP)如何改進語言模型訓(xùn)練
在當(dāng)下以 Next-Token Prediction(NTP)為主導(dǎo)范式的 LLM 訓(xùn)練中,研究者一直在探索如何通過輔助目標(biāo)提升內(nèi)部表征,從而帶來更好的下游泛化。本文聚焦一篇來自 MBZUAI 的早期預(yù)印本,提出以“排序未來詞元的接近度”替代“精確預(yù)測未來詞元”的新思路,將 Multi-Token Prediction(MTP)的“難度過高”轉(zhuǎn)化為 Learning-to-Rank 的“近似排序”問題。核心結(jié)論顯示:在 340M、1.8B 與 7B 三個模型規(guī)模上,作為 NTP 的輔助目標(biāo),Token Order Prediction(TOP)整體優(yōu)于 NTP 與 MTP,且在標(biāo)準(zhǔn) NLP 基準(zhǔn)上更穩(wěn)健。
該研究的關(guān)鍵創(chuàng)新在于將未來詞元的“順序信息”轉(zhuǎn)為可學(xué)習(xí)的排名分布,用 Listwise 排序損失對單個額外的 unembedding 頭進行監(jiān)督,避免為每個未來偏移位置增設(shè)獨立 Transformer 層的額外參數(shù)與計算開銷。與 MTP 相比,TOP 訓(xùn)練更穩(wěn)定、擴展性更好,并在多項通用任務(wù)上呈現(xiàn)出一致的提升趨勢。
1. 基本信息
論文《Predicting the Order of Upcoming Tokens Improves Language Modeling》
作者為 Zayd M. K. Zuhri、Erland Hilman Fuadi 與 Alham Fikri Aji,研究單位為 MBZUAI。
該工作為早期預(yù)印本,版本 v1 發(fā)表于 2025-08-26,arXiv 鏈接為:https://arxiv.org/abs/2508.19228。
作者開源了實現(xiàn)與訓(xùn)練代碼,倉庫地址為:https://github.com/zaydzuhri/token-order-prediction
以 Next-token Prediction, NTP 為中心的自回歸訓(xùn)練范式在語言理解與生成上取得了卓越成績,但其局限與改進空間同樣被持續(xù)討論。近年來,一類代表性嘗試是以Multi-token prediction, MTP 為輔助目標(biāo),讓模型在共享 Backbone 的末端分出多個分支,分別預(yù)測 的精確詞元,從而促使主干表征“向前看”。這類方法在代碼生成、摘要等需要較強“前瞻性”的生成任務(wù)上常有收益,同時還可被用于一定形式的自我推測解碼,以提高推理速度。
然而,MTP 的泛化改善在通用 NLP 基準(zhǔn)上并不穩(wěn)定,尤其在小模型上常見無效甚至退化。經(jīng)驗上,當(dāng)未來偏移步數(shù) 增加時,訓(xùn)練難度顯著提升,且最優(yōu) 難以一刀切地確定。
這從側(cè)面表明,“精確預(yù)測多步未來詞元”的學(xué)習(xí)目標(biāo)可能過難,難度的不匹配反而削弱了其作為輔助目標(biāo)的普適價值。該研究據(jù)此提出一個關(guān)鍵判斷:與其強迫模型“精確命中遠(yuǎn)期詞元”,不如讓模型“學(xué)會排序哪些詞元更快出現(xiàn)”,即把“多步精確預(yù)測”放寬成“近似順序?qū)W習(xí)”的 Learning-to-Rank 問題,從難題轉(zhuǎn)化為“更可學(xué)”的目標(biāo)。
因此,研究的動機在于:在不犧牲“向前看”這一有益歸納偏置的前提下,構(gòu)建更溫和、參數(shù)與計算開銷更友好、且在通用任務(wù)上更穩(wěn)定的輔助訓(xùn)練目標(biāo)。Token Order Prediction, TOP 即在此語境下提出,用單一額外的 unembedding 頭和一個 listwise 排序損失,誘導(dǎo) Backbone 習(xí)得對“未來詞元接近度”的結(jié)構(gòu)化感知。
3. 方法
該方法的設(shè)計出發(fā)點是充分保留“未來結(jié)構(gòu)”的監(jiān)督信號,同時回避 MTP 在多步精確命中上的固有難度與擴展瓶頸。MTP 的每一個未來步都需配備一層單獨的 Transformer 頭,既增加參數(shù)又加重計算;更關(guān)鍵的是,越遠(yuǎn)的未來步難度越高、梯度信號越噪,容易拖累整體優(yōu)化。Token Order Prediction, TOP 將“精確分類到某個未來偏移位置”的任務(wù),替換為“對全詞表進行‘下一個出現(xiàn)時間’的接近度打分”,把硬分類變成軟排序,使訓(xùn)練信號更平滑、覆蓋面更廣、且隨著窗口大小變化不會線性地膨脹參數(shù)。
整體框架上,模型主干仍采用標(biāo)準(zhǔn)自回歸 Transformer,僅在輸出層并聯(lián)兩個線性頭:用于常規(guī) NTP 的 與用于排序監(jiān)督的 。訓(xùn)練時聯(lián)合最小化兩者的損失之和,推理時移除 ,僅保留 ,因此不會改變推理時的架構(gòu)與接口。
具體實現(xiàn)分為目標(biāo)構(gòu)造與損失定義兩部分。
首先,給定輸入序列 、詞表大小 與窗口大小 ,對每個時間步 構(gòu)造長度為 的“接近度向量” 。對任意詞元 ,若其在區(qū)間 的首次出現(xiàn)位置與 的距離為 ,則令 ;若在窗口內(nèi)不可達,則 。這意味著離當(dāng)前更近的“即將出現(xiàn)詞元”將獲得更高分值,從而形成對“未來順序”的隱式排序監(jiān)督。該目標(biāo)可以用如下偽代碼(與論文一致)刻畫其自右向左的一次掃描構(gòu)造過程:
Input: token sequence x (length T+W), vocab size V, window size W Output: target tensor y of shape (T, V) Initialize y[:] = -∞ Initialize next[v] = T + W for all v in [0, V-1] for t from T+W-1 down to 0: if x[t] in vocab: next[x[t]] = t if t < T: for each v in [0, V-1]: d = next[v] - t if 0 < d ≤ W: y[t, v] = W - d

損失函數(shù)借鑒 listwise Learning-to-Rank 思想。設(shè)主干最后一層隱藏狀態(tài)為 ,NTP 頭與 TOP 頭分別為線性映射 與 。標(biāo)準(zhǔn) NTP 損失為
TOP 的 listwise 排序損失將 視作“軟目標(biāo)分?jǐn)?shù)”,在歸一化后與預(yù)測打分的歸一化分布求交叉熵:
最終優(yōu)化目標(biāo)是兩者之和:
從表征學(xué)習(xí)角度看, 以“接近度排名”為監(jiān)督,迫使 捕獲“短期將出現(xiàn)哪些詞元及其大致先后”的結(jié)構(gòu)性信息;這種信息與 NTP 的“下一個詞元概率”目標(biāo)高度一致,二者在“靠近下一個詞元”的方向上形成合力,從而以較小的額外參數(shù)與顯著低于 MTP 的計算開銷,增強主干的建模能力。實現(xiàn)上,作者使用了融合的 Triton kernel 將 unembedding 與損失計算在 block 級別一次完成,幾乎不引入額外吞吐?lián)p失;由于 TOP 僅新增一個與 同形的線性層,參數(shù)與顯存開銷遠(yuǎn)低于隨 線性增長的 MTP 多層頭。
值得強調(diào)的是,若僅以 TOP 訓(xùn)練而移除 NTP,則推理時只能進行貪心生成,缺乏概率采樣的靈活性;因此該工作定位 TOP 為“輔助目標(biāo)”,而非替代 NTP 的主目標(biāo)。
4. 實驗與發(fā)現(xiàn)
實驗在三個模型規(guī)模上系統(tǒng)比較了 NTP、MTP 與 TOP:約 340M、1.8B 與 7B 參數(shù)。訓(xùn)練數(shù)據(jù)來自 FineWeb-Edu 的“sample-100BT”子集,340M 訓(xùn)練 52B tokens,1.8B 與 7B 訓(xùn)練 104B tokens。統(tǒng)一采用序列長度 4096、RoPE 、詞表大小 32k、未綁權(quán)重(untied embeddings),優(yōu)化器為 AdamW,余弦學(xué)習(xí)率調(diào)度與適當(dāng) warmup,并在 MTP 設(shè)置中使用 4 個未來步。為保證可重復(fù)性,論文詳細(xì)給出了每個規(guī)模的層數(shù)/隱藏維度/頭數(shù)、學(xué)習(xí)率與 batch 配置;實現(xiàn)基于 Flame 與 flash-linear-attention。
評測覆蓋 8 個標(biāo)準(zhǔn) NLP 基準(zhǔn):LAMBADA(準(zhǔn)確率與困惑度)、HellaSwag、ARC Challenge、PIQA、SciQ、Social IQa、NaturalQuestions Open、TriviaQA(Exact Match)。核心發(fā)現(xiàn)可以概括為三個層面。第一,在幾乎所有規(guī)模與多數(shù)任務(wù)上,TOP 相較基線 NTP 與 MTP 都帶來一致提升。例如,LAMBADA 上 TOP 在 340M/1.8B/7B 的準(zhǔn)確率分別優(yōu)于 NTP(36.35%→37.07%、49.58%→50.34%、55.89%→57.03%),困惑度也同步下降(340M:30.34→28.76;1.8B:11.38→11.19;7B:7.97→7.64)。HellaSwag 的歸一化準(zhǔn)確率在三個規(guī)模上亦全面提升(340M:42.53%→43.57%;1.8B:60.05%→60.45%;7B:67.44%→68.73%)。第二,MTP 在通用理解類任務(wù)上并不穩(wěn)定,尤其 7B 規(guī)模時常出現(xiàn)退化;TOP 則隨規(guī)模增大持續(xù)受益,在 TriviaQA(EM)上差距尤為明顯(1.8B:11.85→18.93;7B:24.28→30.90)。第三,盡管 TOP 的訓(xùn)練階段 NTP 頭上記錄到的訓(xùn)練損失略高于純 NTP(提示正則化效應(yīng)與更少過擬合),但在評測困惑度與準(zhǔn)確率上表現(xiàn)更佳,指向更強的泛化。
從統(tǒng)計與實際意義討論來看,這些改進不僅體現(xiàn)在平均數(shù)上,也具備一致的跨任務(wù)、跨規(guī)??蛇w移性。尤其是當(dāng)任務(wù)偏向理解與檢索式問答(如 NQ Open、TriviaQA)時,TOP 在較大規(guī)模上展現(xiàn)出明顯優(yōu)勢,說明“順序接近度”的結(jié)構(gòu)性監(jiān)督與語言模型的長程一致性建模、知識尋址與答案定位存在內(nèi)在耦合。相對地,個別社會常識類任務(wù)(如 Social IQa)在 7B 上略有波動,提示未來可以在窗口機制、權(quán)重共享或損失權(quán)重上做更細(xì)粒度的調(diào)節(jié)與消融。
值得注意的是,MTP 的小規(guī)模模型在該復(fù)現(xiàn)實驗中并非處處落后,這一結(jié)果與部分先前報告形成互補;但隨著規(guī)模擴大,MTP 在通用任務(wù)上的弱勢更加明顯。TOP 則展現(xiàn)出“越大越好”的單調(diào)趨勢,符合“排序監(jiān)督更易優(yōu)化、對主干更友好”的設(shè)計初衷。
5. 結(jié)論與展望
該研究以 Token Order Prediction(TOP)為核心貢獻,給出了一個輕量、可擴展且與 NTP 強一致的輔助訓(xùn)練目標(biāo)。通過將“未來詞元精確預(yù)測”的目標(biāo),替換為“未來詞元接近度排序”的 listwise 學(xué)習(xí),TOP 在相近或更小的額外開銷下,更穩(wěn)定地提升了 LLM 在通用 NLP 基準(zhǔn)上的表現(xiàn)。實驗顯示,隨著參數(shù)規(guī)模增大,TOP 的收益進一步擴大,這為大模型預(yù)訓(xùn)練中的輔助目標(biāo)設(shè)計提供了新的方向。
說明:本文基于作者公開的早期預(yù)印本撰寫,實驗細(xì)節(jié)與擴展結(jié)果以后續(xù)版本為準(zhǔn)。
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺“網(wǎng)易號”用戶上傳并發(fā)布,本平臺僅提供信息存儲服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.