#P3897. 第3题-小美的最长上升子序列
-
ID: 3261
Tried: 2
Accepted: 2
Difficulty: 7
所属公司 :
美团
时间 :2025年10月11日-开发岗
-
算法标签>数据结构树状数组
第3题-小美的最长上升子序列
解题思路
我们要求的是“按数值区分”的 LIS(最长上升子序列)的条数:同一数值序列只算 1 次,即使出现位置不同。 关键观察:
-
把数组中每个不同的数值当作一个节点 v。设
- minPos[v] 为该值首次出现的位置;
- maxPos[v] 为该值最后出现的位置。
-
两个不同值 u<v 之间存在一条可连接的边,当且仅当数组中存在某个 u 的出现位置在某个 v 的出现位置之前。这个条件等价于
minPos[u]<maxPos[v]只要最早的 u 在最晚的 v 之前,就能取到一对 i<j 使得 ai=u,aj=v。
于是问题化为:在“不同值”为点、满足上式即有边的 DAG 上,求按值递增的最长路径条数(本质不同的 LIS)。
如何高效做 DP:
-
将所有不同值按数值升序压缩为 1…m。
-
设 DP 在“值”这一维按升序进行。对每个值 v,只允许从更小的值 u<v 转移,同时要求 minPos[u]<maxPos[v]。
-
这相当于在二维平面上做支配查询:
- 第一维:值序(我们按 v 递增处理,保证 u<v)
- 第二维:minPos 要求在 [1,maxPos[v)−1] 内
-
因此可以用树状数组按 minPos 这一维维护前缀的“最佳 LIS 长度与对应方案数”。 树状数组每个节点存一对 (bestLen,ways),合并规则:
- 若新的长度更大,则覆盖并带上对应 ways;
- 若长度相等,则把 ways 相加(取模);
- 否则不变。
-
对值 v:
-
查询区间 [1,maxPos[v)−1] 的结果得到 (L,C),则
dpLen[v] = L+1,dpCnt[v]= (C>0?C:1)
(没有前驱时,空序列的方案数视为 1)
-
然后在位置 minPos[v] 处用 (dpLen[v],dpCnt[v]) 更新树状数组。
-
-
最终答案为所有 dpLen[v] 的最大值 L\* 对应的 ∑dpCnt[v](取模)。
该做法天然去重:每个不同数值仅出现一次节点,转移只关心“是否存在能衔接的位置”,而不计同值不同位置的多重选择,因此得到的是“本质不同序列”的计数。
复杂度分析
-
设 n 为长度、m 为不同值个数(m≤n)。
-
每个测试用例:
- 压缩与统计 minPos,maxPos:O(nlogn)(排序/映射)
- 每个值做 1 次树状数组查询与更新:O(mlogn)
-
由于所有测试的 n 之和 ≤2×105,总体复杂度 O(nlogn),空间 O(n)。
代码实现
Python
# -*- coding: utf-8 -*-
import sys
MOD = 998244353
class BIT:
# 维护 (bestLen, ways) 的前缀“最大长度+同长求和”
def __init__(self, n):
self.n = n
self.best = [0] * (n + 1) # 最佳长度
self.cnt = [0] * (n + 1) # 该最佳长度的方案数
def _merge_into(self, i, length, ways):
if length > self.best[i]:
self.best[i] = length
self.cnt[i] = ways % MOD
elif length == self.best[i]:
self.cnt[i] += ways
if self.cnt[i] >= MOD:
self.cnt[i] -= MOD
def update(self, pos, length, ways):
# 在 pos 位置用 (length, ways) 更新
i = pos
n = self.n
while i <= n:
self._merge_into(i, length, ways)
i += i & -i
def query(self, pos):
# 查询 [1..pos] 的结果
bestLen = 0
ways = 0
i = pos
while i > 0:
if self.best[i] > bestLen:
bestLen = self.best[i]
ways = self.cnt[i]
elif self.best[i] == bestLen:
ways += self.cnt[i]
if ways >= MOD:
ways -= MOD
i -= i & -i
return bestLen, ways
def solve():
data = list(map(int, sys.stdin.buffer.read().split()))
t = data[0]
idx = 1
out_lines = []
for _ in range(t):
n = data[idx]; idx += 1
arr = data[idx:idx+n]; idx += n
# 压缩不同值
uniq = sorted(set(arr))
m = len(uniq)
pos_first = [10**9] * m
pos_last = [0] * m
# 建立值->压缩下标
mp = {v:i for i, v in enumerate(uniq)}
# 统计各值的首次/末次出现位置(1-based)
for i, v in enumerate(arr, 1):
k = mp[v]
if i < pos_first[k]:
pos_first[k] = i
if i > pos_last[k]:
pos_last[k] = i
bit = BIT(n) # 第二维按 minPos 放在 [1..n]
dp_len = [0] * m
dp_cnt = [0] * m
# 按数值递增做 DP
for v in range(m):
mx_pos = pos_last[v] - 1 # 需要 minPos[u] <= mx_pos
if mx_pos >= 1:
L, C = bit.query(mx_pos)
else:
L, C = (0, 0)
if L == 0:
C = 1 # 空前驱
dp_len[v] = L + 1
dp_cnt[v] = C % MOD
mn_pos = pos_first[v]
bit.update(mn_pos, dp_len[v], dp_cnt[v])
Lstar = max(dp_len) if m else 0
ans = 0
for v in range(m):
if dp_len[v] == Lstar:
ans += dp_cnt[v]
if ans >= MOD:
ans -= MOD
out_lines.append(str(ans % MOD))
sys.stdout.write("\n".join(out_lines))
if __name__ == "__main__":
solve()
Java
import java.io.*;
import java.util.*;
/** ACM 风格主类,统计本质不同 LIS 数量(取模 998244353) */
public class Main {
static final int MOD = 998244353;
// 树状数组:维护 (bestLen, ways) 的前缀最大+同长求和
static class BIT {
int n;
int[] best; // 最佳长度
int[] ways; // 该最佳长度的方案数(取模)
BIT(int n) {
this.n = n;
best = new int[n + 1];
ways = new int[n + 1];
}
private void mergeInto(int i, int length, int cnt) {
if (length > best[i]) {
best[i] = length;
ways[i] = cnt % MOD;
} else if (length == best[i]) {
int x = ways[i] + cnt;
if (x >= MOD) x -= MOD;
ways[i] = x;
}
}
void update(int pos, int length, int cnt) {
for (int i = pos; i <= n; i += i & -i) {
mergeInto(i, length, cnt);
}
}
int[] query(int pos) {
int bestLen = 0;
int cnt = 0;
for (int i = pos; i > 0; i -= i & -i) {
if (best[i] > bestLen) {
bestLen = best[i];
cnt = ways[i];
} else if (best[i] == bestLen) {
int x = cnt + ways[i];
if (x >= MOD) x -= MOD;
cnt = x;
}
}
return new int[]{bestLen, cnt};
}
}
// 简单高效读入
static class FastScanner {
final InputStream in;
final byte[] buffer = new byte[1 << 16];
int ptr = 0, len = 0;
FastScanner(InputStream is) { in = is; }
int read() throws IOException {
if (ptr >= len) {
len = in.read(buffer);
ptr = 0;
if (len <= 0) return -1;
}
return buffer[ptr++];
}
int nextInt() throws IOException {
int c, sgn = 1, x = 0;
do { c = read(); } while (c <= 32);
if (c == '-') { sgn = -1; c = read(); }
for (; c > 32; c = read()) x = x * 10 + (c - '0');
return x * sgn;
}
}
public static void main(String[] args) throws Exception {
FastScanner fs = new FastScanner(System.in);
StringBuilder out = new StringBuilder();
int T = fs.nextInt();
while (T-- > 0) {
int n = fs.nextInt();
int[] a = new int[n];
for (int i = 0; i < n; i++) a[i] = fs.nextInt();
// 压缩不同值
int[] b = a.clone();
Arrays.sort(b);
int m = 0;
for (int i = 0; i < n; i++) {
if (i == 0 || b[i] != b[i - 1]) b[m++] = b[i];
}
int[] vals = Arrays.copyOf(b, m);
int[] minPos = new int[m];
int[] maxPos = new int[m];
Arrays.fill(minPos, Integer.MAX_VALUE);
Arrays.fill(maxPos, 0);
// 值 -> 下标 映射(用哈希避免反复二分)
HashMap<Integer, Integer> map = new HashMap<>(m * 2 + 3);
for (int i = 0; i < m; i++) map.put(vals[i], i);
for (int i = 0; i < n; i++) {
int k = map.get(a[i]);
int pos = i + 1; // 1-based
if (pos < minPos[k]) minPos[k] = pos;
if (pos > maxPos[k]) maxPos[k] = pos;
}
BIT bit = new BIT(n);
int[] dpLen = new int[m];
int[] dpCnt = new int[m];
for (int v = 0; v < m; v++) {
int mx = maxPos[v] - 1;
int L = 0, C = 0;
if (mx >= 1) {
int[] res = bit.query(mx);
L = res[0];
C = res[1];
}
if (L == 0) C = 1; // 空前驱
dpLen[v] = L + 1;
dpCnt[v] = C % MOD;
int mn = minPos[v];
bit.update(mn, dpLen[v], dpCnt[v]);
}
int Lstar = 0;
for (int x : dpLen) if (x > Lstar) Lstar = x;
int ans = 0;
for (int v = 0; v < m; v++) if (dpLen[v] == Lstar) {
ans += dpCnt[v];
if (ans >= MOD) ans -= MOD;
}
out.append(ans % MOD).append('\n');
}
System.out.print(out);
}
}
C++
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
/* 树状数组:维护 (bestLen, ways) 的前缀最大+同长求和 */
struct BIT {
int n;
vector<int> best; // 最佳长度
vector<int> ways; // 方案数(取模)
BIT(int n=0): n(n), best(n+1, 0), ways(n+1, 0) {}
void init(int n_) { n = n_; best.assign(n+1, 0); ways.assign(n+1, 0); }
inline void mergeInto(int i, int len, int cnt){
if (len > best[i]) {
best[i] = len;
ways[i] = cnt % MOD;
} else if (len == best[i]) {
int x = ways[i] + cnt;
if (x >= MOD) x -= MOD;
ways[i] = x;
}
}
void update(int pos, int len, int cnt){
for (int i = pos; i <= n; i += i & -i) {
mergeInto(i, len, cnt);
}
}
pair<int,int> query(int pos){
int bestLen = 0, cnt = 0;
for (int i = pos; i > 0; i -= i & -i) {
if (best[i] > bestLen) {
bestLen = best[i];
cnt = ways[i];
} else if (best[i] == bestLen) {
int x = cnt + ways[i];
if (x >= MOD) x -= MOD;
cnt = x;
}
}
return {bestLen, cnt};
}
};
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;
if(!(cin >> T)) return 0;
while (T--) {
int n; cin >> n;
vector<int> a(n);
for (int i = 0; i < n; ++i) cin >> a[i];
// 压缩不同值
vector<int> vals = a;
sort(vals.begin(), vals.end());
vals.erase(unique(vals.begin(), vals.end()), vals.end());
int m = (int)vals.size();
vector<int> minPos(m, INT_MAX), maxPos(m, 0);
// 值 -> 下标 映射(lower_bound 也可)
unordered_map<int,int> id;
id.reserve(m*2+3);
for (int i = 0; i < m; ++i) id[vals[i]] = i;
for (int i = 0; i < n; ++i) {
int k = id[a[i]];
int p = i + 1; // 1-based
if (p < minPos[k]) minPos[k] = p;
if (p > maxPos[k]) maxPos[k] = p;
}
BIT bit(n);
vector<int> dpLen(m, 0), dpCnt(m, 0);
for (int v = 0; v < m; ++v) {
int mx = maxPos[v] - 1;
int L = 0, C = 0;
if (mx >= 1) {
auto res = bit.query(mx);
L = res.first;
C = res.second;
}
if (L == 0) C = 1; // 空前驱
dpLen[v] = L + 1;
dpCnt[v] = C % MOD;
int mn = minPos[v];
bit.update(mn, dpLen[v], dpCnt[v]);
}
int Lstar = 0;
for (int x: dpLen) Lstar = max(Lstar, x);
int ans = 0;
for (int v = 0; v < m; ++v) if (dpLen[v] == Lstar) {
ans += dpCnt[v];
if (ans >= MOD) ans -= MOD;
}
cout << (ans % MOD) << "\n";
}
return 0;
}
题目内容
给定一个长度为n的数组a,请你计算其中所有本质不同最长上升子序列的数量。由于答案可能非常大,请将结果对998244353取模后输出。
**[子序列]**如果一个序列可以通过删除原序列的若干(可能为零)元素得到,则称前者为后者的一个子序列。
**[上升序列]**我们称s为一个上升序列,当且仅当其中任意相邻元素满足si<si+1
**[本质不同序列]**若两个序列的长度不同,或在某一位置上的元素不同,则认为它们是不同的序列。
输入描述
每个测试文件均包含多组测试数据。
第一行输入一个整数T(1≤T≤104)代表数据组数,每组测试数据描述如下:
第一行输入一个正整数n(1≤n≤2⋅105),代表数组a的长度。
第二行输入n个正整数a1,a2,...,an(1≤ai≤109),代表a中的元素。
除此之外,保证单个测试文件的n之和不超过2×105。
输出描述
对于每组测试数据,输出一行一个整数,代表数组a中本质不同 最长上升了序列的数量,对998244353取模后的结果。
样例1
输入
5
1
1
2
1 1
6
1 1 4 5 1 4
7
3 2 2 4 5 8 7
7
1 9 1 9 8 1 0
输出
1
1
1
4
2
说明
在最后一组测试数据中,两个本质不同最长上升子序列分别为{1,9}和{1,8}