#P3521. 第2题-上升子序列
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 35
            Accepted: 8
            Difficulty: 7
            
          
          
          
                       所属公司 : 
                              京东
                                
            
                        
              时间 :2025年8月30日
                              
                      
          
 
- 
                        算法标签>动态规划          
 
第2题-上升子序列
思路总览
给定是一个 1∼n 的排列。设全局最长上升子序列(LIS)长度为 L。 对每个位置 i,我们只需统计:有多少条 LIS 会“经过”元素 ai。
经典分解法:
- 令 
Lend[i]:以位置i结尾的最长上升子序列长度; - 令 
CntL[i]:达到Lend[i]的方案数; - 令 
Rbeg[i]:从位置i开始(包含i)的最长上升子序列长度(向右); - 令 
CntR[i]:达到Rbeg[i]的方案数。 
则:
- 全局 LIS 长度 
L = max_i Lend[i]; - 一条 LIS 经过 
i当且仅当Lend[i] + Rbeg[i] - 1 = L; - 此时经过 
i的 LIS 条数 =CntL[i] * CntR[i](左段和右段独立选择,拼接中间的a_i,不会重叠计数)。 
因为是排列(元素互不相同),无需处理相等元素的去重问题。
关键:如何在 O(nlogn) 内得到四个数组?
用“最大值 + 计数”的树状数组(Fenwick)。每个树状数组节点保存一对 (bestLen, ways),合并规则:
- 若 
lenA > lenB取 A;若<取 B;若=则ways = (waysA + waysB) mod M; - 约定:查询得到的 
bestLen = 0时,方案数记为1(空序列)。 
计算方法:
- 
从左到右:
- 对值域 
[1, a_i-1]做查询,得(len, cnt); Lend[i] = len + 1,CntL[i] = (len==0 ? 1 : cnt);- 用 
(Lend[i], CntL[i])更新位置a_i。 
 - 对值域 
 - 
从右到左(要查“更大值”的区间):用值域镜像
rev(x)=n-x+1。 查询rev(a_i)-1即原值域中的(a_i+1..n)。- 得 
(len, cnt)后:Rbeg[i] = len + 1,CntR[i] = (len==0 ? 1 : cnt); - 用 
(Rbeg[i], CntR[i])更新位置rev(a_i)。 
 - 得 
 
最终答案:
ans[i] = (Lend[i] + Rbeg[i] - 1 == L) ? (CntL[i] * CntR[i] % MOD) : 0
其中 MOD = 998244353。
复杂度
- 两次树状数组(正/反向)各 
n次操作,单次O(log n); - 总时间:
O(n log n);空间:O(n)。 
Python实现
import sys
sys.setrecursionlimit(1 << 25)
MOD = 998244353
class BIT:
    # 维护前缀区间的 (bestLen, ways)
    def __init__(self, n):
        self.n = n
        self.best = [0]*(n+1)
        self.ways = [0]*(n+1)
    def _merge(self, la, ca, lb, cb):
        if la > lb: 
            return la, ca
        if la < lb:
            return lb, cb
        return la, (ca + cb) % MOD
    def upd(self, i, l, c):
        # 用 (l,c) 更新单点 i
        while i <= self.n:
            if self.best[i] < l:
                self.best[i] = l
                self.ways[i] = c % MOD
            elif self.best[i] == l:
                self.ways[i] += c
                if self.ways[i] >= MOD:
                    self.ways[i] -= MOD
            i += i & -i
    def qry(self, i):
        bl, wc = 0, 0
        while i > 0:
            bl, wc = self._merge(bl, wc, self.best[i], self.ways[i])
            i -= i & -i
        return bl, wc
def main():
    data = sys.stdin.read().strip().split()
    n = int(data[0])
    a = list(map(int, data[1:1+n]))
    Lend = [0]*n
    Rbeg = [0]*n
    CntL = [0]*n
    CntR = [0]*n
    bit = BIT(n)
    L = 0
    for i in range(n):
        v = a[i]
        bl, wc = bit.qry(v-1)
        Lend[i] = bl + 1
        CntL[i] = 1 if bl == 0 else wc
        bit.upd(v, Lend[i], CntL[i])
        if Lend[i] > L:
            L = Lend[i]
    bit2 = BIT(n)
    for i in range(n-1, -1, -1):
        v = a[i]
        rv = n - v + 1
        bl, wc = bit2.qry(rv-1)  # 相当于原值域 (v+1..n)
        Rbeg[i] = bl + 1
        CntR[i] = 1 if bl == 0 else wc
        bit2.upd(rv, Rbeg[i], CntR[i])
    out = []
    for i in range(n):
        if Lend[i] + Rbeg[i] - 1 == L:
            out.append(str((CntL[i] * CntR[i]) % MOD))
        else:
            out.append("0")
    print("\n".join(out))
if __name__ == "__main__":
    main()
Java实现
import java.io.*;
import java.util.*;
public class Main {
    static final long MOD = 998244353L;
    static class BIT {
        int n;
        int[] best;     // 最佳长度
        long[] ways;    // 方案数(取模)
        BIT(int n) {
            this.n = n;
            best = new int[n+2];
            ways = new long[n+2];
        }
        void update(int idx, int len, long cnt) {
            while (idx <= n) {
                if (best[idx] < len) {
                    best[idx] = len;
                    ways[idx] = cnt % MOD;
                } else if (best[idx] == len) {
                    ways[idx] += cnt;
                    if (ways[idx] >= MOD) ways[idx] -= MOD;
                }
                idx += idx & -idx;
            }
        }
        // 返回 (bestLen, ways)
        int qLen; long qCnt;
        void query(int idx) {
            int bl = 0; long wc = 0;
            while (idx > 0) {
                if (best[idx] > bl) {
                    bl = best[idx];
                    wc = ways[idx];
                } else if (best[idx] == bl) {
                    wc += ways[idx];
                    if (wc >= MOD) wc -= MOD;
                }
                idx -= idx & -idx;
            }
            qLen = bl; qCnt = wc;
        }
    }
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String s = br.readLine();
        while (s != null && s.trim().isEmpty()) s = br.readLine();
        int n = Integer.parseInt(s.trim());
        int[] a = new int[n];
        StringTokenizer st = new StringTokenizer(br.readLine());
        for (int i = 0; i < n; i++) a[i] = Integer.parseInt(st.nextToken());
        int[] Lend = new int[n], Rbeg = new int[n];
        long[] CntL = new long[n], CntR = new long[n];
        BIT bit = new BIT(n);
        int L = 0;
        for (int i = 0; i < n; i++) {
            int v = a[i];
            bit.query(v - 1);
            int bl = bit.qLen;
            long wc = bit.qCnt;
            Lend[i] = bl + 1;
            CntL[i] = (bl == 0 ? 1 : wc);
            bit.update(v, Lend[i], CntL[i]);
            if (Lend[i] > L) L = Lend[i];
        }
        BIT bit2 = new BIT(n);
        for (int i = n - 1; i >= 0; i--) {
            int v = a[i];
            int rv = n - v + 1;
            bit2.query(rv - 1); // 对应原值域 (v+1..n)
            int bl = bit2.qLen;
            long wc = bit2.qCnt;
            Rbeg[i] = bl + 1;
            CntR[i] = (bl == 0 ? 1 : wc);
            bit2.update(rv, Rbeg[i], CntR[i]);
        }
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < n; i++) {
            if (Lend[i] + Rbeg[i] - 1 == L) {
                long ans = (CntL[i] * CntR[i]) % MOD;
                sb.append(ans).append('\n');
            } else {
                sb.append('0').append('\n');
            }
        }
        System.out.print(sb.toString());
    }
}
C++ 实现
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
struct BIT {
    int n;
    vector<int> best; // 最佳长度
    vector<int> ways; // 方案数(取模)
    BIT(int n=0): n(n), best(n+1,0), ways(n+1,0) {}
    void init(int _n){ n=_n; best.assign(n+1,0); ways.assign(n+1,0); }
    // 单点更新 (len, cnt)
    void upd(int i, int len, int cnt){
        while(i<=n){
            if(best[i] < len){
                best[i] = len;
                ways[i] = cnt;
            }else if(best[i] == len){
                ways[i] += cnt;
                if(ways[i] >= MOD) ways[i] -= MOD;
            }
            i += i & -i;
        }
    }
    // 前缀查询,返回 (bestLen, ways)
    pair<int,int> qry(int i){
        int bl=0, wc=0;
        while(i>0){
            if(best[i] > bl){
                bl = best[i];
                wc = ways[i];
            }else if(best[i] == bl){
                wc += ways[i];
                if(wc >= MOD) wc -= MOD;
            }
            i -= i & -i;
        }
        return {bl, wc};
    }
};
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n; 
    if(!(cin>>n)) return 0;
    vector<int> a(n);
    for(int i=0;i<n;i++) cin>>a[i];
    vector<int> Lend(n), Rbeg(n);
    vector<int> CntL(n), CntR(n);
    BIT bit(n);
    int L = 0;
    for(int i=0;i<n;i++){
        int v = a[i];
        auto p = bit.qry(v-1);
        int bl = p.first, wc = p.second;
        Lend[i] = bl + 1;
        CntL[i] = (bl==0 ? 1 : wc);
        bit.upd(v, Lend[i], CntL[i]);
        L = max(L, Lend[i]);
    }
    BIT bit2(n);
    for(int i=n-1;i>=0;i--){
        int v = a[i];
        int rv = n - v + 1;
        auto p = bit2.qry(rv-1); // 原值域 (v+1..n)
        int bl = p.first, wc = p.second;
        Rbeg[i] = bl + 1;
        CntR[i] = (bl==0 ? 1 : wc);
        bit2.upd(rv, Rbeg[i], CntR[i]);
    }
    for(int i=0;i<n;i++){
        if(Lend[i] + Rbeg[i] - 1 == L){
            long long ans = 1LL * CntL[i] * CntR[i] % MOD;
            cout << ans << "\n";
        }else{
            cout << 0 << "\n";
        }
    }
    return 0;
}
        题目内容
给定一个长为n 的排列 {a}。排列是指,{a} 包含1~n中的所有元素恰好一次。
回顾一下,{a} 的一个子序列是指从{a}中将若干元素(不一定连续)提取出来而不改变相对位置形成 的序列。 对于 {a} 的一个子序列 {b},如果 {b} 中的元素单调递增,就称{b}是{a}的一个上升子序列。 对于{a}的一个上升子序列{c},如果找不到比{c}元素个数更多的上升子序列了,就称 {c} 是 {a} 的 一个最长上升子序列。显然,{a} 可以有很多个最长上升子序列。 请你对每个i∈[1,n],求出有多少个不同的{a}的最长上升子序列包含ai。
输入描述
第一行一个正整数n
第二行n个正整数,以空格分隔,表示a1,a2,...,an
保证 {a} 是一个 1~n 的排列。 1≤n≤2×105
输出描述
输出 n 行,每一行一个非负整数。第i行的输出表示包含a 的最长上升子序列个数。
因为答案可能很大,所以你只需要输出答案对998244353取模的结果。
样例1
输入
5
3 1 4 2 5
输出
1
2
2
1
3
说明
对于输入的排列,其最长上升子序列长度为3,所有不同的最长上升了序列如下:
- 1,2,5。
 - 1,4,5。
 - 3,4,5。
 
以 a1和 a2为例。 包含了a1=3的最长上升子序列有1个,故答案为1。 包含了a2=1的最长上升子序列有2 个,故答案为2。