#P3401. 第4题-三元组数组
-
1000ms
Tried: 52
Accepted: 4
Difficulty: 7
所属公司 :
字节
时间 :2025年8月17日
-
算法标签>数学
第4题-三元组数组
思路
对固定中间位置j,记y=aj。令频次函数为f(x),并定义
- G(t):数组中≥t的元素个数,
- L(t):数组中≤t的元素个数。
由条件max(ai−aj,aj−ak)=2aj可拆分为两种达到最大值的方式,并用容斥合并:
- 当ai−aj=2aj时,ai=3y且aj−ak≤2y⇒ak≥−y,贡献为f(3y)⋅G(−y);
- 当aj−ak=2aj时,ak=−y且ai−aj≤2y⇒ai≤3y,贡献为f(−y)⋅L(3y);
- 二者同时成立时(ai=3y且ak=−y)被重复计算一次,需要减去f(3y)⋅f(−y)。
因此固定y的一次j的贡献为: f(3y)∗G(−y)+f(−y)∗L(3y)−f(3y)∗f(−y)。
总答案为对所有位置j的和,等价于对所有不同的y按出现次数f(y)加权:

实现要点:
- 先排序整个数组,用二分得到G(−y)与L(3y);
- 用哈希表统计f(⋅);
- 遍历不同的y计算并累加贡献。时间复杂度O(nlogn),空间O(n)。
- 结果量级可达n3,在n≤2×105时不超过8×1015,用64位整型保存即可。
C++
#include <bits/stdc++.h>
using namespace std;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
if (!(cin >> n)) return 0;
vector<long long> a(n);
for (int i = 0; i < n; ++i) cin >> a[i];
// 统计频次
unordered_map<long long, long long> freq;
freq.reserve(n * 2);
for (auto v : a) ++freq[v];
// 排序用于二分求 G 和 L
vector<long long> s = a;
sort(s.begin(), s.end());
auto ge_count = [&](long long t) -> long long {
// 统计 >= t 的个数
auto it = lower_bound(s.begin(), s.end(), t);
return (long long)s.size() - (it - s.begin());
};
auto le_count = [&](long long t) -> long long {
// 统计 <= t 的个数
auto it = upper_bound(s.begin(), s.end(), t);
return (long long)(it - s.begin());
};
long long ans = 0;
ans = 0;
for (const auto &kv : freq) {
long long y = kv.first;
long long fy = kv.second;
long long f3y = 0, fny = 0;
auto it1 = freq.find(3 * y);
if (it1 != freq.end()) f3y = it1->second;
auto it2 = freq.find(-y);
if (it2 != freq.end()) fny = it2->second;
long long G = ge_count(-y);
long long L = le_count(3 * y);
long long contrib_one_j = f3y * G + fny * L - f3y * fny;
ans += fy * contrib_one_j;
}
cout << ans << '\n';
return 0;
}
Python
import sys
import bisect
def main():
data = sys.stdin.read().strip().split()
if not data:
return
it = iter(data)
n = int(next(it))
a = [int(next(it)) for _ in range(n)]
# 统计频次
from collections import Counter
freq = Counter(a)
# 排序用于二分
s = sorted(a)
def ge_count(t):
# >= t
idx = bisect.bisect_left(s, t)
return len(s) - idx
def le_count(t):
# <= t
idx = bisect.bisect_right(s, t)
return idx
ans = 0
for y, fy in freq.items():
f3y = freq.get(3 * y, 0)
fny = freq.get(-y, 0)
G = ge_count(-y)
L = le_count(3 * y)
contrib_one_j = f3y * G + fny * L - f3y * fny
ans += fy * contrib_one_j
print(ans)
if __name__ == "__main__":
main()
Java
import java.io.*;
import java.util.*;
public class Main {
static class FastScanner {
BufferedInputStream in = new BufferedInputStream(System.in);
byte[] buffer = new byte[1 << 16];
int ptr = 0, len = 0;
int read() throws IOException {
if (ptr >= len) {
len = in.read(buffer);
ptr = 0;
if (len <= 0) return -1;
}
return buffer[ptr++];
}
long nextLong() throws IOException {
int c;
do { c = read(); } while (c <= 32 && c != -1);
boolean neg = false;
if (c == '-') { neg = true; c = read(); }
long x = 0;
while (c > 32 && c != -1) {
x = x * 10 + (c - '0');
c = read();
}
return neg ? -x : x;
}
}
static int lowerBound(long[] arr, long key) {
int l = 0, r = arr.length;
while (l < r) {
int m = (l + r) >>> 1;
if (arr[m] >= key) r = m; else l = m + 1;
}
return l;
}
static int upperBound(long[] arr, long key) {
int l = 0, r = arr.length;
while (l < r) {
int m = (l + r) >>> 1;
if (arr[m] <= key) l = m + 1; else r = m;
}
return l;
}
public static void main(String[] args) throws Exception {
FastScanner fs = new FastScanner();
int n = (int)fs.nextLong();
long[] a = new long[n];
for (int i = 0; i < n; i++) a[i] = fs.nextLong();
// 统计频次
HashMap<Long, Long> freq = new HashMap<>(n * 2);
for (long v : a) freq.put(v, freq.getOrDefault(v, 0L) + 1);
// 排序用于二分
long[] s = a.clone();
Arrays.sort(s);
long ans = 0L;
for (Map.Entry<Long, Long> e : freq.entrySet()) {
long y = e.getKey();
long fy = e.getValue();
long f3y = freq.getOrDefault(3L * y, 0L);
long fny = freq.getOrDefault(-y, 0L);
long G = (long)(s.length - lowerBound(s, -y)); // >= -y
long L = (long)upperBound(s, 3L * y); // <= 3y
long contrib_one_j = f3y * G + fny * L - f3y * fny;
ans += fy * contrib_one_j;
}
System.out.println(ans);
}
}
题目内容
小红拿到一个长度为 n 的数组 {a1,a2,…,an} 。
她想要知道有多少个三元组 (i,j,k) 满足 max(ai−aj,aj−ak)=2×aj 。
请你帮她数一数。
输入描述
第一行输入一个整数 n(3≦n≦2×105) 代表数组长度。
第二行输入 n 个整数 a1,a2,...,an(−109≦ai≦109) ,表示数组中的元素。
输出描述
输出一个整数,表示满足条件的三元组个数。
样例1
输入
3
0 1 -1
输出
6
说明
在这个样例中,满足条件的三元组选取为
-
(i,j,k)=(1,1,1) ,此时左边 max {0−0,0−0} =0 ,右边 2×a1=0 ;
-
(i,j,k)=(1,1,2) ,此时左边 max {0−0,0−1} =0 ,右边 2×a1=0 ;
-
(i,j,k)=(1,2,3) ,此时左边 max {0−1,1−(−1)} =2 ,右边 2×a2=2 ;
-
(i,j,k)=(2,2,3) ,此时左边 max {1−1,1−(−1)} =2 ,右边 2×a2=2 ;
-
(i,j,k)=(3,1,1) ,此时左边 max {−1−0,0−0} =0 ,右边 2×a1=0 ;
-
(i,j,k)=(3,2,3) ,此时左边 max {−1−1,1−(−1)} =2 ,右边 2×a2=2 。
样例2
输入
4
0 0 -1 1
输出
20