#P3493. 第2题-Group卷积实现
-
1000ms
Tried: 425
Accepted: 89
Difficulty: 5
所属公司 :
华为
时间 :2025年8月28日-留学生
-
算法标签>模拟
第2题-Group卷积实现
题解
-
题面概述:给定展开后的输入张量与卷积核及其形状,和分组数 groups,实现分组卷积(包含深度卷积的特例)前向计算。默认 stride=1、padding=0、dilation=1。若形状与 groups 不合法或输出空间维度非正,则输出 −1。
-
关键条件:
- inchannels
- outchannels
- kchannels=groupsin_channels
- 输出尺寸:Hout=H−Kh+1,Wout=W−Kw+1,需要 Hout>0,Wout>0
-
深度卷积:是分组卷积的特例,groups=inchannels,kchannels=1,允许 outchannels=groups×depthmultiplier。
-
思路:
- 解析 5 行输入,校验数据长度与形状乘积一致;
- 校验分组与通道约束;
- 计算 Hout,Wout 并校验为正;
- 按 N、组 g、组内输出通道 oc、空间位置 (oh,ow)、组内输入通道 kc、核 (kh,kw) 六重循环累加;
- 扁平化顺序为 N→C→H→W 输出。
C++
#include <bits/stdc++.h>
using namespace std;
// 按空格切分一行到整数数组
static vector<long long> parseLineLL(const string &s) {
vector<long long> v;
string cur;
for (size_t i = 0; i <= s.size(); ++i) {
if (i == s.size() || isspace(static_cast<unsigned char>(s[i]))) {
if (!cur.empty()) {
v.push_back(stoll(cur));
cur.clear();
}
} else {
cur.push_back(s[i]);
}
}
return v;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
// 读取5行
string line1, line2, line3, line4, line5;
if (!std::getline(cin, line1)) { cout << "-1\n-1\n"; return 0; }
if (!std::getline(cin, line2)) { cout << "-1\n-1\n"; return 0; }
if (!std::getline(cin, line3)) { cout << "-1\n-1\n"; return 0; }
if (!std::getline(cin, line4)) { cout << "-1\n-1\n"; return 0; }
if (!std::getline(cin, line5)) { cout << "-1\n-1\n"; return 0; }
// 解析
vector<long long> in_data_ll = parseLineLL(line1);
vector<long long> in_shape_ll = parseLineLL(line2);
vector<long long> ker_data_ll = parseLineLL(line3);
vector<long long> ker_shape_ll = parseLineLL(line4);
vector<long long> groups_ll = parseLineLL(line5);
if (in_shape_ll.size() != 4 || ker_shape_ll.size() != 4 || groups_ll.size() != 1) {
cout << "-1\n-1\n";
return 0;
}
long long N = in_shape_ll[0], C = in_shape_ll[1], H = in_shape_ll[2], W = in_shape_ll[3];
long long OC = ker_shape_ll[0], KC = ker_shape_ll[1], KH = ker_shape_ll[2], KW = ker_shape_ll[3];
long long G = groups_ll[0];
// 校验乘积
auto prod = [](long long a, long long b){ return a*b; };
long long in_need = N * C * H * W;
long long ker_need = OC * KC * KH * KW;
if (in_need < 0 || ker_need < 0) { cout << "-1\n-1\n"; return 0; } // 溢出或非法
if ((long long)in_data_ll.size() != in_need || (long long)ker_data_ll.size() != ker_need) {
cout << "-1\n-1\n"; return 0;
}
// 基本合法性
if (N <= 0 || C <= 0 || H <= 0 || W <= 0 || OC <= 0 || KC <= 0 || KH <= 0 || KW <= 0 || G <= 0) {
cout << "-1\n-1\n"; return 0;
}
if (C % G != 0 || OC % G != 0) {
cout << "-1\n-1\n"; return 0;
}
if (KC != C / G) {
cout << "-1\n-1\n"; return 0;
}
long long Ho = H - KH + 1;
long long Wo = W - KW + 1;
if (Ho <= 0 || Wo <= 0) {
cout << "-1\n-1\n"; return 0;
}
// 转为int以便索引(输入可能含负数,累加用long long)
vector<int> x(in_data_ll.size());
for (size_t i = 0; i < in_data_ll.size(); ++i) x[i] = (int)in_data_ll[i];
vector<int> k(ker_data_ll.size());
for (size_t i = 0; i < ker_data_ll.size(); ++i) k[i] = (int)ker_data_ll[i];
// 输出缓冲
vector<long long> y(N * OC * Ho * Wo, 0);
auto idx_in = [&](long long n, long long c, long long h, long long w) {
return n * (C * H * W) + c * (H * W) + h * W + w;
};
auto idx_ker = [&](long long oc, long long kc, long long kh, long long kw) {
return oc * (KC * KH * KW) + kc * (KH * KW) + kh * KW + kw;
};
auto idx_out = [&](long long n, long long oc, long long h, long long w) {
return n * (OC * Ho * Wo) + oc * (Ho * Wo) + h * Wo + w;
};
long long OCg = OC / G; // 每组输出通道数
long long KCg = KC; // 每组输入通道数(核的通道数)
// 六重循环
for (long long n = 0; n < N; ++n) {
for (long long g = 0; g < G; ++g) {
for (long long ocg = 0; ocg < OCg; ++ocg) {
long long oc_idx = g * OCg + ocg;
for (long long oh = 0; oh < Ho; ++oh) {
for (long long ow = 0; ow < Wo; ++ow) {
long long acc = 0;
for (long long kc = 0; kc < KCg; ++kc) {
long long ic = g * KCg + kc;
for (long long kh = 0; kh < KH; ++kh) {
for (long long kw = 0; kw < KW; ++kw) {
long long ih = oh + kh;
long long iw = ow + kw;
int xv = x[idx_in(n, ic, ih, iw)];
int kv = k[idx_ker(oc_idx, kc, kh, kw)];
acc += (long long)xv * (long long)kv;
}
}
}
y[idx_out(n, oc_idx, oh, ow)] = acc;
}
}
}
}
}
// 输出
// 第一行:数据
for (size_t i = 0; i < y.size(); ++i) {
if (i) cout << ' ';
cout << y[i];
}
cout << '\n';
// 第二行:形状
cout << N << ' ' << OC << ' ' << Ho << ' ' << Wo << '\n';
return 0;
}
Python
import sys
def parse_line_to_ints(s: str):
# 将一行按空格切分为整数列表
if not s:
return []
return [int(x) for x in s.strip().split() if x.strip() != '']
def main():
data = sys.stdin.read().strip().splitlines()
if len(data) < 5:
print("-1")
print("-1")
return
line1, line2, line3, line4, line5 = data[:5]
in_data = parse_line_to_ints(line1)
in_shape = parse_line_to_ints(line2)
ker_data = parse_line_to_ints(line3)
ker_shape = parse_line_to_ints(line4)
groups_list = parse_line_to_ints(line5)
if len(in_shape) != 4 or len(ker_shape) != 4 or len(groups_list) != 1:
print("-1")
print("-1")
return
N, C, H, W = in_shape
OC, KC, KH, KW = ker_shape
G = groups_list[0]
# 校验乘积与基本参数
try:
in_need = N * C * H * W
ker_need = OC * KC * KH * KW
except Exception:
print("-1")
print("-1")
return
if N <= 0 or C <= 0 or H <= 0 or W <= 0 or OC <= 0 or KC <= 0 or KH <= 0 or KW <= 0 or G <= 0:
print("-1"); print("-1"); return
if len(in_data) != in_need or len(ker_data) != ker_need:
print("-1"); print("-1"); return
if C % G != 0 or OC % G != 0:
print("-1"); print("-1"); return
if KC != C // G:
print("-1"); print("-1"); return
Ho = H - KH + 1
Wo = W - KW + 1
if Ho <= 0 or Wo <= 0:
print("-1"); print("-1"); return
# 展开索引函数
def idx_in(n, c, h, w):
return n * (C * H * W) + c * (H * W) + h * W + w
def idx_ker(oc, kc, kh, kw_):
return oc * (KC * KH * KW) + kc * (KH * KW) + kh * KW + kw_
def idx_out(n, oc, h, w):
return n * (OC * Ho * Wo) + oc * (Ho * Wo) + h * Wo + w
y = [0] * (N * OC * Ho * Wo)
OCg = OC // G
KCg = KC
# 主循环
for n in range(N):
for g in range(G):
for ocg in range(OCg):
oc_idx = g * OCg + ocg
for oh in range(Ho):
for ow in range(Wo):
acc = 0
for kc in range(KCg):
ic = g * KCg + kc
base_in = n * (C * H * W) + ic * (H * W)
base_ker = oc_idx * (KC * KH * KW) + kc * (KH * KW)
for kh in range(KH):
ih = oh + kh
row_in = base_in + ih * W
row_ker = base_ker + kh * KW
for kw_ in range(KW):
acc += in_data[row_in + ow + kw_] * ker_data[row_ker + kw_]
y[idx_out(n, oc_idx, oh, ow)] = acc
# 打印
print(" ".join(str(v) for v in y))
print(N, OC, Ho, Wo)
if __name__ == "__main__":
main()
Java
import java.io.*;
import java.util.*;
public class Main {
// 按空格切分一行到整数列表
static List<Long> parseLineLL(String s) {
List<Long> res = new ArrayList<>();
if (s == null) return res;
String[] parts = s.trim().split("\\s+");
for (String p : parts) {
if (!p.isEmpty()) res.add(Long.parseLong(p));
}
return res;
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String line1 = br.readLine();
String line2 = br.readLine();
String line3 = br.readLine();
String line4 = br.readLine();
String line5 = br.readLine();
if (line1 == null || line2 == null || line3 == null || line4 == null || line5 == null) {
System.out.println("-1");
System.out.println("-1");
return;
}
List<Long> inDataLL = parseLineLL(line1);
List<Long> inShapeLL = parseLineLL(line2);
List<Long> kerDataLL = parseLineLL(line3);
List<Long> kerShapeLL = parseLineLL(line4);
List<Long> groupsLL = parseLineLL(line5);
if (inShapeLL.size() != 4 || kerShapeLL.size() != 4 || groupsLL.size() != 1) {
System.out.println("-1");
System.out.println("-1");
return;
}
long N = inShapeLL.get(0), C = inShapeLL.get(1), H = inShapeLL.get(2), W = inShapeLL.get(3);
long OC = kerShapeLL.get(0), KC = kerShapeLL.get(1), KH = kerShapeLL.get(2), KW = kerShapeLL.get(3);
long G = groupsLL.get(0);
// 校验乘积与参数
long inNeed, kerNeed;
try {
inNeed = Math.multiplyExact(Math.multiplyExact(Math.multiplyExact(N, C), H), W);
kerNeed = Math.multiplyExact(Math.multiplyExact(Math.multiplyExact(OC, KC), KH), KW);
} catch (ArithmeticException ex) {
System.out.println("-1");
System.out.println("-1");
return;
}
if (N <= 0 || C <= 0 || H <= 0 || W <= 0 || OC <= 0 || KC <= 0 || KH <= 0 || KW <= 0 || G <= 0) {
System.out.println("-1"); System.out.println("-1"); return;
}
if (inDataLL.size() != inNeed || kerDataLL.size() != kerNeed) {
System.out.println("-1"); System.out.println("-1"); return;
}
if (C % G != 0 || OC % G != 0) {
System.out.println("-1"); System.out.println("-1"); return;
}
if (KC != C / G) {
System.out.println("-1"); System.out.println("-1"); return;
}
long Ho = H - KH + 1;
long Wo = W - KW + 1;
if (Ho <= 0 || Wo <= 0) {
System.out.println("-1"); System.out.println("-1"); return;
}
// 转为int加速索引,累加用long
int[] x = new int[inDataLL.size()];
for (int i = 0; i < inDataLL.size(); i++) x[i] = inDataLL.get(i).intValue();
int[] k = new int[kerDataLL.size()];
for (int i = 0; i < kerDataLL.size(); i++) k[i] = kerDataLL.get(i).intValue();
long totalOut = N * OC * Ho * Wo;
long[] y = new long[(int) totalOut];
long OCg = OC / G;
long KCg = KC;
// 辅助计算步长
long inStrideN = C * H * W;
long inStrideC = H * W;
long kerStrideOC = KC * KH * KW;
long kerStrideKC = KH * KW;
long outStrideN = OC * Ho * Wo;
long outStrideC = Ho * Wo;
for (long n = 0; n < N; ++n) {
long baseNIn = n * inStrideN;
long baseNOut = n * outStrideN;
for (long g = 0; g < G; ++g) {
for (long ocg = 0; ocg < OCg; ++ocg) {
long ocIdx = g * OCg + ocg;
long baseOCOut = baseNOut + ocIdx * outStrideC;
long baseOCKer = ocIdx * kerStrideOC;
for (int oh = 0; oh < Ho; ++oh) {
for (int ow = 0; ow < Wo; ++ow) {
long acc = 0;
for (int kc = 0; kc < KCg; ++kc) {
long ic = g * KCg + kc;
long baseICIn = baseNIn + ic * inStrideC;
long baseKCKer = baseOCKer + kc * kerStrideKC;
for (int kh = 0; kh < KH; ++kh) {
long ih = oh + kh;
long rowIn = baseICIn + ih * W;
long rowKer = baseKCKer + kh * KW;
int offsetIn = (int)(rowIn + ow);
int offsetKer = (int)rowKer;
for (int kw_ = 0; kw_ < KW; ++kw_) {
acc += (long) x[offsetIn + kw_] * (long) k[offsetKer + kw_];
}
}
}
int outIndex = (int)(baseOCOut + oh * Wo + ow);
y[outIndex] = acc;
}
}
}
}
}
// 输出
StringBuilder sb = new StringBuilder();
for (int i = 0; i < y.length; i++) {
if (i > 0) sb.append(' ');
sb.append(y[i]);
}
System.out.println(sb);
System.out.println(N + " " + OC + " " + Ho + " " + Wo);
}
}
题目内容
卷积(Convolution)是计算视觉中常用的计算算子,广泛应用于图像分类、检测、跟踪等多领域。
如下图所示,以 2个三维张量卷积计算为例,取输入张量 为通道数 、高度 、宽度 ,卷积核 为通道数 、高度 、宽度 ,二者执行卷积计算要求其通道数相同。
当取卷积计算步长 ,填充 ,膨胀 ,无偏置项(bias)时,卷积核 在输入张量 上从左至右,从上至下滑动,分别与滑窗所重叠的输入张量 切片,逐元素相乘求和后,得到输出张量 的各元素。
例如:
$y_{0,0}=x_{0,0,0}k_{0,0,0}+x_{0,0,1}k_{0,0,1}+x_{0,1,0}k_{0,1,0}+x_{0,1,1}k_{0,1,1}+x_{1,0,1}k_{1,0,1}+x_{1,1,0}k_{1,1,0}+x_{1,1,1}k_{1,1,1}+x_{2,0,0}k_{2,0,0}+x_{2,1,0}k_{2,1,0}+x_{2,1,1}k_{2,1,1}=72$
面向不同的应用需求,卷积存在多类变种。分组卷积(GroupConvolution)即是随2012年AlexNet提出的一种变种,其将输入张量和卷积核分组后,分别执行卷积计算,然后把多个输出张量进行融合。
例如,输入张量尺寸为 1×32×32×32(其中首个维度1为样本数),卷积核尺寸为4×16×3×3(其中首个维度4 为输出张量通道数,亦可理解为卷积核个数),下图为分组数 时分组卷积计算。
输入张量被切分为1×16×32×32 的两组张量,卷积核被切分为2×16×3×3 的两组张量,分组进行无 padding的卷积计算后,将两组尺寸为1×2×30×30 的计算结果,在第2 个维度拼接,得到尺寸为1×4×30×30 的输出张量。
请不使用PyTorch、MindSpore、PaddlePaddle 等AI 框架,使用编程语言原生库,编写一个支持分组卷积和深度卷积前向传播的函数,根据输入张量、卷积核、分组数,计算输出张量。
输入描述
-
in_data: 4 维输入张量展开后的数据序列,以空格分隔的正整数;
-
in_shape: 4 维输入张量的形状,以空格分隔的 4 个正整数,依次为
- batch size(样本数)
- in_channels(输入张量通道数)
- height(高度)
- width(宽度)
-
kernel_data: 4 维卷积核展开后的数据序列,以空格分隔的正整数;
-
kernel_shape: 4 维卷积核的形状,以空格分隔的 4 个正整数,依次为
- out_channels(输出张量通道数)
- k_channels(卷积核通道数)
- kernel_h(卷积核高度)
- kernel_w(卷积核宽度)
-
groups: 分组数,需满足
in_channels%groups=0 ,out_channels%groups=0, k_channels=groupsin_channels
输出描述
- out_data: 4维输出张量展开后的数据序列,以空格分隔的正整数;
- out_shape: 4 维输出张量的形状,以空格分隔的4个正整数,依次为
- batch_size(样本数)
- out_channels(输出张量通道数)
- height(高度)
- width(宽度)
若输入张量和卷积核的形状与 group的取值存在冲突,或出现其它取值冲突导致无法执行卷积计算,则 out_data 和 out_shape 均返回 −1。
样例1
输入
1 2 3 4 5 6 7 8
1 2 2 2
1 0 0 1 -1 0 0 -1
2 1 2 2
2
输出
5 -13
1 2 1 1
说明
输入张量为:
$\left[\left[\begin{array}{ll} 1 & 2 \\ 3 & 4 \end{array}\right],\left[\begin{array}{ll} 5 & 6 \\ 7 & 8 \end{array}\right]\right]$
输入张量形状为1×2×2×2 ;
卷积核为:
$\left[\left[\begin{array}{ll} 1 & 0 \\ 0 & 1 \end{array}\right],\left[\begin{array}{cc} -1 & 0 \\ 0 & -1 \end{array}\right]\right]$
卷积核形状为2×1×2×2 ;
分组数为2 ,输出张量为:
⌈[5],⌈−13]⌉
输出张量形状为1×2×1×1 。
样例2
输入
1 2 3 4 5 6 7 8 9
1 1 3 3
1 0 0 -1
1 1 2 2
2
输出
-1
-1
说明
解释:
由于inchannels=1、outchannels=1,不满足
in_channels%groups=0 ,out_channels%groups=0,
的条件,因此 out_data 和 out_shape 均返回 -1。