Token Labeling(NeurIPS 2021, ByteDance)论文解读
paper:All Tokens Matter: Token Labeling for Training Better Vision Transformers
official implementation:https://github.com/zihangJiang/TokenLabeling
出发点
- ViTs的局限性:尽管ViTs在捕捉长距离依赖方面表现出色, 但它通常依赖于额外的可训练class token来计算分类损失,这可能会忽略其他patch token所包含的丰富局部信息。
- 局部信息的重要性:最近的工作表明,对图像分类任务而言,良好地建模和利用局部信息可以避免模型偏向偏颇和不可泛化的模式,从而显著提高模型性能。
创新点
本文提出了一种新的训练目标——token labeling,旨在利用所有的图像patch token进行密集的训练损失计算,而不仅仅依赖于额外的class token。通过这种方式,每个patch token都能获得由machine annotator生成的单独的、位置特定的监督,从而提升模型的性能。 具体包括:
- Token Labeling:提出了token labeling方法,通过对所有patch token进行位置特定的监督,提高了图像分类的准确性和对象识别能力。
- MixToken:改进了传统的CutMix数据增强方法,使其在token层面上操作,避免了图像patch中混合内容的问题,从而提高了模型的训练效果。
- Patch Embedding:对ViT的patch embedding模块进行了修改,采用了4层卷积层来更好地对输入图像进行token化和整合局部信息。
方法介绍
Token Labeling
在传统的ViT中,给定输入图片 \(I\),最后一个transformer block的输出可以表示为 \([X^{cls},X^1,...,X^N]\),其中 \(N\) 表示patch token的数量,\(X^{cls}\) 和 \(X^1,...,X^N\) 分别对应 class token和patch tokens。则图片 \(I\) 的分类损失可以按下式计算
其中 \(H(\cdot,\cdot)\) 是softmax cross-entropy loss,\(y^{cls}\) 是类别标签。
这种方式只采用了image-level的标签作为监督,而忽略了每个image patch中包含的丰富信息。因此本文提出了一种新的训练目标,token labeling,它利用了patch tokens和class token之间的互补信息。具体来说,作者认为每个输出token都应该和一个单独的、位置特定的label联系起来,因此token labeling的ground truth标签是一个 \(K\times N\) score map矩阵,表示为 \([y^1,...,y^N]\),其中 \(N\) 是patch token的数量,\(K\) 是类别数。
最终计算每个patch token和score map中对应的标签之间的交叉熵损失,如下
完整的损失包含原始的class token损失和token labeling损失,如下
图2是整个过程的一个直观展示
里dense score map是通过machine annotator离线得到的,可参考Re-labeling ImageNet(CVPR 2021, Naver)-CSDN博客。简单地说,machine annotator是一个在额外的大数据集(例如JFT-300M)上训练好的性能强大的分类模型,然后对ImageNet进行推理,去掉全连接层之前的全局平均池化,全连接层改为一个1x1卷积层,因此经过softmax后输出的是一个 \(H\times W\times C\) 的score map,这里的 \(H,W\) 是模型最后一层的分辨率大小,而不像传统的分类模型一样输出的是一个 \(1\times 1\times C\) 的向量。score map是提前计算得到并保存到本地的,在训练token labeling时,只需要加载score map并根据patch token的空间位置对应的在score map上crop和插值对齐空间坐标,然后再进行全局平均池化并经过softmax和argmax得到每个patch token对应的标签。和需要在线生成target的知识蒸馏不同,token labeling额外增加的计算量可以忽略不计。
MixToken
在训练分类模型时,数据增强方法例如MixUp和CutMix可以有效地提高模型性能。但vision transformer依赖patch-based tokenization来将输入图片映射为token序列,本文提出的token labeling也是基于每个patch的标签,如果我们直接在原始图像上应用CutMix,可能导致一个patch中包含来自两个图像的内容,如图3左所示。
这样token labeling很难为每个token分配一个干净正确的标签,因此本文提出了CutMix的变体MixToken,即在patch embedding层后得到的token上进行cutmix操作,如图3右所示,这样就保证了每个token只包含一张图片的内容。
实验结果
结合token labeling和mixtoken作者提出了LV-ViT,网络配置如下
两个components的消融实验如下表所示,可以看到单独将cutmix换成mixtoken精度提升了0.1%,而结合mixtoken和token labeling时,精度提升了0.9%。
和其它分类模型在ImageNet上的结果对比如下表所示