博士生首次利用双Transformer构建GAN
最近,计算机视觉领域的研究者对 Transformer 的关注不断增加,并取得了显著进展。这表明,Transformer 可能成为视觉任务中一种强大的通用模型。
我们都很想知道:在计算机视觉的广阔领域,Transformer 的潜力究竟有多大?在更具挑战性的任务,比如生成对抗网络(GAN)中,Transformer 的表现又如何呢?
在这种探究的推动下,德克萨斯大学奥斯汀分校的 YiFan Jiang、Zhangyang Wang,以及 IBM Research 的 Shiyu Chang 等研究者进行了开创性的实验,构建了一个纯粹基于 Transformer 架构且不使用卷积的 GAN,命名为 TRanSGAN。与其他基于 Transformer 的视觉模型相比,单纯利用 Transformer 构建 GAN 的难度更大,因为真实图像生成的要求高于分类等任务,并且 GAN 的训练本身存在不稳定性的问题。
从结构上来看,TRanSGAN 包含两个部分:一个是内存友好的基于 Transformer 的生成器,能够逐步提升特征分辨率并降低嵌入维度;另一个是基于 Transformer 的 patch 级判别器。
研究者发现,TRanSGAN 的性能显著受益于数据增强、多任务协同训练策略以及强调自然图像邻域平滑的局部初始化自注意力。这些发现表明,TRanSGAN 可以有效扩展到更大的模型和高分辨率的图像数据集。
实验结果显示,与当前基于卷积骨干的最先进 GAN 相比,TRanSGAN 的表现极具竞争力。具体而言,TRanSGAN 在 STL-10 数据集上的 IS 评分为 10.10,FID 为 25.32,创造了新的最先进水平。
这项研究表明,GAN 并不一定依赖于卷积骨干和许多专用模块,纯粹的 Transformer 也具备生成图像的足够能力。
然而,部分研究者也表达了担忧:在 Transformer 技术广泛应用的背景下,小型实验室将如何生存?
如果 Transformer 真正成为社区的“必需品”,如何提升这类架构的计算效率将成为亟待解决的研究难题。
作为基础模块的 Transformer 编码器
研究者决定将 Transformer 编码器作为基础结构,并尽量减少改动。编码器由两个部分组成,第一个部分是多头自注意力模块,第二个部分是带有 GELU 非线性的前馈 MLP。此外,研究者在两个部分之间应用了层归一化,并使用了残差连接。
内存友好的生成器
在自然语言处理(NLP)中,Transformer 通常将每个词作为输入。然而,如果按照类似的方式通过堆叠 Transformer 编码器逐像素生成图像,低分辨率图像可能导致长序列和更高的自注意力开销。
因此,为了避免过高的开销,研究者受到基于卷积的 GAN 设计理念的启发,采用多个阶段逐步提升分辨率。他们的策略是逐步增加输入序列并降低嵌入维度。
如下图 1 左所示,研究者提出了一个具有多个阶段的内存友好的基于 Transformer 的生成器:
每个阶段堆叠了多个编码器块。通过分段设计,研究者逐步提高特征图的分辨率,直至达到目标分辨率。具体而言,该生成器以随机噪声作为输入,并通过一个 MLP 将随机噪声转换为长度为 H×W×C 的向量。该向量随后变形为分辨率为 H×W 的特征图,每个点为 C 维嵌入。接着,该特征图被视为长度为 64 的 C 维 Token 序列,并与可学习的位置编码结合。
与 BERT 类似,该研究提出的 Transformer 编码器以嵌入 Token 作为输入,并递归计算每个 Token 之间的关系。为了合成更高分辨率的图像,研究者在每个阶段后插入了一个由 ReshaPING 和 pixelshuffle 模块组成的上采样模块。
具体操作中,上采样模块首先将 1D 序列的 Token 嵌入变形为 2D 特征图,然后通过 pixelshuffle 模块对该特征图进行上采样,并降低嵌入维度,最终得到输出。
然后,2D 特征图 X’_0 再次变形为嵌入 Token 的 1D 序列,其中 Token 数为 4HW,嵌入维度为 C/4。因此,在每个阶段,分辨率提升至两倍,嵌入维度 C 减少至输入的四分之一。这一权衡策略有效缓和了内存和计算需求的增长。
研究者在多个阶段重复上述流程,直到分辨率达到预期。
用于判别器的 Tokenized 输入
与需要精确合成每个像素的生成器不同,该研究提出的判别器只需分辨图像的真假。这使得研究者可以在语义上将输入图像 Tokenize 为更粗糙的 patch 级别。
如上图 1 右所示,判别器以图像的 patch 作为输入。研究者将输入图像分解为 8×8 个 patch,每个 patch 可视为一个“词”。然后,这 64 个 patch 通过线性 flatten 层转化为 Token 嵌入的 1D 序列,其中 Token 数 N = 8×8 = 64,嵌入维度为 C。接下来,研究者在 1D 序列的开头添加了可学习的位置编码和一个 [cls] Token。在经过 Transformer 编码器后,分类头仅使用 [cls] Token 输出真假预测。
研究者在 CIFAR-10 数据集上对 TRanSGAN 和最近的基于卷积的 GAN 进行了比较,结果如下表 5 所示:
根据表 5,TRanSGAN 在 IS 评分上超越了 AutoGAN,并在许多竞争者中表现更佳,仅次于 ProgReSSive GAN 和 styleGAN v2。
在 FID 结果对比中,研究发现,TRanSGAN 甚至优于 ProgReSSive GAN,略低于 styleGAN v2。CIFAR-10 上生成的可视化示例如下图 4 所示:
研究者将 TRanSGAN 应用于另一个流行的 48×48 分辨率基准 STL-10。为了适应目标分辨率,该研究将第一阶段的输入特征图从(8×8)=64 增加到(12×12)=144,然后将 TRanSGAN-XL 与自动搜索的 ConvNets 和手工设计的 ConvNets 进行了比较,结果如下表 6 所示:
与 CIFAR-10 的结果不同,该研究发现 TRanSGAN 优于所有现有模型,并在 IS 和 FID 得分方面达到了新的最先进性能。
由于 TRanSGAN 在标准基准 CIFAR-10 和 STL-10 上取得了良好表现,研究者将其应用于更具挑战性的 CelebA 64×64 数据集,结果如下表 10 所示:
TRanSGAN-XL 的 FID 评分为 12.23,表明其适用于高分辨率任务。可视化结果如图 4 所示。
尽管 TRanSGAN 已取得良好成绩,但与最优秀的手工设计 GAN 相比,仍存在较大改进空间。在论文的最后,作者指出了以下几个具体改进方向:
对生成器和判别器进行更复杂的 Token 化操作。使用代理任务预训练 Transformer,这可能会改善现有的多任务协同训练策略。探索更强大的注意力机制。开发更高效的自注意力形式,以提升模型效率并降低内存开销,从而支持生成更高分辨率的图像。
研究一作 YiFan Jiang 是德克萨斯大学奥斯汀分校电子与计算机工程系的一年级博士生,毕业于华中科技大学,研究兴趣集中于计算机视觉和深度学习等领域。目前,他主要从事神经架构搜索、视频理解和高级表征学习的研究,师从该系助理教授 Zhangyang Wang。
在本科期间,YiFan Jiang 曾在字节跳动 AI Lab 实习,今年夏天将进入 Google Research 实习。
一作主页:
