在计算机视觉领域,卷积神经网络(CNN)一直是主流技术。然而,越来越多的研究者开始探索将自然语言处理(NLP)领域的 Transformer 架构应用于视觉任务,其中一些研究已取得显著成果。最近,一篇匿名的 ICLR 2021 投稿论文将标准 Transformer 模型直接应用于图像处理,提出了一种新的 Vision Transformer 模型,并在多项图像识别基准测试中表现出接近甚至超越当前最优方法的性能。
在 10 月 2 日,深度学习领域顶级会议 ICLR 2021 的论文投稿结束,其中一篇将 Transformer 应用于图像识别的研究引起了广泛关注。
特斯拉 AI 的负责人 Andrej Karpathy 转发了这篇论文,并表示「乐于看到计算机视觉与 NLP 领域的进一步融合」。

之前,Facebook曾将 Transformer 应用于目标检测任务,OpenAI也尝试使用 GPT-2 进行图像分类。这篇研究在「跨界」应用上又带来了哪些新的尝试呢?
虽然 Transformer 架构在自然语言处理任务中已被广泛应用,但在计算机视觉领域仍面临一定的障碍。通常情况下,注意力机制要么与卷积网络结合使用,要么取代卷积网络的某些组件,而整体架构保持不变。
这项研究表明,依赖 CNN 并非必需。当 Transformer 直接应用于图像块序列时,其在图像分类任务中的表现同样出色。研究团队基于大量数据进行模型的预训练,并将其迁移至多个图像识别基准数据集(如 ImageNet、CIFAR-100、VTAB 等),结果显示,Vision Transformer(ViT)模型的性能可与当前最优的卷积网络相媲美,同时所需的计算资源显著减少。
NLP 领域的 Transformer 与计算机视觉领域的 CNN 比较
基于自注意力机制的架构,尤其是 Transformer,已成为 NLP 领域的主流模型。这种方法通过在大型文本语料库上进行预训练,然后针对较小的特定任务数据集进行微调,展示了优越的性能。得益于 Transformer 的计算效率和可扩展性,研究者们甚至能够训练出参数超过 100B 的模型。随着模型和数据集规模的增加,其性能仍未达到饱和状态。
然而,在计算机视觉领域,卷积架构依然主导。受到 NLP 成功的启发,许多计算机视觉研究尝试将 CNN 架构与自注意力结合,甚至完全替代卷积。尽管理论上有效,由于使用了特定的注意力模式,这些方法尚未在现代硬件加速器上得到有效扩展。因此,在大规模图像识别任务中,传统的 ResNet 架构依然保持领先。
Transformer 在视觉领域的跨界融合
受 NLP 领域中 Transformer 成功缩放的启发,研究团队尝试将标准 Transformer 直接应用于图像,并尽量减少修改。他们将图像分割成多个图像块,并将这些图像块的线性嵌入序列作为 Transformer 的输入。然后,研究团队用 NLP 中处理 token 的方式来处理图像块,并以监督方式训练图像分类模型。
在中等规模的数据集(如 ImageNet)上进行训练时,这种模型的表现并不理想,准确率比同等规模的 ResNet 低几个百分点。这个看似令人失望的结果是可以预见的:Transformer 缺乏一些 CNN 固有的归纳偏置,如平移不变性和局部性,因此在数据量不足时,其泛化能力受到限制。
但是,当在大型数据集(14M-300M 张图像)上训练模型时,情况大为不同。研究发现,大规模训练的优势超过了归纳偏置。在足够的数据规模上进行预训练后,Transformer 可以在少量任务数据上取得优异的结果。
研究提出的 Vision Transformer 在 JFT-300M 数据集上进行预训练,结果在多个图像识别基准上接近或超过了 SOTA 水平:在 ImageNet 上达到了 88.36% 的准确率,在 ImageNet ReaL 上达到了 90.77% 的准确率,在 CIFAR-100 上达到了 94.55% 的准确率,并在 VTAB 基准的 19 个任务中达到了 77.16% 的准确率。
模型与方法
研究团队尽量遵循原始 Transformer 的设计。这种简单的设置具有可扩展性,使得 NLP Transformer 架构及其高效实现几乎可以开箱即用。研究者旨在证明,在适当地扩展的情况下,该方法足以超越当前最佳的卷积神经网络。
Vision Transformer(ViT)
该研究提出的 Vision Transformer 架构遵循原始 Transformer 的设计。下图 1 展示了模型的架构图。
标准 Transformer 接收 1D 序列的 token 嵌入作为输入。为了处理 2D 图像,研究团队将图像 x ∈ R^H×W×C 转换为一系列扁平化的 2D patch x_p ∈ R^N×(P^2·C),其中 (H, W) 表示原始图像的分辨率,(P, P) 表示每个图像 patch 的分辨率。N = HW/P^2 成为 Vision Transformer 的有效序列长度。
Vision Transformer 在所有层使用相同的宽度,因此一个可训练的线性投影将每个向量化的 patch 映射到模型的维度 D(公式 1),相应的输出称为 patch 嵌入。

与 BERT 的 [class] token 类似,研究团队在一系列嵌入 patch(z_0^0 = x_class)之前添加了一个可学习的嵌入,这个嵌入在 Transformer 编码器(z_0^L)输出中的状态可以作为图像表示 y(公式 4)。在预训练和微调阶段,分类头(head)附加在 z_L^0 之后。
位置嵌入被加到 patch 嵌入中,以保留位置信息。研究者尝试了不同的 2D 感知位置嵌入变体,但与标准的 1D 位置嵌入相比,未能取得显著提升。因此,编码器以联合嵌入作为输入。
Transformer 编码器由多个交互层的多头自注意力(MSA)和 MLP 块组成(公式 2、3)。每个块之前施加 LayerNorm(LN),而残差连接则在每个块之后应用。MLP 包含两个施加了 GELU 非线性的层。

作为将图像分割成 patch 的一种替代方案,输出序列可以通过 ResNet 的中间特征图来形成。在这个混合模型中,patch 嵌入投影(公式 1)被早期阶段的 ResNet 取代。ResNet 的一个中间 2D 特征图被扁平化为一个序列,映射到 Transformer 的维度,然后作为 Transformer 的输入序列。最后,像前面所述的那样,将分类输入嵌入和位置嵌入添加到 Transformer 输入中。
微调与更高分辨率
研究者在大型数据集上预训练 ViT 模型,并针对较小规模的下游任务进行微调。为此,研究者移除了预训练的预测头,并添加了一个零初始化的 D×K 前馈层,其中 K 表示下游类的数量。与预训练阶段相比,在更高分辨率时进行微调通常更为有效。当输入更高分辨率的图像时,研究者保持 patch 大小不变,从而获得更大的有效序列长度。
ViT 模型可以处理任意序列长度(受内存限制),但预训练的位置信息可能不再适用。因此,研究者根据预训练位置嵌入在原始图像中的位置,对其进行 2D 插值。需要注意的是,只有在分辨率调整和 patch 提取时,才能将 2D 图像的归纳偏置手动注入到 ViT 模型中。
实验
该研究进行了大量实验,并使用了多个 ViT 模型变体,详见下表 1:

与 SOTA 模型的性能对比
研究团队首先将最大的 ViT 模型(在 JFT-300M 数据集上预训练的 ViT-H/14 和 ViT-L/16)与 SOTA CNN 模型进行对比,结果详见下表 2。

表 2:ViT 模型与 SOTA 模型在流行图像分类基准数据集上的性能对比。
从上表中可以看出,规模较小的 ViT-L/16 模型在所有数据集上的性能与 BIT-L 相当,且所需的计算资源显著更少。更大的 ViT-H/14 模型进一步提升了性能,尤其在更具挑战性的数据集上,如 ImageNet、CIFAR-100 和 VTAB。ViT-H/14 模型在所有数据集上的性能匹配或超越 SOTA,甚至在某些情况下大幅超越 SOTA 模型(如在 CIFAR-100 数据集上高出 1%)。在 ImageNet 数据集上,ViT 模型的性能比 Noisy Student 低约 0.1%,但在具有更干净标签的 ImageNet ReaL 数据集上,ViT 的表现超过了 SOTA 模型。
下图 2 将 VTAB 任务分解为多个组,并对比了 ViT 与 SOTA 方法的性能,这些方法包括 BIT、VIVI 和 S4L。
在 Natural 任务中,ViT-H/14 的性能略低于 BIT-R152x4;在 Specialized 任务中,ViT 的性能超过 BIT 等方法;而在 Structured 任务中,ViT 显著优于其他方法。

预训练数据要求
Vision Transformer 在大型 JFT-300M 数据集上经过预训练后表现出色。在 ViT 的归纳偏置低于 ResNet 的情况下,数据集规模的重要性如何呢?该研究进行了相关实验。
研究者首先在逐渐增大规模的数据集(ImageNet、ImageNet-21k 和 JFT300M)上预训练 ViT 模型。下图 3 展示了模型在 ImageNet 数据集上的性能:

下表 3 展示了模型在 ImageNet、ImageNet-21k 和 JFT300M 数据集上的性能对比情况。在前两个较小的数据集上,ViT-LaRge 模型的表现不如 ViT-Base,但在更大规模的 JFT300M 数据集上,大模型则展现出明显优势。这表明,随着数据集规模的增加,较大的 ViT 模型变体优于较小模型。
