#P4027. 数组中的第K个最大元素
-
ID: 2243
Tried: 92
Accepted: 36
Difficulty: 5
数组中的第K个最大元素
题目
描述
给定整数数组 nums
和整数 k
,请返回数组中第 k
个最大的元素。
请注意,你需要找的是数组排序后的第 k
个最大的元素,而不是第 k
个不同的元素。
你必须设计并实现时间复杂度为 O(n)
的算法解决此问题。
输入
- 第一行输入
n
和k
,其中:n
代表数组长度 (1 ≤ n ≤ 10⁵)k
代表要找到的第k
个最大元素 (1 ≤ k ≤ n)
- 第二行输入
n
个整数,表示数组nums
(-10⁴ ≤ nums[i] ≤ 10⁴)
输出
- 输出数组
nums
中第k
个最大的元素。
样例输入 1
6 2
3 2 1 5 6 4
样例输出 1
5
样例输入 2
9 4
3 2 3 1 2 4 5 5 6
样例输出 2
4
题解
本题要求在 O(n)
时间复杂度内找到第 k
个最大的元素,适合使用 快速选择算法(QuickSelect)。快速选择是一种类似快速排序的算法,使用 分区(Partition) 思想,可以在 O(n)
的期望时间复杂度下找到第 k
大的元素。
思路
- 选择一个 基准元素(pivot) 进行 分区(Partition):
- 所有大于
pivot
的元素放在左侧; - 所有小于
pivot
的元素放在右侧; pivot
归位,确定其在 排序后的位置。
- 所有大于
- 递归进行:
- 如果
pivot
位置正好是k-1
,返回pivot
; - 如果
pivot
位置大于k-1
,在左侧继续查找; - 否则,在右侧继续查找。
- 如果
- 使用 随机化的 pivot 以减少最坏情况
O(n^2)
发生的概率。
代码实现
Python
import sys
import random
def partition(nums, left, right):
"""分区操作,返回 pivot 的最终位置"""
pivot_idx = random.randint(left, right)
pivot = nums[pivot_idx]
nums[pivot_idx], nums[right] = nums[right], nums[pivot_idx] # 把 pivot 移到最后
store_idx = left
for i in range(left, right):
if nums[i] > pivot: # 降序排列
nums[i], nums[store_idx] = nums[store_idx], nums[i]
store_idx += 1
nums[store_idx], nums[right] = nums[right], nums[store_idx] # pivot 归位
return store_idx
def quick_select(nums, left, right, k):
"""快速选择,寻找第 k 大的元素"""
if left <= right:
pivot_idx = partition(nums, left, right)
if pivot_idx == k:
return nums[pivot_idx]
elif pivot_idx > k:
return quick_select(nums, left, pivot_idx - 1, k)
else:
return quick_select(nums, pivot_idx + 1, right, k)
return -1 # 不应该到这里
if __name__ == "__main__":
# 读取输入
n, k = map(int, sys.stdin.readline().strip().split())
nums = list(map(int, sys.stdin.readline().strip().split()))
# 计算第 k 大元素
result = quick_select(nums, 0, n - 1, k - 1)
# 输出结果
print(result)
C++
#include <iostream>
#include <vector>
#include <cstdlib>
#include <ctime>
using namespace std;
// 分区函数
int partition(vector<int>& nums, int left, int right) {
int pivot_idx = left + rand() % (right - left + 1);
int pivot = nums[pivot_idx];
swap(nums[pivot_idx], nums[right]); // 把 pivot 移到末尾
int store_idx = left;
for (int i = left; i < right; i++) {
if (nums[i] > pivot) { // 降序排列
swap(nums[i], nums[store_idx]);
store_idx++;
}
}
swap(nums[store_idx], nums[right]); // pivot 归位
return store_idx;
}
// 快速选择算法
int quickSelect(vector<int>& nums, int left, int right, int k) {
if (left <= right) {
int pivot_idx = partition(nums, left, right);
if (pivot_idx == k)
return nums[pivot_idx];
else if (pivot_idx > k)
return quickSelect(nums, left, pivot_idx - 1, k);
else
return quickSelect(nums, pivot_idx + 1, right, k);
}
return -1;
}
int main() {
srand(time(0)); // 设置随机种子
int n, k;
cin >> n >> k;
vector<int> nums(n);
for (int i = 0; i < n; i++)
cin >> nums[i];
cout << quickSelect(nums, 0, n - 1, k - 1) << endl;
return 0;
}
Java
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt(), k = scanner.nextInt();
int[] nums = new int[n];
for (int i = 0; i < n; i++) {
nums[i] = scanner.nextInt();
}
scanner.close();
System.out.println(quickSelect(nums, 0, n - 1, k - 1));
}
private static int partition(int[] nums, int left, int right) {
Random rand = new Random();
int pivotIdx = left + rand.nextInt(right - left + 1);
int pivot = nums[pivotIdx];
swap(nums, pivotIdx, right);
int storeIdx = left;
for (int i = left; i < right; i++) {
if (nums[i] > pivot) {
swap(nums, i, storeIdx);
storeIdx++;
}
}
swap(nums, storeIdx, right);
return storeIdx;
}
private static int quickSelect(int[] nums, int left, int right, int k) {
if (left <= right) {
int pivotIdx = partition(nums, left, right);
if (pivotIdx == k) return nums[pivotIdx];
else if (pivotIdx > k) return quickSelect(nums, left, pivotIdx - 1, k);
else return quickSelect(nums, pivotIdx + 1, right, k);
}
return -1;
}
private static void swap(int[] nums, int i, int j) {
int temp = nums[i];
nums[i] = nums[j];
nums[j] = temp;
}
}