#P3493. 第2题-Group卷积实现
          
                        
                                    
                      
        
              - 
          
          
                      1000ms
            
          
                      Tried: 417
            Accepted: 87
            Difficulty: 5
            
          
          
          
                       所属公司 : 
                              华为
                                
            
                        
              时间 :2025年8月28日-留学生
                              
                      
          
 
- 
                        算法标签>模拟          
 
第2题-Group卷积实现
题解
- 
题面概述:给定展开后的输入张量与卷积核及其形状,和分组数 groups,实现分组卷积(包含深度卷积的特例)前向计算。默认 stride=1、padding=0、dilation=1。若形状与 groups 不合法或输出空间维度非正,则输出 −1。
 - 
关键条件:
- inchannels
 - outchannels
 - kchannels=groupsin_channels
 - 输出尺寸:Hout=H−Kh+1,Wout=W−Kw+1,需要 Hout>0,Wout>0
 
 - 
深度卷积:是分组卷积的特例,groups=inchannels,kchannels=1,允许 outchannels=groups×depthmultiplier。
 - 
思路:
- 解析 5 行输入,校验数据长度与形状乘积一致;
 - 校验分组与通道约束;
 - 计算 Hout,Wout 并校验为正;
 - 按 N、组 g、组内输出通道 oc、空间位置 (oh,ow)、组内输入通道 kc、核 (kh,kw) 六重循环累加;
 - 扁平化顺序为 N→C→H→W 输出。
 
 
C++
#include <bits/stdc++.h>
using namespace std;
// 按空格切分一行到整数数组
static vector<long long> parseLineLL(const string &s) {
    vector<long long> v;
    string cur;
    for (size_t i = 0; i <= s.size(); ++i) {
        if (i == s.size() || isspace(static_cast<unsigned char>(s[i]))) {
            if (!cur.empty()) {
                v.push_back(stoll(cur));
                cur.clear();
            }
        } else {
            cur.push_back(s[i]);
        }
    }
    return v;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    // 读取5行
    string line1, line2, line3, line4, line5;
    if (!std::getline(cin, line1)) { cout << "-1\n-1\n"; return 0; }
    if (!std::getline(cin, line2)) { cout << "-1\n-1\n"; return 0; }
    if (!std::getline(cin, line3)) { cout << "-1\n-1\n"; return 0; }
    if (!std::getline(cin, line4)) { cout << "-1\n-1\n"; return 0; }
    if (!std::getline(cin, line5)) { cout << "-1\n-1\n"; return 0; }
    // 解析
    vector<long long> in_data_ll = parseLineLL(line1);
    vector<long long> in_shape_ll = parseLineLL(line2);
    vector<long long> ker_data_ll = parseLineLL(line3);
    vector<long long> ker_shape_ll = parseLineLL(line4);
    vector<long long> groups_ll = parseLineLL(line5);
    if (in_shape_ll.size() != 4 || ker_shape_ll.size() != 4 || groups_ll.size() != 1) {
        cout << "-1\n-1\n";
        return 0;
    }
    long long N = in_shape_ll[0], C = in_shape_ll[1], H = in_shape_ll[2], W = in_shape_ll[3];
    long long OC = ker_shape_ll[0], KC = ker_shape_ll[1], KH = ker_shape_ll[2], KW = ker_shape_ll[3];
    long long G = groups_ll[0];
    // 校验乘积
    auto prod = [](long long a, long long b){ return a*b; };
    long long in_need = N * C * H * W;
    long long ker_need = OC * KC * KH * KW;
    if (in_need < 0 || ker_need < 0) { cout << "-1\n-1\n"; return 0; } // 溢出或非法
    if ((long long)in_data_ll.size() != in_need || (long long)ker_data_ll.size() != ker_need) {
        cout << "-1\n-1\n"; return 0;
    }
    // 基本合法性
    if (N <= 0 || C <= 0 || H <= 0 || W <= 0 || OC <= 0 || KC <= 0 || KH <= 0 || KW <= 0 || G <= 0) {
        cout << "-1\n-1\n"; return 0;
    }
    if (C % G != 0 || OC % G != 0) {
        cout << "-1\n-1\n"; return 0;
    }
    if (KC != C / G) {
        cout << "-1\n-1\n"; return 0;
    }
    long long Ho = H - KH + 1;
    long long Wo = W - KW + 1;
    if (Ho <= 0 || Wo <= 0) {
        cout << "-1\n-1\n"; return 0;
    }
    // 转为int以便索引(输入可能含负数,累加用long long)
    vector<int> x(in_data_ll.size());
    for (size_t i = 0; i < in_data_ll.size(); ++i) x[i] = (int)in_data_ll[i];
    vector<int> k(ker_data_ll.size());
    for (size_t i = 0; i < ker_data_ll.size(); ++i) k[i] = (int)ker_data_ll[i];
    // 输出缓冲
    vector<long long> y(N * OC * Ho * Wo, 0);
    auto idx_in = [&](long long n, long long c, long long h, long long w) {
        return n * (C * H * W) + c * (H * W) + h * W + w;
    };
    auto idx_ker = [&](long long oc, long long kc, long long kh, long long kw) {
        return oc * (KC * KH * KW) + kc * (KH * KW) + kh * KW + kw;
    };
    auto idx_out = [&](long long n, long long oc, long long h, long long w) {
        return n * (OC * Ho * Wo) + oc * (Ho * Wo) + h * Wo + w;
    };
    long long OCg = OC / G; // 每组输出通道数
    long long KCg = KC;     // 每组输入通道数(核的通道数)
    // 六重循环
    for (long long n = 0; n < N; ++n) {
        for (long long g = 0; g < G; ++g) {
            for (long long ocg = 0; ocg < OCg; ++ocg) {
                long long oc_idx = g * OCg + ocg;
                for (long long oh = 0; oh < Ho; ++oh) {
                    for (long long ow = 0; ow < Wo; ++ow) {
                        long long acc = 0;
                        for (long long kc = 0; kc < KCg; ++kc) {
                            long long ic = g * KCg + kc;
                            for (long long kh = 0; kh < KH; ++kh) {
                                for (long long kw = 0; kw < KW; ++kw) {
                                    long long ih = oh + kh;
                                    long long iw = ow + kw;
                                    int xv = x[idx_in(n, ic, ih, iw)];
                                    int kv = k[idx_ker(oc_idx, kc, kh, kw)];
                                    acc += (long long)xv * (long long)kv;
                                }
                            }
                        }
                        y[idx_out(n, oc_idx, oh, ow)] = acc;
                    }
                }
            }
        }
    }
    // 输出
    // 第一行:数据
    for (size_t i = 0; i < y.size(); ++i) {
        if (i) cout << ' ';
        cout << y[i];
    }
    cout << '\n';
    // 第二行:形状
    cout << N << ' ' << OC << ' ' << Ho << ' ' << Wo << '\n';
    return 0;
}
Python
import sys
def parse_line_to_ints(s: str):
    # 将一行按空格切分为整数列表
    if not s:
        return []
    return [int(x) for x in s.strip().split() if x.strip() != '']
def main():
    data = sys.stdin.read().strip().splitlines()
    if len(data) < 5:
        print("-1")
        print("-1")
        return
    line1, line2, line3, line4, line5 = data[:5]
    in_data = parse_line_to_ints(line1)
    in_shape = parse_line_to_ints(line2)
    ker_data = parse_line_to_ints(line3)
    ker_shape = parse_line_to_ints(line4)
    groups_list = parse_line_to_ints(line5)
    if len(in_shape) != 4 or len(ker_shape) != 4 or len(groups_list) != 1:
        print("-1")
        print("-1")
        return
    N, C, H, W = in_shape
    OC, KC, KH, KW = ker_shape
    G = groups_list[0]
    # 校验乘积与基本参数
    try:
        in_need = N * C * H * W
        ker_need = OC * KC * KH * KW
    except Exception:
        print("-1")
        print("-1")
        return
    if N <= 0 or C <= 0 or H <= 0 or W <= 0 or OC <= 0 or KC <= 0 or KH <= 0 or KW <= 0 or G <= 0:
        print("-1"); print("-1"); return
    if len(in_data) != in_need or len(ker_data) != ker_need:
        print("-1"); print("-1"); return
    if C % G != 0 or OC % G != 0:
        print("-1"); print("-1"); return
    if KC != C // G:
        print("-1"); print("-1"); return
    Ho = H - KH + 1
    Wo = W - KW + 1
    if Ho <= 0 or Wo <= 0:
        print("-1"); print("-1"); return
    # 展开索引函数
    def idx_in(n, c, h, w):
        return n * (C * H * W) + c * (H * W) + h * W + w
    def idx_ker(oc, kc, kh, kw_):
        return oc * (KC * KH * KW) + kc * (KH * KW) + kh * KW + kw_
    def idx_out(n, oc, h, w):
        return n * (OC * Ho * Wo) + oc * (Ho * Wo) + h * Wo + w
    y = [0] * (N * OC * Ho * Wo)
    OCg = OC // G
    KCg = KC
    # 主循环
    for n in range(N):
        for g in range(G):
            for ocg in range(OCg):
                oc_idx = g * OCg + ocg
                for oh in range(Ho):
                    for ow in range(Wo):
                        acc = 0
                        for kc in range(KCg):
                            ic = g * KCg + kc
                            base_in = n * (C * H * W) + ic * (H * W)
                            base_ker = oc_idx * (KC * KH * KW) + kc * (KH * KW)
                            for kh in range(KH):
                                ih = oh + kh
                                row_in = base_in + ih * W
                                row_ker = base_ker + kh * KW
                                for kw_ in range(KW):
                                    acc += in_data[row_in + ow + kw_] * ker_data[row_ker + kw_]
                        y[idx_out(n, oc_idx, oh, ow)] = acc
    # 打印
    print(" ".join(str(v) for v in y))
    print(N, OC, Ho, Wo)
if __name__ == "__main__":
    main()
Java
import java.io.*;
import java.util.*;
public class Main {
    // 按空格切分一行到整数列表
    static List<Long> parseLineLL(String s) {
        List<Long> res = new ArrayList<>();
        if (s == null) return res;
        String[] parts = s.trim().split("\\s+");
        for (String p : parts) {
            if (!p.isEmpty()) res.add(Long.parseLong(p));
        }
        return res;
    }
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String line1 = br.readLine();
        String line2 = br.readLine();
        String line3 = br.readLine();
        String line4 = br.readLine();
        String line5 = br.readLine();
        if (line1 == null || line2 == null || line3 == null || line4 == null || line5 == null) {
            System.out.println("-1");
            System.out.println("-1");
            return;
        }
        List<Long> inDataLL = parseLineLL(line1);
        List<Long> inShapeLL = parseLineLL(line2);
        List<Long> kerDataLL = parseLineLL(line3);
        List<Long> kerShapeLL = parseLineLL(line4);
        List<Long> groupsLL = parseLineLL(line5);
        if (inShapeLL.size() != 4 || kerShapeLL.size() != 4 || groupsLL.size() != 1) {
            System.out.println("-1");
            System.out.println("-1");
            return;
        }
        long N = inShapeLL.get(0), C = inShapeLL.get(1), H = inShapeLL.get(2), W = inShapeLL.get(3);
        long OC = kerShapeLL.get(0), KC = kerShapeLL.get(1), KH = kerShapeLL.get(2), KW = kerShapeLL.get(3);
        long G = groupsLL.get(0);
        // 校验乘积与参数
        long inNeed, kerNeed;
        try {
            inNeed = Math.multiplyExact(Math.multiplyExact(Math.multiplyExact(N, C), H), W);
            kerNeed = Math.multiplyExact(Math.multiplyExact(Math.multiplyExact(OC, KC), KH), KW);
        } catch (ArithmeticException ex) {
            System.out.println("-1");
            System.out.println("-1");
            return;
        }
        if (N <= 0 || C <= 0 || H <= 0 || W <= 0 || OC <= 0 || KC <= 0 || KH <= 0 || KW <= 0 || G <= 0) {
            System.out.println("-1"); System.out.println("-1"); return;
        }
        if (inDataLL.size() != inNeed || kerDataLL.size() != kerNeed) {
            System.out.println("-1"); System.out.println("-1"); return;
        }
        if (C % G != 0 || OC % G != 0) {
            System.out.println("-1"); System.out.println("-1"); return;
        }
        if (KC != C / G) {
            System.out.println("-1"); System.out.println("-1"); return;
        }
        long Ho = H - KH + 1;
        long Wo = W - KW + 1;
        if (Ho <= 0 || Wo <= 0) {
            System.out.println("-1"); System.out.println("-1"); return;
        }
        // 转为int加速索引,累加用long
        int[] x = new int[inDataLL.size()];
        for (int i = 0; i < inDataLL.size(); i++) x[i] = inDataLL.get(i).intValue();
        int[] k = new int[kerDataLL.size()];
        for (int i = 0; i < kerDataLL.size(); i++) k[i] = kerDataLL.get(i).intValue();
        long totalOut = N * OC * Ho * Wo;
        long[] y = new long[(int) totalOut];
        long OCg = OC / G;
        long KCg = KC;
        // 辅助计算步长
        long inStrideN = C * H * W;
        long inStrideC = H * W;
        long kerStrideOC = KC * KH * KW;
        long kerStrideKC = KH * KW;
        long outStrideN = OC * Ho * Wo;
        long outStrideC = Ho * Wo;
        for (long n = 0; n < N; ++n) {
            long baseNIn = n * inStrideN;
            long baseNOut = n * outStrideN;
            for (long g = 0; g < G; ++g) {
                for (long ocg = 0; ocg < OCg; ++ocg) {
                    long ocIdx = g * OCg + ocg;
                    long baseOCOut = baseNOut + ocIdx * outStrideC;
                    long baseOCKer = ocIdx * kerStrideOC;
                    for (int oh = 0; oh < Ho; ++oh) {
                        for (int ow = 0; ow < Wo; ++ow) {
                            long acc = 0;
                            for (int kc = 0; kc < KCg; ++kc) {
                                long ic = g * KCg + kc;
                                long baseICIn = baseNIn + ic * inStrideC;
                                long baseKCKer = baseOCKer + kc * kerStrideKC;
                                for (int kh = 0; kh < KH; ++kh) {
                                    long ih = oh + kh;
                                    long rowIn = baseICIn + ih * W;
                                    long rowKer = baseKCKer + kh * KW;
                                    int offsetIn = (int)(rowIn + ow);
                                    int offsetKer = (int)rowKer;
                                    for (int kw_ = 0; kw_ < KW; ++kw_) {
                                        acc += (long) x[offsetIn + kw_] * (long) k[offsetKer + kw_];
                                    }
                                }
                            }
                            int outIndex = (int)(baseOCOut + oh * Wo + ow);
                            y[outIndex] = acc;
                        }
                    }
                }
            }
        }
        // 输出
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < y.length; i++) {
            if (i > 0) sb.append(' ');
            sb.append(y[i]);
        }
        System.out.println(sb);
        System.out.println(N + " " + OC + " " + Ho + " " + Wo);
    }
}
        题目内容
卷积(Convolution)是计算视觉中常用的计算算子,广泛应用于图像分类、检测、跟踪等多领域。
如下图所示,以 2个三维张量卷积计算为例,取输入张量 为通道数 、高度 、宽度 ,卷积核 为通道数 、高度 、宽度 ,二者执行卷积计算要求其通道数相同。
当取卷积计算步长 ,填充 ,膨胀 ,无偏置项(bias)时,卷积核 在输入张量 上从左至右,从上至下滑动,分别与滑窗所重叠的输入张量 切片,逐元素相乘求和后,得到输出张量 的各元素。
例如:
$y_{0,0}=x_{0,0,0}k_{0,0,0}+x_{0,0,1}k_{0,0,1}+x_{0,1,0}k_{0,1,0}+x_{0,1,1}k_{0,1,1}+x_{1,0,1}k_{1,0,1}+x_{1,1,0}k_{1,1,0}+x_{1,1,1}k_{1,1,1}+x_{2,0,0}k_{2,0,0}+x_{2,1,0}k_{2,1,0}+x_{2,1,1}k_{2,1,1}=72$
面向不同的应用需求,卷积存在多类变种。分组卷积(GroupConvolution)即是随2012年AlexNet提出的一种变种,其将输入张量和卷积核分组后,分别执行卷积计算,然后把多个输出张量进行融合。
例如,输入张量尺寸为 1×32×32×32(其中首个维度1为样本数),卷积核尺寸为4×16×3×3(其中首个维度4 为输出张量通道数,亦可理解为卷积核个数),下图为分组数 时分组卷积计算。
输入张量被切分为1×16×32×32 的两组张量,卷积核被切分为2×16×3×3 的两组张量,分组进行无 padding的卷积计算后,将两组尺寸为1×2×30×30 的计算结果,在第2 个维度拼接,得到尺寸为1×4×30×30 的输出张量。
请不使用PyTorch、MindSpore、PaddlePaddle 等AI 框架,使用编程语言原生库,编写一个支持分组卷积和深度卷积前向传播的函数,根据输入张量、卷积核、分组数,计算输出张量。
输入描述
- 
in_data: 4 维输入张量展开后的数据序列,以空格分隔的正整数;
 - 
in_shape: 4 维输入张量的形状,以空格分隔的 4 个正整数,依次为
- batch size(样本数)
 - in_channels(输入张量通道数)
 - height(高度)
 - width(宽度)
 
 - 
kernel_data: 4 维卷积核展开后的数据序列,以空格分隔的正整数;
 - 
kernel_shape: 4 维卷积核的形状,以空格分隔的 4 个正整数,依次为
- out_channels(输出张量通道数)
 - k_channels(卷积核通道数)
 - kernel_h(卷积核高度)
 - kernel_w(卷积核宽度)
 
 - 
groups: 分组数,需满足
 
in_channels%groups=0 ,out_channels%groups=0, k_channels=groupsin_channels
输出描述
- out_data: 4维输出张量展开后的数据序列,以空格分隔的正整数;
 - out_shape: 4 维输出张量的形状,以空格分隔的4个正整数,依次为
- batch_size(样本数)
 - out_channels(输出张量通道数)
 - height(高度)
 - width(宽度)
 
 
若输入张量和卷积核的形状与 group的取值存在冲突,或出现其它取值冲突导致无法执行卷积计算,则 out_data 和 out_shape 均返回 −1。
样例1
输入
1 2 3 4 5 6 7 8
1 2 2 2
1 0 0 1 -1 0 0 -1
2 1 2 2
2
输出
5 -13
1 2 1 1
说明
输入张量为:
$\left[\left[\begin{array}{ll} 1 & 2 \\ 3 & 4 \end{array}\right],\left[\begin{array}{ll} 5 & 6 \\ 7 & 8 \end{array}\right]\right]$
输入张量形状为1×2×2×2 ;
卷积核为:
$\left[\left[\begin{array}{ll} 1 & 0 \\ 0 & 1 \end{array}\right],\left[\begin{array}{cc} -1 & 0 \\ 0 & -1 \end{array}\right]\right]$
卷积核形状为2×1×2×2 ;
分组数为2 ,输出张量为:
⌈[5],⌈−13]⌉
输出张量形状为1×2×1×1 。
样例2
输入
1 2 3 4 5 6 7 8 9
1 1 3 3
1 0 0 -1
1 1 2 2
2
输出
-1
-1
说明
解释:
由于inchannels=1、outchannels=1,不满足
in_channels%groups=0 ,out_channels%groups=0,
的条件,因此 out_data 和 out_shape 均返回 -1。