#P3521. 第2题-上升子序列
-
1000ms
Tried: 35
Accepted: 8
Difficulty: 7
所属公司 :
京东
时间 :2025年8月30日
-
算法标签>动态规划
第2题-上升子序列
思路总览
给定是一个 1∼n 的排列。设全局最长上升子序列(LIS)长度为 L。 对每个位置 i,我们只需统计:有多少条 LIS 会“经过”元素 ai。
经典分解法:
- 令
Lend[i]:以位置i结尾的最长上升子序列长度; - 令
CntL[i]:达到Lend[i]的方案数; - 令
Rbeg[i]:从位置i开始(包含i)的最长上升子序列长度(向右); - 令
CntR[i]:达到Rbeg[i]的方案数。
则:
- 全局 LIS 长度
L = max_i Lend[i]; - 一条 LIS 经过
i当且仅当Lend[i] + Rbeg[i] - 1 = L; - 此时经过
i的 LIS 条数 =CntL[i] * CntR[i](左段和右段独立选择,拼接中间的a_i,不会重叠计数)。
因为是排列(元素互不相同),无需处理相等元素的去重问题。
关键:如何在 O(nlogn) 内得到四个数组?
用“最大值 + 计数”的树状数组(Fenwick)。每个树状数组节点保存一对 (bestLen, ways),合并规则:
- 若
lenA > lenB取 A;若<取 B;若=则ways = (waysA + waysB) mod M; - 约定:查询得到的
bestLen = 0时,方案数记为1(空序列)。
计算方法:
-
从左到右:
- 对值域
[1, a_i-1]做查询,得(len, cnt); Lend[i] = len + 1,CntL[i] = (len==0 ? 1 : cnt);- 用
(Lend[i], CntL[i])更新位置a_i。
- 对值域
-
从右到左(要查“更大值”的区间):用值域镜像
rev(x)=n-x+1。 查询rev(a_i)-1即原值域中的(a_i+1..n)。- 得
(len, cnt)后:Rbeg[i] = len + 1,CntR[i] = (len==0 ? 1 : cnt); - 用
(Rbeg[i], CntR[i])更新位置rev(a_i)。
- 得
最终答案:
ans[i] = (Lend[i] + Rbeg[i] - 1 == L) ? (CntL[i] * CntR[i] % MOD) : 0
其中 MOD = 998244353。
复杂度
- 两次树状数组(正/反向)各
n次操作,单次O(log n); - 总时间:
O(n log n);空间:O(n)。
Python实现
import sys
sys.setrecursionlimit(1 << 25)
MOD = 998244353
class BIT:
# 维护前缀区间的 (bestLen, ways)
def __init__(self, n):
self.n = n
self.best = [0]*(n+1)
self.ways = [0]*(n+1)
def _merge(self, la, ca, lb, cb):
if la > lb:
return la, ca
if la < lb:
return lb, cb
return la, (ca + cb) % MOD
def upd(self, i, l, c):
# 用 (l,c) 更新单点 i
while i <= self.n:
if self.best[i] < l:
self.best[i] = l
self.ways[i] = c % MOD
elif self.best[i] == l:
self.ways[i] += c
if self.ways[i] >= MOD:
self.ways[i] -= MOD
i += i & -i
def qry(self, i):
bl, wc = 0, 0
while i > 0:
bl, wc = self._merge(bl, wc, self.best[i], self.ways[i])
i -= i & -i
return bl, wc
def main():
data = sys.stdin.read().strip().split()
n = int(data[0])
a = list(map(int, data[1:1+n]))
Lend = [0]*n
Rbeg = [0]*n
CntL = [0]*n
CntR = [0]*n
bit = BIT(n)
L = 0
for i in range(n):
v = a[i]
bl, wc = bit.qry(v-1)
Lend[i] = bl + 1
CntL[i] = 1 if bl == 0 else wc
bit.upd(v, Lend[i], CntL[i])
if Lend[i] > L:
L = Lend[i]
bit2 = BIT(n)
for i in range(n-1, -1, -1):
v = a[i]
rv = n - v + 1
bl, wc = bit2.qry(rv-1) # 相当于原值域 (v+1..n)
Rbeg[i] = bl + 1
CntR[i] = 1 if bl == 0 else wc
bit2.upd(rv, Rbeg[i], CntR[i])
out = []
for i in range(n):
if Lend[i] + Rbeg[i] - 1 == L:
out.append(str((CntL[i] * CntR[i]) % MOD))
else:
out.append("0")
print("\n".join(out))
if __name__ == "__main__":
main()
Java实现
import java.io.*;
import java.util.*;
public class Main {
static final long MOD = 998244353L;
static class BIT {
int n;
int[] best; // 最佳长度
long[] ways; // 方案数(取模)
BIT(int n) {
this.n = n;
best = new int[n+2];
ways = new long[n+2];
}
void update(int idx, int len, long cnt) {
while (idx <= n) {
if (best[idx] < len) {
best[idx] = len;
ways[idx] = cnt % MOD;
} else if (best[idx] == len) {
ways[idx] += cnt;
if (ways[idx] >= MOD) ways[idx] -= MOD;
}
idx += idx & -idx;
}
}
// 返回 (bestLen, ways)
int qLen; long qCnt;
void query(int idx) {
int bl = 0; long wc = 0;
while (idx > 0) {
if (best[idx] > bl) {
bl = best[idx];
wc = ways[idx];
} else if (best[idx] == bl) {
wc += ways[idx];
if (wc >= MOD) wc -= MOD;
}
idx -= idx & -idx;
}
qLen = bl; qCnt = wc;
}
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String s = br.readLine();
while (s != null && s.trim().isEmpty()) s = br.readLine();
int n = Integer.parseInt(s.trim());
int[] a = new int[n];
StringTokenizer st = new StringTokenizer(br.readLine());
for (int i = 0; i < n; i++) a[i] = Integer.parseInt(st.nextToken());
int[] Lend = new int[n], Rbeg = new int[n];
long[] CntL = new long[n], CntR = new long[n];
BIT bit = new BIT(n);
int L = 0;
for (int i = 0; i < n; i++) {
int v = a[i];
bit.query(v - 1);
int bl = bit.qLen;
long wc = bit.qCnt;
Lend[i] = bl + 1;
CntL[i] = (bl == 0 ? 1 : wc);
bit.update(v, Lend[i], CntL[i]);
if (Lend[i] > L) L = Lend[i];
}
BIT bit2 = new BIT(n);
for (int i = n - 1; i >= 0; i--) {
int v = a[i];
int rv = n - v + 1;
bit2.query(rv - 1); // 对应原值域 (v+1..n)
int bl = bit2.qLen;
long wc = bit2.qCnt;
Rbeg[i] = bl + 1;
CntR[i] = (bl == 0 ? 1 : wc);
bit2.update(rv, Rbeg[i], CntR[i]);
}
StringBuilder sb = new StringBuilder();
for (int i = 0; i < n; i++) {
if (Lend[i] + Rbeg[i] - 1 == L) {
long ans = (CntL[i] * CntR[i]) % MOD;
sb.append(ans).append('\n');
} else {
sb.append('0').append('\n');
}
}
System.out.print(sb.toString());
}
}
C++ 实现
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
struct BIT {
int n;
vector<int> best; // 最佳长度
vector<int> ways; // 方案数(取模)
BIT(int n=0): n(n), best(n+1,0), ways(n+1,0) {}
void init(int _n){ n=_n; best.assign(n+1,0); ways.assign(n+1,0); }
// 单点更新 (len, cnt)
void upd(int i, int len, int cnt){
while(i<=n){
if(best[i] < len){
best[i] = len;
ways[i] = cnt;
}else if(best[i] == len){
ways[i] += cnt;
if(ways[i] >= MOD) ways[i] -= MOD;
}
i += i & -i;
}
}
// 前缀查询,返回 (bestLen, ways)
pair<int,int> qry(int i){
int bl=0, wc=0;
while(i>0){
if(best[i] > bl){
bl = best[i];
wc = ways[i];
}else if(best[i] == bl){
wc += ways[i];
if(wc >= MOD) wc -= MOD;
}
i -= i & -i;
}
return {bl, wc};
}
};
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
if(!(cin>>n)) return 0;
vector<int> a(n);
for(int i=0;i<n;i++) cin>>a[i];
vector<int> Lend(n), Rbeg(n);
vector<int> CntL(n), CntR(n);
BIT bit(n);
int L = 0;
for(int i=0;i<n;i++){
int v = a[i];
auto p = bit.qry(v-1);
int bl = p.first, wc = p.second;
Lend[i] = bl + 1;
CntL[i] = (bl==0 ? 1 : wc);
bit.upd(v, Lend[i], CntL[i]);
L = max(L, Lend[i]);
}
BIT bit2(n);
for(int i=n-1;i>=0;i--){
int v = a[i];
int rv = n - v + 1;
auto p = bit2.qry(rv-1); // 原值域 (v+1..n)
int bl = p.first, wc = p.second;
Rbeg[i] = bl + 1;
CntR[i] = (bl==0 ? 1 : wc);
bit2.upd(rv, Rbeg[i], CntR[i]);
}
for(int i=0;i<n;i++){
if(Lend[i] + Rbeg[i] - 1 == L){
long long ans = 1LL * CntL[i] * CntR[i] % MOD;
cout << ans << "\n";
}else{
cout << 0 << "\n";
}
}
return 0;
}
题目内容
给定一个长为n 的排列 {a}。排列是指,{a} 包含1~n中的所有元素恰好一次。
回顾一下,{a} 的一个子序列是指从{a}中将若干元素(不一定连续)提取出来而不改变相对位置形成 的序列。 对于 {a} 的一个子序列 {b},如果 {b} 中的元素单调递增,就称{b}是{a}的一个上升子序列。 对于{a}的一个上升子序列{c},如果找不到比{c}元素个数更多的上升子序列了,就称 {c} 是 {a} 的 一个最长上升子序列。显然,{a} 可以有很多个最长上升子序列。 请你对每个i∈[1,n],求出有多少个不同的{a}的最长上升子序列包含ai。
输入描述
第一行一个正整数n
第二行n个正整数,以空格分隔,表示a1,a2,...,an
保证 {a} 是一个 1~n 的排列。 1≤n≤2×105
输出描述
输出 n 行,每一行一个非负整数。第i行的输出表示包含a 的最长上升子序列个数。
因为答案可能很大,所以你只需要输出答案对998244353取模的结果。
样例1
输入
5
3 1 4 2 5
输出
1
2
2
1
3
说明
对于输入的排列,其最长上升子序列长度为3,所有不同的最长上升了序列如下:
- 1,2,5。
- 1,4,5。
- 3,4,5。
以 a1和 a2为例。 包含了a1=3的最长上升子序列有1个,故答案为1。 包含了a2=1的最长上升子序列有2 个,故答案为2。