#P3713. 第3题-大模型分词
-
1000ms
Tried: 1086
Accepted: 237
Difficulty: 5
所属公司 :
华为
时间 :2025年9月17日-AI岗
-
算法标签>动态规划
第3题-大模型分词
解题思路
本题要求把一段不含空格的小写英文字符串切分为若干“词元”,并最大化: 所有词元的置信度分数之和 + 相邻词元之间的转移加分之和。 若无法完全用词表中的词覆盖整句,则输出 0。
核心算法:动态规划(DP)
-
设原串为
text,长度为L。 -
用哈希表保存词表
score[w](词w的置信度分数)与转移加分bonus[u][v](从词u到v的转移分)。 -
令
dp[i]表示能切到下标i(前缀text[0:i])的所有方案中,“以某个词结尾”的最优分数集合。- 具体用:
dp[i]是一个映射{last_word -> best_score},表示前缀text[0:i]且最后一个词为last_word的最优总分。
- 具体用:
-
转移:枚举以
i结尾的词w = text[j:i](w必须在词表中),再看前一段的结尾:- 若
j == 0:这是首个词,dp[i][w] = max(dp[i][w], score[w])。 - 若
j > 0:需要从dp[j]的每个候选u转移,dp[i][w] = max(dp[i][w], dp[j][u] + score[w] + bonus[u][w](若无则为0))。
- 若
-
答案是
dp[L]中所有值的最大值;若dp[L]为空(无法完整切分),输出 0。 -
为降复杂度,可预先计算词表中最长词长
maxLen,只枚举长度不超过maxLen的后缀。
为什么不能用贪心?
- 因为转移加分依赖于相邻两个词,局部最优(当前词分高)并不一定带来全局最优(可能与下一词的转移分差)。因此需要 DP 统筹考虑上下文。
复杂度分析
- 设原串长度
L ≤ 100,词表大小N ≤ 100,最长词长maxLen ≤ L。 - 外层位置
i共L次;对每个i,仅尝试长度 ≤maxLen的后缀,近似O(maxLen); - 对每个合法后缀
w,需遍历dp[j]的状态数(不超过词表大小N)。 - 因此时间复杂度近似为 O(L × maxLen × N),在给定数据范围内完全可行。
- 空间复杂度:
dp存每个位置最多N个词状态,故 O(L × N)。
代码实现
Python
import sys
from ast import literal_eval
def solve(text, vocab_list, trans_list):
# 构建词表分数
score = {}
max_len = 0
for w, p in vocab_list:
score[w] = p
if len(w) > max_len:
max_len = len(w)
# 构建转移加分表
bonus = {}
for u, v, x in trans_list:
if u not in bonus:
bonus[u] = {}
bonus[u][v] = x
L = len(text)
# dp[i]: dict {last_word: best_score} 覆盖 text[0:i]
dp = [dict() for _ in range(L + 1)]
# 枚举前缀终点 i
for i in range(1, L + 1):
# 只需要尝试不超过词表最长长度的后缀
up = min(max_len, i)
for l in range(1, up + 1):
w = text[i - l:i]
if w not in score:
continue
base = score[w]
j = i - l
if j == 0:
# 首词
dp[i][w] = max(dp[i].get(w, float("-inf")), base)
else:
if not dp[j]:
continue
# 从所有可能的前一词转移
for u, val in dp[j].items():
add = bonus.get(u, {}).get(w, 0)
cand = val + base + add
if cand > dp[i].get(w, float("-inf")):
dp[i][w] = cand
if not dp[L]:
return 0
return max(dp[L].values())
def main():
data = sys.stdin.read().strip().splitlines()
idx = 0
text = data[idx].strip(); idx += 1
n = literal_eval(data[idx].strip()); idx += 1
vocab_list = []
for _ in range(n):
parts = data[idx].strip().split()
w = parts[0]
p = literal_eval(parts[1])
vocab_list.append((w, p))
idx += 1
m = literal_eval(data[idx].strip()); idx += 1
trans_list = []
for _ in range(m):
parts = data[idx].strip().split()
u, v = parts[0], parts[1]
x = literal_eval(parts[2])
trans_list.append((u, v, x))
idx += 1
ans = solve(text, vocab_list, trans_list)
print(ans)
if __name__ == "__main__":
main()
Java
import java.io.*;
import java.util.*;
public class Main {
static int solve(String text, List<String> words, List<Integer> scores,
List<String[]> trans, List<Integer> adds) {
int L = text.length();
// 构建词表分数与最长词长
Map<String, Integer> score = new HashMap<>();
int maxLen = 0;
for (int i = 0; i < words.size(); i++) {
score.put(words.get(i), scores.get(i));
maxLen = Math.max(maxLen, words.get(i).length());
}
// 构建转移加分表 bonus[u][v] = x
Map<String, Map<String, Integer>> bonus = new HashMap<>();
for (int i = 0; i < trans.size(); i++) {
String u = trans.get(i)[0];
String v = trans.get(i)[1];
int x = adds.get(i);
if (!bonus.containsKey(u)) bonus.put(u, new HashMap<String, Integer>());
bonus.get(u).put(v, x);
}
// dp[i]: 以某个词结尾覆盖 text[0:i] 的最优分数集合
@SuppressWarnings("unchecked")
HashMap<String, Integer>[] dp = new HashMap[L + 1];
for (int i = 0; i <= L; i++) dp[i] = new HashMap<String, Integer>();
for (int i = 1; i <= L; i++) {
int up = Math.min(maxLen, i);
for (int len = 1; len <= up; len++) {
String w = text.substring(i - len, i);
if (!score.containsKey(w)) continue;
int base = score.get(w);
int j = i - len;
if (j == 0) {
int prev = dp[i].getOrDefault(w, Integer.MIN_VALUE / 4);
dp[i].put(w, Math.max(prev, base));
} else {
if (dp[j].isEmpty()) continue;
for (Map.Entry<String, Integer> e : dp[j].entrySet()) {
String u = e.getKey();
int val = e.getValue();
int add = 0;
if (bonus.containsKey(u)) {
add = bonus.get(u).getOrDefault(w, 0);
}
int cand = val + base + add;
int prev = dp[i].getOrDefault(w, Integer.MIN_VALUE / 4);
if (cand > prev) dp[i].put(w, cand);
}
}
}
}
if (dp[L].isEmpty()) return 0;
int ans = Integer.MIN_VALUE / 4;
for (int v : dp[L].values()) ans = Math.max(ans, v);
return ans;
}
public static void main(String[] args) throws Exception {
// 使用行读取 + 简单解析(替换字符/输入流组合)
BufferedReader br = new BufferedReader(new InputStreamReader(System.in, "UTF-8"));
String text = br.readLine().trim();
int n = Integer.parseInt(br.readLine().trim());
List<String> words = new ArrayList<>();
List<Integer> scores = new ArrayList<>();
for (int i = 0; i < n; i++) {
String line = br.readLine().trim();
String[] parts = line.split("\\s+");
words.add(parts[0]);
scores.add(Integer.parseInt(parts[1]));
}
int m = Integer.parseInt(br.readLine().trim());
List<String[]> trans = new ArrayList<>();
List<Integer> adds = new ArrayList<>();
for (int i = 0; i < m; i++) {
String line = br.readLine().trim();
String[] parts = line.split("\\s+");
trans.add(new String[]{parts[0], parts[1]});
adds.add(Integer.parseInt(parts[2]));
}
int ans = solve(text, words, scores, trans, adds);
System.out.println(ans);
}
}
C++
#include <bits/stdc++.h>
using namespace std;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
string text;
if (!getline(cin, text)) return 0;
string line;
// 读取 n
getline(cin, line);
int n = stoi(line);
// 词表
unordered_map<string, int> score;
int maxLen = 0;
for (int i = 0; i < n; i++) {
getline(cin, line);
// 简单分割(替换字符+输入流)
stringstream ss(line);
string w; int p;
ss >> w >> p;
score[w] = p;
maxLen = max(maxLen, (int)w.size());
}
// 读取 m
getline(cin, line);
int m = stoi(line);
// 转移加分表
unordered_map<string, unordered_map<string, int>> bonus;
for (int i = 0; i < m; i++) {
getline(cin, line);
stringstream ss(line);
string u, v; int x;
ss >> u >> v >> x;
bonus[u][v] = x;
}
int L = (int)text.size();
// dp[i]: map<最后一个词, 最优总分>
vector<unordered_map<string, int>> dp(L + 1);
for (int i = 1; i <= L; i++) {
int up = min(maxLen, i);
for (int len = 1; len <= up; len++) {
string w = text.substr(i - len, len);
auto itw = score.find(w);
if (itw == score.end()) continue;
int base = itw->second;
int j = i - len;
if (j == 0) {
auto it = dp[i].find(w);
if (it == dp[i].end()) dp[i][w] = base;
else it->second = max(it->second, base);
} else {
if (dp[j].empty()) continue;
for (auto &pr : dp[j]) {
const string &u = pr.first;
int val = pr.second;
int add = 0;
auto itu = bonus.find(u);
if (itu != bonus.end()) {
auto itv = itu->second.find(w);
if (itv != itu->second.end()) add = itv->second;
}
int cand = val + base + add;
auto it = dp[i].find(w);
if (it == dp[i].end()) dp[i][w] = cand;
else it->second = max(it->second, cand);
}
}
}
}
if (dp[L].empty()) {
cout << 0 << "\n";
} else {
int ans = INT_MIN / 4;
for (auto &pr : dp[L]) ans = max(ans, pr.second);
cout << ans << "\n";
}
return 0;
}
题目内容
您正在为一种罕见的语言构建一个专用的大语言模型。由于训练样本缺失,传统BPE等标准的分词器效果不佳,使得大模型推理生成的句子不理想。
幸运的是,一位语言学家为罕见的语言的已知词根和词缀(我们统称为“词元"或“Token”)都标注了一个“置信度”分数,这个分数代表了该词元作为一个“独立单位”的合理性,同时,语言学家还总结出了一个转移分数表,表示当前词元选择对下一个词元"置信度"的影响。
您的任务是设计并实现一个“最优分词器”,它能将输入的罕见语言句子(一个不含空格的英文小写字符多也串)切分成一系列词元,并使得所有词元的置信度分数之和达到最大,从而帮助大语言模型在后续处理中能够输出更合理的句了
输入描述
第一行输入待分词的字符串 text,假设只包含英文小写字母;
接着输入词典词条数 n;
然后输入n行,每一行包含一个单词和对应的分值,以空格分隔。
第 n+3 行为转移分数的个数 m。
随后m行为转移分数数据。包括起始词、下一个词、转移分数加分X。以空格分隔。
参数范围说明:
- 0<len(text)≤100
- −100≤ 词典中单词的得分 ≤100
- −100≤ 词汇表置信度分数P≤100
- 输入的字符串都是英文小写字母
- 0<词汇表大小 n≤100
输出描述
返回最高的分词得分,若根据已知间汇表无法拆分则返回0、我们约定若切分成一系列词元中含有不在已知词汇表中的词,则最终得分为0。
样例1
输入
applepie
2
pen 3
apple 10
2
pen apple 5
pie apple 2
输出
0
说明
text中的字符不能和词典词条匹配出切分结果,无法计算得分。
样例2
输入
goodeats
4
good 15
goo 12
deats 14
eats 10
1
good eats -5
输出
26
说明
切分为["good","eats"] 的总分=15+10−5=20;
切分为 ["goo","deats"] 的总分=12+14=26;
所以最大得分为 26。