#P3401. 第4题-三元组数组
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 52
            Accepted: 4
            Difficulty: 7
            
          
          
          
                       所属公司 : 
                              字节
                                
            
                        
              时间 :2025年8月17日
                              
                      
          
 
- 
                        算法标签>数学          
 
第4题-三元组数组
思路
对固定中间位置j,记y=aj。令频次函数为f(x),并定义
- G(t):数组中≥t的元素个数,
 - L(t):数组中≤t的元素个数。
 
由条件max(ai−aj,aj−ak)=2aj可拆分为两种达到最大值的方式,并用容斥合并:
- 当ai−aj=2aj时,ai=3y且aj−ak≤2y⇒ak≥−y,贡献为f(3y)⋅G(−y);
 - 当aj−ak=2aj时,ak=−y且ai−aj≤2y⇒ai≤3y,贡献为f(−y)⋅L(3y);
 - 二者同时成立时(ai=3y且ak=−y)被重复计算一次,需要减去f(3y)⋅f(−y)。
 
因此固定y的一次j的贡献为: f(3y)∗G(−y)+f(−y)∗L(3y)−f(3y)∗f(−y)。
总答案为对所有位置j的和,等价于对所有不同的y按出现次数f(y)加权:

实现要点:
- 先排序整个数组,用二分得到G(−y)与L(3y);
 - 用哈希表统计f(⋅);
 - 遍历不同的y计算并累加贡献。时间复杂度O(nlogn),空间O(n)。
 - 结果量级可达n3,在n≤2×105时不超过8×1015,用64位整型保存即可。
 
C++
#include <bits/stdc++.h>
using namespace std;
int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	int n;
	if (!(cin >> n)) return 0;
	vector<long long> a(n);
	for (int i = 0; i < n; ++i) cin >> a[i];
	// 统计频次
	unordered_map<long long, long long> freq;
	freq.reserve(n * 2);
	for (auto v : a) ++freq[v];
	// 排序用于二分求 G 和 L
	vector<long long> s = a;
	sort(s.begin(), s.end());
	auto ge_count = [&](long long t) -> long long {
		// 统计 >= t 的个数
		auto it = lower_bound(s.begin(), s.end(), t);
		return (long long)s.size() - (it - s.begin());
	};
	auto le_count = [&](long long t) -> long long {
		// 统计 <= t 的个数
		auto it = upper_bound(s.begin(), s.end(), t);
		return (long long)(it - s.begin());
	};
	long long ans = 0;
	ans = 0;
	for (const auto &kv : freq) {
		long long y = kv.first;
		long long fy = kv.second;
		long long f3y = 0, fny = 0;
		auto it1 = freq.find(3 * y);
		if (it1 != freq.end()) f3y = it1->second;
		auto it2 = freq.find(-y);
		if (it2 != freq.end()) fny = it2->second;
		long long G = ge_count(-y);
		long long L = le_count(3 * y);
		long long contrib_one_j = f3y * G + fny * L - f3y * fny;
		ans += fy * contrib_one_j;
	}
	cout << ans << '\n';
	return 0;
}
Python
import sys
import bisect
def main():
	data = sys.stdin.read().strip().split()
	if not data:
		return
	it = iter(data)
	n = int(next(it))
	a = [int(next(it)) for _ in range(n)]
	# 统计频次
	from collections import Counter
	freq = Counter(a)
	# 排序用于二分
	s = sorted(a)
	def ge_count(t):
		# >= t
		idx = bisect.bisect_left(s, t)
		return len(s) - idx
	def le_count(t):
		# <= t
		idx = bisect.bisect_right(s, t)
		return idx
	ans = 0
	for y, fy in freq.items():
		f3y = freq.get(3 * y, 0)
		fny = freq.get(-y, 0)
		G = ge_count(-y)
		L = le_count(3 * y)
		contrib_one_j = f3y * G + fny * L - f3y * fny
		ans += fy * contrib_one_j
	print(ans)
if __name__ == "__main__":
	main()
Java
import java.io.*;
import java.util.*;
public class Main {
	static class FastScanner {
		BufferedInputStream in = new BufferedInputStream(System.in);
		byte[] buffer = new byte[1 << 16];
		int ptr = 0, len = 0;
		int read() throws IOException {
			if (ptr >= len) {
				len = in.read(buffer);
				ptr = 0;
				if (len <= 0) return -1;
			}
			return buffer[ptr++];
		}
		long nextLong() throws IOException {
			int c;
			do { c = read(); } while (c <= 32 && c != -1);
			boolean neg = false;
			if (c == '-') { neg = true; c = read(); }
			long x = 0;
			while (c > 32 && c != -1) {
				x = x * 10 + (c - '0');
				c = read();
			}
			return neg ? -x : x;
		}
	}
	static int lowerBound(long[] arr, long key) {
		int l = 0, r = arr.length;
		while (l < r) {
			int m = (l + r) >>> 1;
			if (arr[m] >= key) r = m; else l = m + 1;
		}
		return l;
	}
	static int upperBound(long[] arr, long key) {
		int l = 0, r = arr.length;
		while (l < r) {
			int m = (l + r) >>> 1;
			if (arr[m] <= key) l = m + 1; else r = m;
		}
		return l;
	}
	public static void main(String[] args) throws Exception {
		FastScanner fs = new FastScanner();
		int n = (int)fs.nextLong();
		long[] a = new long[n];
		for (int i = 0; i < n; i++) a[i] = fs.nextLong();
		// 统计频次
		HashMap<Long, Long> freq = new HashMap<>(n * 2);
		for (long v : a) freq.put(v, freq.getOrDefault(v, 0L) + 1);
		// 排序用于二分
		long[] s = a.clone();
		Arrays.sort(s);
		long ans = 0L;
		for (Map.Entry<Long, Long> e : freq.entrySet()) {
			long y = e.getKey();
			long fy = e.getValue();
			long f3y = freq.getOrDefault(3L * y, 0L);
			long fny = freq.getOrDefault(-y, 0L);
			long G = (long)(s.length - lowerBound(s, -y));     // >= -y
			long L = (long)upperBound(s, 3L * y);              // <= 3y
			long contrib_one_j = f3y * G + fny * L - f3y * fny;
			ans += fy * contrib_one_j;
		}
		System.out.println(ans);
	}
}
        题目内容
小红拿到一个长度为 n 的数组 {a1,a2,…,an} 。
她想要知道有多少个三元组 (i,j,k) 满足 max(ai−aj,aj−ak)=2×aj 。
请你帮她数一数。
输入描述
第一行输入一个整数 n(3≦n≦2×105) 代表数组长度。
第二行输入 n 个整数 a1,a2,...,an(−109≦ai≦109) ,表示数组中的元素。
输出描述
输出一个整数,表示满足条件的三元组个数。
样例1
输入
3
0 1 -1
输出
6
说明
在这个样例中,满足条件的三元组选取为
- 
(i,j,k)=(1,1,1) ,此时左边 max {0−0,0−0} =0 ,右边 2×a1=0 ;
 - 
(i,j,k)=(1,1,2) ,此时左边 max {0−0,0−1} =0 ,右边 2×a1=0 ;
 - 
(i,j,k)=(1,2,3) ,此时左边 max {0−1,1−(−1)} =2 ,右边 2×a2=2 ;
 - 
(i,j,k)=(2,2,3) ,此时左边 max {1−1,1−(−1)} =2 ,右边 2×a2=2 ;
 - 
(i,j,k)=(3,1,1) ,此时左边 max {−1−0,0−0} =0 ,右边 2×a1=0 ;
 - 
(i,j,k)=(3,2,3) ,此时左边 max {−1−1,1−(−1)} =2 ,右边 2×a2=2 。
 
样例2
输入
4
0 0 -1 1
输出
20