#P3433. 第2题-最大池化操作
-
ID: 2775
Tried: 33
Accepted: 7
Difficulty: 2
所属公司 :
阿里
时间 :2025年8月22日-菜鸟
-
算法标签>模拟
第2题-最大池化操作
思路
- 解析输入:将输入的字符串转换为实际的二维列表和一维列表。
- 确定参数:设输入特征图大小为 n×n,池化窗口大小为 w×h。
- 计算输出尺寸:输出特征图的高度为 out_h=⌊(n−h)/stride⌋+1,宽度为 out_w=⌊(n−w)/stride⌋+1。通常步长(stride)等于池化窗口的尺寸(即非重叠池化),因此这里步长在高度方向为 h,在宽度方向为 w。
- 滑动窗口取最大值:遍历每个窗口,计算窗口内的最大值,并填入输出矩阵。
代码实现
C++
#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <algorithm>
#include <cmath>
using namespace std;
// 解析一行列表字符串,返回一维整数向量
vector<int> parseList(const string& s) {
vector<int> res;
stringstream ss(s);
char ch;
int num;
while (ss >> ch) {
if (ch == '[' || ch == ',' || ch == ' ') {
continue;
} else if (ch == ']') {
break;
} else {
ss.putback(ch);
ss >> num;
res.push_back(num);
}
}
return res;
}
// 解析二维列表字符串,返回二维整数向量
vector<vector<int>> parse2DList(const string& s) {
vector<vector<int>> res;
stringstream ss(s);
char ch;
// 跳过开头的'['
ss >> ch;
while (true) {
// 读取下一个字符,可能是'['或']'
ss >> ch;
if (ch == ']') {
break;
} else if (ch == '[') {
// 解析一行
string rowStr;
getline(ss, rowStr, ']');
rowStr = "[" + rowStr + "]";
vector<int> row = parseList(rowStr);
res.push_back(row);
}
}
return res;
}
int main() {
string featureMapStr;
getline(cin, featureMapStr);
string poolSizeStr;
getline(cin, poolSizeStr);
// 解析特征图和池化窗口大小
vector<vector<int>> featureMap = parse2DList(featureMapStr);
vector<int> poolSize = parseList(poolSizeStr);
int n = featureMap.size();
int w = poolSize[0];
int h = poolSize[1];
// 计算输出特征图的尺寸
int out_h = n / h;
int out_w = n / w;
vector<vector<int>> result(out_h, vector<int>(out_w));
for (int i = 0; i < out_h; i++) {
for (int j = 0; j < out_w; j++) {
int start_i = i * h;
int start_j = j * w;
int maxVal = featureMap[start_i][start_j];
// 遍历池化窗口内的所有元素
for (int ki = 0; ki < h; ki++) {
for (int kj = 0; kj < w; kj++) {
int cur = featureMap[start_i + ki][start_j + kj];
if (cur > maxVal) {
maxVal = cur;
}
}
}
result[i][j] = maxVal;
}
}
// 输出结果
cout << "[";
for (int i = 0; i < out_h; i++) {
cout << "[";
for (int j = 0; j < out_w; j++) {
cout << result[i][j];
if (j < out_w - 1) {
cout << ",";
}
}
cout << "]";
if (i < out_h - 1) {
cout << ",";
}
}
cout << "]" << endl;
return 0;
}
Python
def main():
# 读取输入
feature_map_str = input().strip()
pool_size_str = input().strip()
# 解析特征图字符串为二维列表
feature_map = eval(feature_map_str)
# 解析池化窗口大小
pool_size = eval(pool_size_str)
n = len(feature_map)
w, h = pool_size
# 计算输出尺寸
out_h = n // h
out_w = n // w
result = []
for i in range(out_h):
row = []
for j in range(out_w):
start_i = i * h
start_j = j * w
# 获取当前窗口内的最大值
max_val = feature_map[start_i][start_j]
for ki in range(h):
for kj in range(w):
cur = feature_map[start_i + ki][start_j + kj]
if cur > max_val:
max_val = cur
row.append(max_val)
result.append(row)
# 使用自定义格式输出,确保没有多余空格
print(str(result).replace(", ", ","))
if __name__ == "__main__":
main()
Java
import java.util.*;
import java.util.regex.*;
public class Main {
// 解析一维列表
public static List<Integer> parseList(String s) {
List<Integer> res = new ArrayList<>();
s = s.replaceAll("[\\[\\]\\s]", "");
String[] parts = s.split(",");
for (String part : parts) {
if (!part.isEmpty()) {
res.add(Integer.parseInt(part));
}
}
return res;
}
// 解析二维列表
public static List<List<Integer>> parse2DList(String s) {
List<List<Integer>> res = new ArrayList<>();
// 去除最外层的括号
s = s.substring(1, s.length() - 1);
// 按行分割
String[] rows = s.split("\\],\\[");
for (int i = 0; i < rows.length; i++) {
String rowStr = rows[i];
// 处理首尾的括号
if (i == 0) {
rowStr = rowStr + "]";
} else if (i == rows.length - 1) {
rowStr = "[" + rowStr;
} else {
rowStr = "[" + rowStr + "]";
}
res.add(parseList(rowStr));
}
return res;
}
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
String featureMapStr = scanner.nextLine();
String poolSizeStr = scanner.nextLine();
// 解析输入
List<List<Integer>> featureMap = parse2DList(featureMapStr);
List<Integer> poolSize = parseList(poolSizeStr);
int n = featureMap.size();
int w = poolSize.get(0);
int h = poolSize.get(1);
int out_h = n / h;
int out_w = n / w;
List<List<Integer>> result = new ArrayList<>();
for (int i = 0; i < out_h; i++) {
List<Integer> row = new ArrayList<>();
for (int j = 0; j < out_w; j++) {
int start_i = i * h;
int start_j = j * w;
int maxVal = featureMap.get(start_i).get(start_j);
for (int ki = 0; ki < h; ki++) {
for (int kj = 0; kj < w; kj++) {
int cur = featureMap.get(start_i + ki).get(start_j + kj);
if (cur > maxVal) {
maxVal = cur;
}
}
}
row.add(maxVal);
}
result.add(row);
}
// 使用替换方法去除多余空格
String output = result.toString().replace(", ", ",");
System.out.println(output);
}
}
题目内容
某医疗诊断公司在进行医疗图像识别的过程中,使用了基于卷积神经网络(CNN)的机器学习模型。其中,卷积层后的最大池化操作是关键步骤之一,它有助于减少模型的计算负担并提取特征的最重要部分。请根据输入描述和输出描述中的要求,编程实现与 CNN 中最大池化操作相关的子功能。
最大池化:最大池化是一种子采样方法,它在卷积操作后进行,用于减少模型的计算负担并提取特征的最重要部分。最大池化操作的步骤是:在卷积后的特征图中,选取一个固定大小的窗口,以固定的步长滑动窗口,并在每个窗口中选取最大的值作为该窗口的输出。
输入描述
输入第一行是一个 2 维的 list ,用于表示卷积后的特征图,形如 [[1,2,1,2],[3,4,3,4],[1,2,1,2],[3,4,3,4]] ,且所有的 list 组成的都是方阵。
第二行输入是一个 1 维的 list 参数,长度固定为 2 ,形如 [2,2] ,这两个元素分别表示池化窗口的宽度和高度。
输出描述
返回值为一个 2 维的 list ,表示经过最大池化操作后的特征图。
样例1
输入
[[1,2,1,2],[3,4,3,4],[1,2,1,2],[3,4,3,4]]
[2,2]
输出
[[4, 4],[4, 4]]
样例2
输入
[[1,3,2,4],[5,7,6,8],[1,3,2,4],[5,7,6,8]]
[3,3]
输出
[[7]]