这道题要做两件事:
logits(未归一化分数)做 Softmax 变换,得到每一类的概率分布 probs;probs 和 One-hot 标签 labels 计算 交叉熵损失(Cross Entropy Loss)。关键点有两个:
exp(logits[i]) 可能溢出;下面按步骤拆开讲。
Softmax 定义为:
pi=∑jezjezi直接计算 exp(z_i) 可能出现数值溢出,因此常见做法是减去最大值:
因为对所有 i 来说:
$$\frac{e^{z_i}}{\sum_j e^{z_j}} = \frac{e^{z_i - C}}{\sum_j e^{z_j - C}}$$对任意常数 C 都成立,我们取 C=max(z),就能避免指数过大。
代码对应为:
max_logit = max(logits) # 最大值
exp_shifted = [math.exp(x - max_logit) for x in logits] # 先减再 exp
sum_exp = sum(exp_shifted) # 分母
probs = [v / sum_exp for v in exp_shifted] # Softmax 概率
最终 probs 是一个与 logits 等长的列表,且满足:
交叉熵定义为:
L=−i=1∑nyilog(pi)在本题标签是 One-hot 的前提下,只有真实类别那一项 yk=1,其它都是 0,于是损失就简化为:
L=−log(pk)因此在代码里可以简单写成:
eps = 1e-15 # 避免 log(0)
loss = 0.0
for y, p in zip(labels, probs):
if y == 1:
loss -= math.log(max(p, eps))
这里 max(p, eps) 是为了防止 p 特别小导致 log(0) 或数值问题。
import math
from typing import List, Tuple
class Solution:
def softmax_cross_entropy(self, logits: List[float], labels: List[int]) -> Tuple[List[float], float]:
# ---------- 1. 计算 Softmax 概率(数值稳定版) ----------
# 减去最大值,防止 exp 溢出
max_logit = max(logits)
exp_shifted = [math.exp(x - max_logit) for x in logits]
sum_exp = sum(exp_shifted)
probs = [v / sum_exp for v in exp_shifted]
# ---------- 2. 计算交叉熵损失 ----------
# 损失: L = -sum(y_i * log(p_i))
# 在 one-hot 情况下,只会取到 y_i=1 对应的那一项
eps = 1e-15
loss = 0.0
for y, p in zip(labels, probs):
if y == 1:
# 防止 log(0)
loss -= math.log(max(p, eps))
return probs, loss
你需要实现一个函数,计算 Softmax 概率分布和 交叉熵损失(Cross Entropy Loss)。
给定一组未归一化的对数概率(logits)列表 logits 和对应的 One-hot 标签列表 labels。
请你计算并返回:
probs。loss。设 z 为 logits 列表,y 为 labels 列表,长度均为 n。
Softmax 概率 pi 的计算公式:
$$p_i = \text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{n} e^{z_j}}$$交叉熵损失 L 的计算公式:
L=−i=1∑nyilog(pi)(注意:对数底数为 natural logarithm, 即 ln)
logits: 包含 n 个浮点数的列表。labels: 包含 n 个整数(0 或 1)的列表,且其中恰好有一个元素为 1(One-hot 编码)。返回一个元组 (probs, loss):
probs: 包含 n 个浮点数的列表,表示 Softmax 概率。loss: 一个浮点数,表示交叉熵损失。输入:
logits = [2.0, 1.0, 0.1]
labels = [0, 1, 0]
输出:
probs = [0.6590, 0.2424, 0.0986]
loss = 1.4170
解释:
计算 Softmax 分母:$e^2 + e^1 + e^{0.1} \approx 7.389 + 2.718 + 1.105 = 11.212$
计算概率: p_0≈7.389/11.212≈0.6590
p_1≈2.718/11.212≈0.2424
p_2≈1.105/11.212≈0.0986
计算损失:由于标签是 [0, 1, 0],只有第 2 项参与计算。
L=−1⋅ln(0.2424)≈1.4170
labels 长度等于 logits 长度,且严格符合 One-hot 编码(只有一个 1,其余为 0)。