#P5131. 第3题-Certainty Forcing 训练损失计算
-
1000ms
Tried: 11
Accepted: 1
Difficulty: 7
所属公司 :
华为
时间 :2026年7月1日-AI方向
第3题-Certainty Forcing 训练损失计算
解题思路
先对每个 token 根据 logits 计算稳定版 softmax。
令:
mi=jmaxzi,j题目内容
自回归语言模型(AR LLM)通常按从左到右的顺序逐个生成 token。
而 Diffusion LLM 的特点是:它会同时对一批位置进行预测,再逐步把整条序列“修正”出来,因此天然更适合并行解码。
但在实际中,一个关键问题是:很多 token 虽然已经朝正确答案靠近了,但它们的预测分布还不够集中,置信度提升得不够快,这会拖慢并行解码。
dParallel(dParallel: Learnable Parallel Decoding for dLLMs)提出的 Certainty Forcing 的核心思路,就是在训练时对部分 token 额外加入一个“让分布更确定”的约束,从而让更多 token 更早达到高置信状态,减少解码步数。
本题考虑一个简化版训练目标。
对于每个 token,模型会输出它在整个词表上的 logits。
普通训练使用 Cross Entropy;如果对某个token施加 Certainty Forcing,则额外加入该token在整个词表上的预测分布熵(entropy)作为训练增强。
但并不是所有 token 都适合被加入这种额外训练损失。比如一些简单 token(如 the、a)如果被过度强化,可能会影响训练稳定性。
因此,我们预先为每个 token 计算了一个“稳定性影响值”,表示若对它加入 Certainty Forcing,可能对训练稳定性带来的影响。现在需要在总影响不超过限制的前提下,选择一部分 token 加入 Certainty Forcing,并计算最终训练损失。
数学定义
设共有 N 个token,词表大小为 V。 对于第 i 个token,给定:
- logits:zi,0, zi,1, …, zi,V−1
- 正确标签:yi
- 稳定性影响值:ci
1. softmax 概率
第 i 个token在词表第 j 个词上的预测概率为:
pi,j=∑t=0V−1ezi,tezi,j实际实现时,为了避免数值不稳定,建议使用等价的写法。先令:
mi=0≤t<Vmaxzi,t再计算:
pi,j=∑t=0V−1ezi,t−miezi,j−mi两种写法数学上等价,但后者更稳定。
2. Cross Entropy
第 i 个token的基础交叉熵损失为:
ℓi=−logpi,yi其中 yi 是正确标签。
3. Certainty Forcing loss
第 i 个token的额外约束项定义为该预测分布的熵:
Hi=−j=0∑V−1pi,jlogpi,j注意:
- ℓi 只看正确标签对应的概率;
- Hi 看整个词表上的概率分布;
- 两者不是同一个量;
- Hi 不是输入给定,需要你根据 logits 先计算 softmax 后再求出。
4. token 的选择价值
本题规定: 第 i 个token的选择价值等于它的 Cross Entropy:
vi=ℓi也就是说,Cross Entropy 越大,越优先考虑对该 token 加入 Certainty Forcing。
5. 选择约束
设被选中的token集合为 S,要求满足:
i∈S∑ci≤B其中 B 是总稳定性影响上限。
在满足约束的前提下,选择一个集合 S ,使得:
i∈S∑vi=i∈S∑li6.最终训练损失
最终总 loss 定义为:
loss=i=1∑Nℓi+λ⋅i∈S∑Hi其中 λ 为给定系数。
任务要求
你需要完成以下步骤:
- 根据 logits 计算每个 token 的 softmax 概率;
- 计算每个 token 的 Cross Entropy;
- 计算每个 token 的熵 Hi;
- 在总稳定性影响不超过 B 的条件下,选择一部分 token,使被选中 token 的 Cross Entropy 之和最大;
- 输出最终总 loss。
输入描述
第一行输入四个数: N V B λ 分别表示:
- N:token 数量
- V:词表大小
- B:总稳定性影响上限
- λ:Certainty Forcing 系数
接下来共有 N 行,每行输入:
zi,0 zi,1 … zi,V−1 yi ci
含义为:
- 前 V 个实数:该token的 logits(保留2位小数)
- 接着 1 个整数:正确标签 yi
- 接着 1 个整数:稳定性影响值 ci
输出描述
输出一行,一个实数,表示最终总 loss。
结果四舍五入保留 2 位小数。
样例1
输入
4 4 4 0.8
3.00 1.00 0.00 -1.00 0 2
0.00 1.00 3.00 0.00 2 2
1.00 1.00 1.00 1.00 1 3
2.00 0.00 1.00 2.00 3 1
输出
4.75
说明
共有 4 个token,词表大小为 4,总稳定性影响上限 B=4,λ=0.8。
Token 1
logits 为 [3.00, 1.00, 0.00, −1.00],正确标签 y1 = 0。
softmax 概率:
p1≈[0.830953, 0.112457, 0.041371, 0.015219]Cross Entropy:
ℓ1≈0.185182熵:
H1≈0.595087影响值 c1 = 2。
Token 2
logits 为 [0.00, 1.00, 3.00, 0.00],正确标签 y2 = 2。
softmax 概率:
p2≈[0.040316, 0.109591, 0.809776, 0.040316]Cross Entropy:
ℓ2≈0.210998熵:
H2≈0.672078影响值 c2 = 2。
Token 3
logits 为 [1.00, 1.00, 1.00, 1.00],正确标签 y3 = 1。
softmax 概率:
p3=[0.25, 0.25, 0.25, 0.25]Cross Entropy:
ℓ3≈1.386294熵:
H3≈1.386294影响值 c3 = 3。
Token 4
logits 为 [2.00, 0.00, 1.00, 2.00],正确标签 y4 = 3。
softmax 概率:
p4≈[0.365529, 0.049217, 0.130025, 0.455229]Cross Entropy:
ℓ4≈0.791756熵:
H4≈1.172668影响值 c4 = 1。
基础 loss
Lbase≈0.185182+0.210998+1.386294+0.917576=2.700050
选择 token
各token的选择价值就是对应的 Cross Entropy:
- token 1:0.185182
- token 2:0.210998
- token 3:1.386294
- token 4:0.917576
总影响上限 B=4。
可行方案中:
- 选 token 3:价值 1.386294,影响 3
- 选 token 4:价值 0.9175766,影响 1
- 选 token 3 和 token 4:价值 2.303870,影响 4
- 选 token 2 和 token 4:价值 1.128674,影响 3
- 选 token 1 和 token 2:价值 0.396180,影响 4
最优方案为选择 token 3 和 token 4。
因此:
i∈S∑Hi=H3+H4≈1.386294+1.172668=2.558962最终总 loss
L=Lbase+λi∈S∑Hi L≈2.700050+0.8×2.558962=4.747220四舍五入到 2 位小数后为: 4.75
样例2
输入
3 3 3 0.5
2.00 1.00 0.00 0 2
0.00 2.00 1.00 1 1
1.00 1.00 1.00 2 2
输出
2.88
说明
共有 3 个token,词表大小为 3,总稳定性影响上限 B=3,λ=0.5。
Token 1
logits 为 [2.00, 1.00, 0.00],正确标签 y1 = 0。
softmax 概率:
p1≈[0.665241, 0.244728, 0.090031]Cross Entropy:
ℓ1=−log(0.665241)≈0.407606熵:
H1≈0.832396影响值 c1 = 2。
Token 2
logits 为 [0.00, 2.00, 1.00],正确标签 y2 = 1。
softmax 概率:
p2≈[0.090031, 0.665241, 0.244728]Cross Entropy:
ℓ2=−log(0.665241)≈0.407606熵:
H2≈0.832396影响值 c2 = 1。
Token 3
logits 为 [1.00, 1.00, 1.00],正确标签 y3 = 2。
softmax 概率:
p3=[0.333333, 0.333333, 0.333333]Cross Entropy:
ℓ3=−log(0.333333)≈1.098612熵:
H3≈1.098612影响值 c3 = 2。
基础 loss
Lbase=0.407606+0.407606+1.098612=1.913824
选择 token
本题中 token 的选择价值就是它的 Cross Entropy,因此:
- token 1 的价值约为 0.407606
- token 2 的价值约为 0.407606
- token 3 的价值约为 1.098612
总影响上限 B=3。
可行方案中:
- 选 token 1:价值 0.407606,影响 2
- 选 token 2:价值 0.407606,影响 1
- 选 token 3:价值 1.098612,影响 2
- 选 token 1 和 2:价值 0.815212,影响 3
- 选 token 2 和 3:价值 1.506218,影响 3
最优方案为选择 token 2 和 token 3。
因此:
i∈S∑Hi=H2+H3≈0.832396+1.098612=1.931008最终总 loss
L=Lbase+λi∈S∑Hi L≈1.913824+0.5×1.931008=2.879328四舍五入到 2 位小数后为: 2.88
提示
- 计算 softmax 时,建议使用上面的数值稳定写法。
- 对数使用自然对数。
- 题目保证最优总价值唯一,因此最终答案唯一。