P4490.Softmax 与交叉熵损失
题目描述:
你需要实现一个函数,计算 Softmax 概率分布和 交叉熵损失(Cross Entropy Loss)。
给定一组未归一化的对数概率(logits)列表 logits 和对应的 One-hot 标签列表 labels。
请你计算并返回:
- 经过 Softmax 变换后的概率列表
probs。
- 模型预测与真实标签之间的交叉熵损失值
loss。
数学公式
设 z 为 logits 列表,y 为 labels 列表,长度均为 n。
-
Softmax 概率 pi 的计算公式:
pi=Softmax(zi)=∑j=1nezjezi
-
交叉熵损失 L 的计算公式:
L=−i=1∑nyilog(pi)
(注意:对数底数为 natural logarithm, 即 ln)
输入参数:
logits: 包含 n 个浮点数的列表。
labels: 包含 n 个整数(0 或 1)的列表,且其中恰好有一个元素为 1(One-hot 编码)。
返回值:
-
返回一个元组 (probs, loss):
probs: 包含 n 个浮点数的列表,表示 Softmax 概率。
loss: 一个浮点数,表示交叉熵损失。
示例 1:
输入:
logits = [2.0, 1.0, 0.1]
labels = [0, 1, 0]
输出:
probs = [0.6590, 0.2424, 0.0986]
loss = 1.4170
解释:
-
计算 Softmax 分母:e2+e1+e0.1≈7.389+2.718+1.105=11.212
-
计算概率:
p_0≈7.389/11.212≈0.6590
p_1≈2.718/11.212≈0.2424
p_2≈1.105/11.212≈0.0986
-
计算损失:由于标签是 [0, 1, 0],只有第 2 项参与计算。
L=−1⋅ln(0.2424)≈1.4170
提示
- 1≤n≤1000
- −100≤logits[i]≤100
labels 长度等于 logits 长度,且严格符合 One-hot 编码(只有一个 1,其余为 0)。