#P4481. 第2题-Vision Transformer中的Patch Embdding层实现
-
1000ms
Tried: 4
Accepted: 4
Difficulty: 5
所属公司 :
华为
时间 :2025年11月20日-AI方向
-
算法标签>机器学习算法
第2题-Vision Transformer中的Patch Embdding层实现
解题思路
Patch Embedding 的本质是一个分块 + 展开 + 线性变换的过程,可以理解为对图像做一次“卷积核为 patch,步长为 patch_size 的卷积 + reshape”,再加上一个 cls token。这里我们只需要计算输出向量序列的维度,而不是具体做矩阵运算。
设输入参数为:
- 图像尺寸:
img_size(高 = 宽) - patch 大小:
patch_size - 通道数:
channel - 嵌入维度:
embedding_dim
1. 计算 patch 的个数
图像被均匀切成大小为 patch_size × patch_size 的不重叠 patch:
-
每一维上的 patch 个数:
N=patch_sizeimg_size -
总 patch 数目:
$$\text{num\_patches} = N \times N = \left(\frac{\text{img\_size}}{\text{patch\_size}}\right)^2 $$
题目默认输入是合法的,所以可以认为 img_size 能整除 patch_size。
2. 展开并线性变换
每个 patch 的原始维度为:
$$\text{patch\_dim} = \text{patch\_size} \times \text{patch\_size} \times \text{channel} $$线性嵌入用一个权重矩阵 E 和偏置 b:
- E 的形状:patch_dim×embedding_dim
- b 的形状:embedding_dim
对每个 patch 展开后的向量 X 做线性变换:
Z=X×E+b因此每个 patch 最终变成一个长度为 embedding_dim 的向量。
所有 patch 经过线性变换后得到:
(num_patches, embedding_dim)3. 添加 CLS Token
ViT 会额外添加一个可学习的 cls token,其维度和单个 patch 的嵌入相同,为 (embedding_dim,)。
拼接到序列前面之后,序列长度变为:
因此最终的 patch embedding 输出维度(不含 batch 维)为:
$$\text{embedding\_shape} = (\text{num\_patches} + 1,\ \text{embedding\_dim}) $$4. 套用样例
样例输入:448 32 3 384
- 每边 patch 数:N=448/32=14
- 总 patch 数:num_patches=142=196
- 加 cls token:num_tokens=196+1=197
- 每个 token 维度:384
所以输出为:
embedding_shape=(197, 384)即:197 384
代码实现
Python
# 计算 Patch Embedding 输出维度的函数
def get_embedding_shape(img_size, patch_size, channel, embedding_dim):
# 每一维上的 patch 个数
num_per_dim = img_size // patch_size
# 总的 patch 个数
num_patches = num_per_dim * num_per_dim
# 加上一个 cls token
num_tokens = num_patches + 1
# 返回 (序列长度, 嵌入维度)
return num_tokens, embedding_dim
def main():
# 读取输入:img_size patch_size channel embedding_dim
img_size, patch_size, channel, embedding_dim = map(int, input().split())
# 调用函数计算结果
tokens, dim = get_embedding_shape(img_size, patch_size, channel, embedding_dim)
# 按题目要求输出
print(tokens, dim)
if __name__ == "__main__":
main()
Java
import java.util.Scanner;
public class Main {
// 计算 Patch Embedding 输出维度的函数
public static int[] getEmbeddingShape(int imgSize, int patchSize, int channel, int embeddingDim) {
// 每一维上的 patch 个数
int numPerDim = imgSize / patchSize;
// 总的 patch 个数
int numPatches = numPerDim * numPerDim;
// 加上一个 cls token
int numTokens = numPatches + 1;
// 返回 (序列长度, 嵌入维度)
return new int[]{numTokens, embeddingDim};
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
// 读取输入:img_size patch_size channel embedding_dim
int imgSize = sc.nextInt();
int patchSize = sc.nextInt();
int channel = sc.nextInt();
int embeddingDim = sc.nextInt();
sc.close();
// 计算结果
int[] res = getEmbeddingShape(imgSize, patchSize, channel, embeddingDim);
// 按题目要求输出
System.out.println(res[0] + " " + res[1]);
}
}
C++
#include <iostream>
using namespace std;
// 计算 Patch Embedding 输出维度的函数
void getEmbeddingShape(int img_size, int patch_size, int channel, int embedding_dim,
int &num_tokens, int &dim_out) {
// 每一维上的 patch 个数
int num_per_dim = img_size / patch_size;
// 总的 patch 个数
int num_patches = num_per_dim * num_per_dim;
// 加上一个 cls token
num_tokens = num_patches + 1;
// 嵌入维度不变
dim_out = embedding_dim;
}
int main() {
int img_size, patch_size, channel, embedding_dim;
// 读取输入:img_size patch_size channel embedding_dim
cin >> img_size >> patch_size >> channel >> embedding_dim;
int tokens, dim;
// 调用函数计算输出维度
getEmbeddingShape(img_size, patch_size, channel, embedding_dim, tokens, dim);
// 按题目要求输出
cout << tokens << " " << dim << endl;
return 0;
}
题目内容
Vision Transformer(ViT) 是视觉领域应用非常广泛的基础网络结构,经典的 ViT 结构如图所示,
其包含了 Patch&Position embedding、Transformer Encoder 等多个关键模块组成。这几个模块中,将图像分割为固定大小的 patch 并进行线性嵌入是一个关键步骤,也即 Patch Embedding 层,其主要实现步骤为:
Step 1:将输入图像分割为多个非重叠的 patch ,也即将图片切分为 N∗N 个 patch ,如 3∗3 个 2D 图像块;
Step 2:将每个 patch 展平为向量,也即将每个切分后的 2D Patch 展平为 1D 向量;
Step 3:对展平的 patch 进行线性变换(嵌入),也即对每个展平后的 1D 向量做一个线性变换,使用一个可学习的权重矩阵 E 和 偏置向量 B 进行线性变换,公式为:Z=X∗E+b
Step 4:添加可学习的位置编码;
请根据以上提示步骤,实现 Patch Embedding 层。
特别注意:本实现过程中,无法使用深度学习框架,如 pytorch、tensorflow 等
输入描述
输入参数包括:imp_size、patch_size、channel、embedding_dim,分别表示:
图像尺寸(图像长、宽默认相等)img_size ;
patch 大小 patch_size ;
图像通道数 channels ;
嵌入维度 embedding_dim
输出描述
输出 patch_embedding 后的维度信息 embedding_shape,其中需要包含 cis token,具体可见样例。
样例1
输入
448 32 3 384
输出
197 384
说明
输入:448 32 3 384
分别表示:
图像尺寸(图像长、宽默认相等)img_size=448 ;
patch 大小 patch_size=32 ;
图像通道数 channels=3 ;
嵌入维度 embedding_dim=384
输出:197 384
分别表示:
经过 patch_embedding 层后得到的 embedding_shape ,其中第一维 197 表示 patch token+cis token ,第二维表示 patch_embedding 后的 enbedding 维度
样例2
输入
224 16 3 768
输出
197 768
说明
输入:224 16 3 768
分别表示:
图像尺寸(图像长、宽默认相等)img_size=224 ;
patch 大小 patch_size=16 ;
图像通道数 channels=3 ;
嵌入维度 embedding_dim=768
输出:197 768
分别表示:
经过 patch_embedding 层后得到的 embedding_shape ,其中第一维 197 表示 patch token+cis token ,第二维表示 patch_embedding 后的 enbedding 维度