Patch Embedding 的本质是一个分块 + 展开 + 线性变换的过程,可以理解为对图像做一次“卷积核为 patch,步长为 patch_size 的卷积 + reshape”,再加上一个 cls token。这里我们只需要计算输出向量序列的维度,而不是具体做矩阵运算。
设输入参数为:
img_size(高 = 宽)patch_sizechannelVision Transformer(ViT) 是视觉领域应用非常广泛的基础网络结构,经典的 ViT 结构如图所示,
其包含了 Patch&Position embedding、Transformer Encoder 等多个关键模块组成。这几个模块中,将图像分割为固定大小的 patch 并进行线性嵌入是一个关键步骤,也即 Patch Embedding 层,其主要实现步骤为:
Step 1:将输入图像分割为多个非重叠的 patch ,也即将图片切分为 N∗N 个 patch ,如 3∗3 个 2D 图像块;
Step 2:将每个 patch 展平为向量,也即将每个切分后的 2D Patch 展平为 1D 向量;
Step 3:对展平的 patch 进行线性变换(嵌入),也即对每个展平后的 1D 向量做一个线性变换,使用一个可学习的权重矩阵 E 和 偏置向量 B 进行线性变换,公式为:Z=X∗E+b
Step 4:添加可学习的位置编码;
请根据以上提示步骤,实现 Patch Embedding 层。
特别注意:本实现过程中,无法使用深度学习框架,如 pytorch、tensorflow 等
输入参数包括:imp_size、patch_size、channel、embedding_dim,分别表示:
图像尺寸(图像长、宽默认相等)img_size ;
patch 大小 patch_size ;
图像通道数 channels ;
嵌入维度 embedding_dim
输出 patch_embedding 后的维度信息 embedding_shape,其中需要包含 cis token,具体可见样例。
输入
448 32 3 384
输出
197 384
说明
输入:448 32 3 384
分别表示:
图像尺寸(图像长、宽默认相等)img_size=448 ;
patch 大小 patch_size=32 ;
图像通道数 channels=3 ;
嵌入维度 embedding_dim=384
输出:197 384
分别表示:
经过 patch_embedding 层后得到的 embedding_shape ,其中第一维 197 表示 patch token+cis token ,第二维表示 patch_embedding 后的 enbedding 维度
输入
224 16 3 768
输出
197 768
说明
输入:224 16 3 768
分别表示:
图像尺寸(图像长、宽默认相等)img_size=224 ;
patch 大小 patch_size=16 ;
图像通道数 channels=3 ;
嵌入维度 embedding_dim=768
输出:197 768
分别表示:
经过 patch_embedding 层后得到的 embedding_shape ,其中第一维 197 表示 patch token+cis token ,第二维表示 patch_embedding 后的 enbedding 维度