#P3830. 第2题-美丽的三元组
-
1000ms
Tried: 87
Accepted: 24
Difficulty: 6
所属公司 :
京东
时间 :2025年9月27日
-
算法标签>组合数学哈希表
第2题-美丽的三元组
解题思路
给定长度为 n 的数组 a。把所有相邻的三个数形成的三元组 (ai,ai+1,ai+2) 记下来,共有 m=n−2 个。题目要求统计三元组对 (b,c) 中,恰好有且只有一个位置不同(海明距离为 1)的对数。
观察:一对三元组只在第 1 位不同 ⇔ 它们的 (b2,b3) 完全相同、但 b1 不同。于是可以把问题拆成三类并求和(互不重叠):
- 只在第 1 位不同;
- 只在第 2 位不同;
- 只在第 3 位不同。
做法(哈希计数):
-
枚举所有连续三元组,统计四个频次:
cnt_xyz[(x,y,z)]:完整三元组出现次数;cnt_yz[(y,z)]:固定后两位;cnt_xz[(x,z)]:固定第 1、3 位;cnt_xy[(x,y)]:固定前两位。
-
对于“只在第 1 位不同”的对数: 对每个键 (y,z),先计算同键内两两配对总数 (2cnt_yz),再减去同一完整三元组的配对 ∑(2cnt_xyz)(这些配对 0 位不同,需要剔除)。 记作 A=∑(y,z)(2cnt_yz(y,z))-∑(x,y,z)(2cnt_xyz(x,y,z))。
-
同理可得只在第 2 位不同的 B 与只在第 3 位不同的 C。 最终答案为 A+B+C。
该算法本质是用分组计数 + 组合数消重,属于哈希/计数类算法。
实现要点:
- 组合数 (2k)=k(k−1)/2。
- 由于答案可能很大,使用 64 位整型。
- 只需一次线性扫描与若干次哈希遍历,时间、空间都可控。
复杂度分析
- 时间复杂度: 构建三种二元键与一种三元键的计数是 O(n),之后遍历哈希表求和,规模不超过三元组数的常数倍,整体 O(n)。
- 空间复杂度: 存储若干哈希表,键的数量均不超过 n,故为 O(n)。
代码实现
Python
# -*- coding: utf-8 -*-
# 题意:统计所有相邻三元组中,海明距离为1的三元组对数量
# 思路:哈希计数 + 组合数消重(见题解)
import sys
from collections import Counter
def comb2(k: int) -> int:
return k * (k - 1) // 2
def solve_case(arr):
n = len(arr)
if n < 3:
return 0
# 构建所有连续三元组
triples = [(arr[i], arr[i+1], arr[i+2]) for i in range(n - 2)]
# 统计四类频次
cnt_xyz = Counter(triples)
cnt_yz = Counter([(y, z) for _, y, z in triples])
cnt_xz = Counter([(x, z) for x, _, z in triples])
cnt_xy = Counter([(x, y) for x, y, _ in triples])
# 计算三部分
same_xyz_pairs = sum(comb2(v) for v in cnt_xyz.values())
part1 = sum(comb2(v) for v in cnt_yz.values()) - same_xyz_pairs # 只在第1位不同
part2 = sum(comb2(v) for v in cnt_xz.values()) - same_xyz_pairs # 只在第2位不同
part3 = sum(comb2(v) for v in cnt_xy.values()) - same_xyz_pairs # 只在第3位不同
return part1 + part2 + part3
def main():
data = list(map(int, sys.stdin.buffer.read().split()))
it = iter(data)
t = next(it)
out_lines = []
for _ in range(t):
n = next(it)
arr = [next(it) for _ in range(n)]
out_lines.append(str(solve_case(arr)))
sys.stdout.write("\n".join(out_lines))
if __name__ == "__main__":
main()
Java
// 题意:统计所有相邻三元组中,海明距离为1的三元组对数量
// 思路:哈希计数 + 组合数消重(见题解)
import java.io.*;
import java.util.*;
public class Main {
// 自定义较快输入(数据范围较大)
static class FastScanner {
private final InputStream in;
private final byte[] buffer = new byte[1 << 16];
private int ptr = 0, len = 0;
FastScanner(InputStream is) { in = is; }
private int read() throws IOException {
if (ptr >= len) {
len = in.read(buffer);
ptr = 0;
if (len <= 0) return -1;
}
return buffer[ptr++];
}
int nextInt() throws IOException {
int c, sgn = 1, x = 0;
do { c = read(); } while (c <= ' '); // 跳过空白
if (c == '-') { sgn = -1; c = read(); }
while (c > ' ') { x = x * 10 + (c - '0'); c = read(); }
return x * sgn;
}
}
// 将两个/三个数编码为long键,避免对象开销;值范围<=1e6,用21位足够
static long enc2(int a, int b) { return (((long)a) << 21) | (long)b; }
static long enc3(int a, int b, int c) { return (((long)a) << 42) | (((long)b) << 21) | (long)c; }
static long C2(long k) { return k * (k - 1) / 2; }
static long solveCase(int[] a) {
int n = a.length;
if (n < 3) return 0L;
HashMap<Long, Integer> cntXYZ = new HashMap<>();
HashMap<Long, Integer> cntYZ = new HashMap<>();
HashMap<Long, Integer> cntXZ = new HashMap<>();
HashMap<Long, Integer> cntXY = new HashMap<>();
for (int i = 0; i + 2 < n; i++) {
int x = a[i], y = a[i+1], z = a[i+2];
long k3 = enc3(x, y, z);
long kyz = enc2(y, z);
long kxz = enc2(x, z);
long kxy = enc2(x, y);
cntXYZ.merge(k3, 1, Integer::sum);
cntYZ.merge(kyz, 1, Integer::sum);
cntXZ.merge(kxz, 1, Integer::sum);
cntXY.merge(kxy, 1, Integer::sum);
}
long sameXYZPairs = 0;
for (int v : cntXYZ.values()) sameXYZPairs += C2(v);
long part1 = 0, part2 = 0, part3 = 0;
for (int v : cntYZ.values()) part1 += C2(v);
for (int v : cntXZ.values()) part2 += C2(v);
for (int v : cntXY.values()) part3 += C2(v);
// 剔除“完全相同”的配对
part1 -= sameXYZPairs;
part2 -= sameXYZPairs;
part3 -= sameXYZPairs;
return part1 + part2 + part3;
}
public static void main(String[] args) throws Exception {
FastScanner fs = new FastScanner(System.in);
StringBuilder sb = new StringBuilder();
int T = fs.nextInt();
for (int tc = 0; tc < T; tc++) {
int n = fs.nextInt();
int[] a = new int[n];
for (int i = 0; i < n; i++) a[i] = fs.nextInt();
sb.append(solveCase(a)).append('\n');
}
System.out.print(sb.toString());
}
}
C++
// 题意:统计所有相邻三元组中,海明距离为1的三元组对数量
// 思路:哈希计数 + 组合数消重(见题解)
#include <bits/stdc++.h>
using namespace std;
static inline long long C2(long long k){ return k*(k-1)/2; }
// 数值<=1e6,用21位编码足够
static inline long long enc2(int a, int b){
return ( (long long)a << 21 ) | (long long)b;
}
static inline long long enc3(int a, int b, int c){
return ( ( (long long)a << 42 ) | ( (long long)b << 21 ) | (long long)c );
}
long long solve_case(const vector<int>& a){
int n = (int)a.size();
if(n < 3) return 0LL;
unordered_map<long long,int> cntXYZ, cntYZ, cntXZ, cntXY;
cntXYZ.reserve(n*2);
cntYZ.reserve(n*2);
cntXZ.reserve(n*2);
cntXY.reserve(n*2);
for(int i=0;i+2<n;i++){
int x=a[i], y=a[i+1], z=a[i+2];
cntXYZ[enc3(x,y,z)]++;
cntYZ[enc2(y,z)]++;
cntXZ[enc2(x,z)]++;
cntXY[enc2(x,y)]++;
}
long long sameXYZPairs = 0;
for(auto &p: cntXYZ) sameXYZPairs += C2(p.second);
long long part1 = 0, part2 = 0, part3 = 0;
for(auto &p: cntYZ) part1 += C2(p.second);
for(auto &p: cntXZ) part2 += C2(p.second);
for(auto &p: cntXY) part3 += C2(p.second);
part1 -= sameXYZPairs;
part2 -= sameXYZPairs;
part3 -= sameXYZPairs;
return part1 + part2 + part3;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;
if(!(cin>>T)) return 0;
while(T--){
int n; cin>>n;
vector<int> a(n);
for(int i=0;i<n;i++) cin>>a[i];
cout<<solve_case(a)<<"\n";
}
return 0;
}
题目内容
小 Q 得到了一个包含 n 个整数的数组 a 。他非常喜欢三元组,所以他把所有相邻的三个位置都看成一个三元组,然后抄到了他的本子上,一共写了 n−2 个三元组。
当且仅当两个三元组恰好有一个位置不同时,小 Q 认为这两个三元组蕴含了韵律的美。即满足以下条件之一时,小 Q 认为一对三元组符合他的要求:
-
b1=c1 且 b2=c2 且 b3=c3 ;
-
b1=c1 且 b2=c2 且 b3=c3 ;
-
b1=c1 且 b2=c2 且 b3=c3 。
找出所有他写在本子上的三元组中,符合他的要求的三元组对的数量。
输入描述
单个数据包含多组测试用例
第一行包含一个整数 t(1≤t≤104) 测试用例的数量。
对于每组数据:
第一行包含一个整数 n(3≤n≤2∗105) 表示数组 a 的长度。
第二行包含 n 个整数 a1,a2,an(1≤ai≤109) 表示数组的元素。
输出描述
对于每个测试用例,输出一个整数,表示符合小 Q 要求的三元组对的数量。
注意,答案可能无法用 32 位数据类型表示。
样例1
输入
3
5
2 1 1 1 2
5
4 3 4 3 4
8
5 5 4 3 4 5 5 3
输出
2
0
2