#P4101. 颜色分类
-
ID: 2338
Tried: 53
Accepted: 20
Difficulty: 5
颜色分类
题目描述
给定一个长度为 n 的数组 nums
,其中包含 红色、白色 和 蓝色 三种颜色,分别用整数 0
、1
和 2
表示。请对 nums
进行 原地排序,使得相同颜色的元素相邻,并按 红色、白色、蓝色 顺序排列。
要求 不使用 内置排序函数。
输入描述
输入包含两行:
- 第一行输入一个整数 n(1≤n≤300),表示数组的长度。
- 第二行输入 n 个整数,表示数组
nums
,其中nums[i] ∈ {0, 1, 2}
。
输出描述
输出一行,表示排序后的数组,元素之间用 空格 分隔。
样例输入 1
6
2 0 2 1 1 0
样例输出 1
0 0 1 1 2 2
样例输入 2
3
2 0 1
样例输出 2
0 1 2
解题思路:荷兰国旗问题(双指针)
我们使用 三路快排(荷兰国旗算法) 进行原地排序,主要思想是:
-
维护 三个指针:
low
指向 0 的最右边界(即 0 应该放置的位置)。mid
当前遍历的元素索引。high
指向 2 的最左边界(即 2 应该放置的位置)。
-
遍历数组:
- 如果
nums[mid] == 0
,交换nums[mid]
和nums[low]
,然后low++
,mid++
。 - 如果
nums[mid] == 1
,mid++
继续遍历。 - 如果
nums[mid] == 2
,交换nums[mid]
和nums[high]
,然后high--
,但 不移动mid
(因为交换后的值还未检查)。
- 如果
-
终止条件:当
mid > high
时,排序完成。
时间 & 空间复杂度
- 时间复杂度:O(n),仅需 一次遍历 数组。
- 空间复杂度:O(1),原地排序,不使用额外空间。
Python 代码
import sys
def sort_colors(nums):
""" 使用双指针(荷兰国旗问题)对数组进行排序 """
low, mid, high = 0, 0, len(nums) - 1
while mid <= high:
if nums[mid] == 0: # 0 交换到左侧
nums[low], nums[mid] = nums[mid], nums[low]
low += 1
mid += 1
elif nums[mid] == 1: # 1 直接跳过
mid += 1
else: # 2 交换到右侧
nums[mid], nums[high] = nums[high], nums[mid]
high -= 1
def main():
""" 读取输入并调用排序函数 """
n = int(sys.stdin.readline().strip()) # 读取数组长度
nums = list(map(int, sys.stdin.readline().strip().split())) # 读取数组
sort_colors(nums) # 排序
print(" ".join(map(str, nums))) # 输出排序后的数组
if __name__ == "__main__":
main()
Java 代码
import java.util.*;
public class Main {
public static void sortColors(int[] nums) {
int low = 0, mid = 0, high = nums.length - 1;
while (mid <= high) {
if (nums[mid] == 0) { // 0 交换到左侧
int temp = nums[low];
nums[low] = nums[mid];
nums[mid] = temp;
low++;
mid++;
} else if (nums[mid] == 1) { // 1 直接跳过
mid++;
} else { // 2 交换到右侧
int temp = nums[mid];
nums[mid] = nums[high];
nums[high] = temp;
high--;
}
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt(); // 读取数组长度
int[] nums = new int[n];
for (int i = 0; i < n; i++) {
nums[i] = sc.nextInt();
}
sortColors(nums); // 排序
for (int i = 0; i < n; i++) {
System.out.print(nums[i] + (i == n - 1 ? "\n" : " "));
}
sc.close();
}
}
C++ 代码
#include <iostream>
#include <vector>
using namespace std;
void sortColors(vector<int>& nums) {
int low = 0, mid = 0, high = nums.size() - 1;
while (mid <= high) {
if (nums[mid] == 0) { // 0 交换到左侧
swap(nums[low], nums[mid]);
low++;
mid++;
} else if (nums[mid] == 1) { // 1 直接跳过
mid++;
} else { // 2 交换到右侧
swap(nums[mid], nums[high]);
high--;
}
}
}
int main() {
int n;
cin >> n; // 读取数组长度
vector<int> nums(n);
for (int i = 0; i < n; i++) {
cin >> nums[i];
}
sortColors(nums); // 排序
for (int i = 0; i < n; i++) {
cout << nums[i] << (i == n - 1 ? "\n" : " ");
}
return 0;
}