给定 n 个数组 a1,…,an, 找出所有的严格递增三元组(1≤i<j<k≤n) , 使得ai=ak=aj+1 , 输出其数量。
第一行输入一个正整数n(3≤n≤105)
第二行输入a1,...,an(1≤ai≤109)
一个正整数,代表符合条件的三元组数量
5
2 2 1 1 2
4
我们需要找出所有符合条件的严格递增三元组 (i, j, k)
,使得:
1 <= i < j < k <= n
a[i] = a[k] = a[j] + 1
换句话说,给定一个数 a[j]
,我们希望找到两个 i
和 k
,使得 a[i] = a[k] = a[j] + 1
且 i < j < k
。这可以形象地看作寻找形如(x+1,x,x+1)
的三元组。将求和过程先聚焦于中间的这个x,那么最终的答案 等价于
1.对于位置i的数a[i] , 寻找它前缀中a[i]+1 的个数和后缀中 a[i]+1的个数 , 根据乘法原理,答案是个数相乘
2.对每个位置的求和。
用数学公式来描述就是:
预处理手段:这个乘法的左侧和右侧,我们都可以使用前缀哈希表来维护。具体思路如下:
j
,我们希望找到它左侧和右侧分别有多少个数等于 a[j] + 1
。为了提高算法效率,我们可以使用两个哈希表 count_left
和 count_right
:
count_left[x]
表示 x
在索引小于 j
的位置出现的次数。count_right[x]
表示 x
在索引大于 j
的位置出现的次数。a[j]
,以j
为中心的所有满足条件的三元组数量等于count_left[a[j] + 1] * count_right[a[j] + 1]
。count_left
和 count_right
数组,以便为下一个元素计算。具体的,由于我们是从左往右遍历,所以一开始时,count_left
为空, 而 count_right
包含所有数组。在遍历的过程中,往count_left
里添加a[i],往count_right
里删除a[i]count_right
数组的时间复杂度为 O(n)。j
对应的三元组数,时间复杂度是 O(n)。from collections import defaultdict
def main():
n = int(input())
res = 0
count_left = defaultdict(int)
count_right = defaultdict(int)
a = list(map(int, input().split()))
for i in a:
count_right[i] += 1 # 统计每个元素的频率
for i in a:
t = i + 1 # 计算t = a[j] + 1
res += count_left[t] * count_right[t] # 计算符合条件的三元组数量
count_left[i] += 1 # 更新count_left
count_right[i] -= 1 # 更新count_right
print(res)
if __name__ == "__main__":
main()