#P3733. 第1题-构造一元二次函数
-
1000ms
Tried: 56
Accepted: 22
Difficulty: 3
所属公司 :
美团
时间 :2025年9月20日-开发岗
-
算法标签>暴力枚举
第1题-构造一元二次函数
解题思路
本题的关键在于:由于 n=3 且要求 a,b,c 三个变量都至少被累加一次,因此三次累加恰好分别落在 a,b,c 上——也就是把 (a1,a2,a3) 以某种排列分配为 (p,q,r),得到
f(x)=p⋅x2+q⋅x+r,其中 (p,q,r) 是 (a1,a2,a3) 的六种排列之一。
因此每个询问仅需在 6 个排列中取
$$\max \big( (p \bmod m)\cdot (x^2 \bmod m) + (q \bmod m)\cdot (x \bmod m) + (r \bmod m) \big)\bmod m . $$利用同余运算性质先取模,避免大数溢出,计算高效。
算法类型:枚举/暴力 + 数学同余。 实现要点:
-
预处理 aimodm 为 αi;
-
对每个询问 x,计算 xm=xmodm、x2m=xm2modm;
-
枚举 (αi1,αi2,αi3) 的 6 个排列,计算
$$\text{val} = (\alpha_{i_1}\cdot x2_m + \alpha_{i_2}\cdot x_m + \alpha_{i_3}) \bmod m $$取最大值即可。
复杂度分析
- 每个询问固定枚举 6 种排列,常数时间;
- 时间复杂度:O(q);
- 空间复杂度:O(1)(仅常数级临时变量)。
代码实现
Python
import sys
# 给定 a_mod(三个数均已取模)、m 和 x,返回最大 f(x) mod m
def solve_one(a_mod, m, x):
xm = x % m # x mod m
x2m = (xm * xm) % m # x^2 mod m
# 六种排列(用下标表示)
perms = [
(0, 1, 2), (0, 2, 1),
(1, 0, 2), (1, 2, 0),
(2, 0, 1), (2, 1, 0)
]
best = 0
for i, j, k in perms:
# 逐步取模避免大数溢出
val = (a_mod[i] * x2m) % m
val = (val + (a_mod[j] * xm) % m) % m
val = (val + a_mod[k]) % m
if val > best:
best = val
return best
def main():
data = list(map(int, sys.stdin.read().strip().split()))
# 输入均为空格分隔,直接切分即可(无需使用 literal_eval)
it = iter(data)
n = next(it); m = next(it); q = next(it)
a = [next(it) for _ in range(n)]
xs = [next(it) for _ in range(q)]
# 预处理 a_i mod m
a_mod = [ai % m for ai in a]
ans = []
for x in xs:
ans.append(str(solve_one(a_mod, m, x)))
print(" ".join(ans))
if __name__ == "__main__":
main()
Java
import java.util.*;
public class Main {
// 返回最大 f(x) mod m
static long solveOne(long[] aMod, long m, long x) {
long xm = x % m; // x mod m
long x2m = (xm * xm) % m; // x^2 mod m
int[][] perms = {
{0,1,2},{0,2,1},{1,0,2},
{1,2,0},{2,0,1},{2,1,0}
};
long best = 0;
for (int[] p : perms) {
int i = p[0], j = p[1], k = p[2];
long val = (aMod[i] * x2m) % m;
val = (val + (aMod[j] * xm) % m) % m;
val = (val + aMod[k]) % m;
if (val > best) best = val;
}
return best;
}
public static void main(String[] args) {
// 数据规模很小,使用 Scanner 即可
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
long m = sc.nextLong();
int q = sc.nextInt();
long[] a = new long[n];
for (int i = 0; i < n; i++) a[i] = sc.nextLong();
long[] xs = new long[q];
for (int i = 0; i < q; i++) xs[i] = sc.nextLong();
// 预处理 a_i % m
long[] aMod = new long[3];
for (int i = 0; i < 3; i++) aMod[i] = ((a[i] % m) + m) % m;
StringBuilder sb = new StringBuilder();
for (int i = 0; i < q; i++) {
long res = solveOne(aMod, m, xs[i]);
if (i > 0) sb.append(' ');
sb.append(res);
}
System.out.println(sb.toString());
sc.close();
}
}
C++
#include <bits/stdc++.h>
using namespace std;
// 返回最大 f(x) mod m
long long solveOne(const array<long long,3>& aMod, long long m, long long x) {
long long xm = x % m; // x mod m
long long x2m = (xm * xm) % m; // x^2 mod m
int perms[6][3] = {
{0,1,2},{0,2,1},{1,0,2},
{1,2,0},{2,0,1},{2,1,0}
};
long long best = 0;
for (int t = 0; t < 6; ++t) {
int i = perms[t][0], j = perms[t][1], k = perms[t][2];
long long val = (aMod[i] * x2m) % m;
val = (val + (aMod[j] * xm) % m) % m;
val = (val + aMod[k]) % m;
if (val > best) best = val;
}
return best;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n; long long m; int q;
if (!(cin >> n >> m >> q)) return 0;
vector<long long> a(n);
for (int i = 0; i < n; ++i) cin >> a[i];
vector<long long> xs(q);
for (int i = 0; i < q; ++i) cin >> xs[i];
// 预处理 a_i % m
array<long long,3> aMod;
for (int i = 0; i < 3; ++i) aMod[i] = (a[i] % m + m) % m;
// 处理每个询问并输出
for (int i = 0; i < q; ++i) {
long long res = solveOne(aMod, m, xs[i]);
if (i) cout << ' ';
cout << res;
}
cout << "\n";
return 0;
}
题目内容
给定一个长度为 n 的数组 a1,a2,...,an ,和一个整数 m ,你需要回答 q 次询问,每一次询问给定一个 x ,随后,按照下方步骤进行一元二次函数的构造:
1.初始化 a=0,b=0,c=0 。
2.遍历序列 a1,a2,...,an ,每次遍历将 ai 累加到 a、b、c 三个值中的其中任意一个值上。
3.保证 a、b、c 都要被至少累加一次。
于是你构造出了一个二次函数 f(x)=a⋅x2+b⋅x+c 。对于每一个询问,请你计算出最大的 f(x) mod m (注意,是取模后的最大值)。
注意,每一次询问构造独立,互不干扰。
输入描述
第一次输入三个整数 n,m,q(n=3;2≤m≤648;1≤q≤100) 。
第二行输入 n 个整数 a1,a2,...,an(1≤ai≤106) 。
第三行输入 q 个整数 x1,x2,...,xq(0≤xi≤106) 。
输出描述
对于每一个询问,输出一个整数,表示 f(x) mod m 的最大值。
样例1
输入
3 4 2
1 2 5
2 0
输出
3 2