昇思25天学习打卡营第15天 | Vision Transformer图像分类
内容介绍:
近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。
ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。
模型特点
ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:
- 数据集的原图像被划分为多个patch(图像块)后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
- 模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
- 模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。
下面将通过代码实例来详细解释基于ViT实现ImageNet分类任务。
具体内容:
1. 导包
from download import download import os import mindspore as ms from mindspore.dataset import ImageFolderDataset import mindspore.dataset.vision as transforms from mindspore import nn, ops from typing import Optional, Dict from mindspore.common.initializer import Normal from mindspore.common.initializer import initializer from mindspore import Parameter from mindspore.nn import LossBase from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint from mindspore import train import pathlib import cv2 import numpy as np from PIL import Image from enum import Enum from scipy import io
2. 下载数据
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip" path = "./" path = download(dataset_url, path, kind="zip", replace=True)
3. 加载数据
data_path = './dataset/' mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] std = [0.229 * 255, 0.224 * 255, 0.225 * 255] dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True) trans_train = [ transforms.RandomCropDecodeResize(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)), transforms.RandomHorizontalFlip(prob=0.5), transforms.Normalize(mean=mean, std=std), transforms.HWC2CHW() ] dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"]) dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
4. Multi-Head Attention
class Attention(nn.Cell): def __init__(self, dim: int, num_heads: int = 8, keep_prob: float = 1.0, attention_keep_prob: float = 1.0): super(Attention, self).__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = ms.Tensor(head_dim ** -0.5) self.qkv = nn.Dense(dim, dim * 3) self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob) self.out = nn.Dense(dim, dim) self.out_drop = nn.Dropout(p=1.0-keep_prob) self.attn_matmul_v = ops.BatchMatMul() self.q_matmul_k = ops.BatchMatMul(transpose_b=True) self.softmax = nn.Softmax(axis=-1) def construct(self, x): """Attention construct.""" b, n, c = x.shape qkv = self.qkv(x) qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads)) qkv = ops.transpose(qkv, (2, 0, 3, 1, 4)) q, k, v = ops.unstack(qkv, axis=0) attn = self.q_matmul_k(q, k) attn = ops.mul(attn, self.scale) attn = self.softmax(attn) attn = self.attn_drop(attn) out = self.attn_matmul_v(attn, v) out = ops.transpose(out, (0, 2, 1, 3)) out = ops.reshape(out, (b, n, c)) out = self.out(out) out = self.out_drop(out) return out
5. Encoder
class FeedForward(nn.Cell): def __init__(self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, activation: nn.Cell = nn.GELU, keep_prob: float = 1.0): super(FeedForward, self).__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.dense1 = nn.Dense(in_features, hidden_features) self.activation = activation() self.dense2 = nn.Dense(hidden_features, out_features) self.dropout = nn.Dropout(p=1.0-keep_prob) def construct(self, x): """Feed Forward construct.""" x = self.dense1(x) x = self.activation(x) x = self.dropout(x) x = self.dense2(x) x = self.dropout(x) return x class ResidualCell(nn.Cell): def __init__(self, cell): super(ResidualCell, self).__init__() self.cell = cell def construct(self, x): """ResidualCell construct.""" return self.cell(x) + x
6. 模型
class TransformerEncoder(nn.Cell): def __init__(self, dim: int, num_layers: int, num_heads: int, mlp_dim: int, keep_prob: float = 1., attention_keep_prob: float = 1.0, drop_path_keep_prob: float = 1.0, activation: nn.Cell = nn.GELU, norm: nn.Cell = nn.LayerNorm): super(TransformerEncoder, self).__init__() layers = [] for _ in range(num_layers): normalization1 = norm((dim,)) normalization2 = norm((dim,)) attention = Attention(dim=dim, num_heads=num_heads, keep_prob=keep_prob, attention_keep_prob=attention_keep_prob) feedforward = FeedForward(in_features=dim, hidden_features=mlp_dim, activation=activation, keep_prob=keep_prob) layers.append( nn.SequentialCell([ ResidualCell(nn.SequentialCell([normalization1, attention])), ResidualCell(nn.SequentialCell([normalization2, feedforward])) ]) ) self.layers = nn.SequentialCell(layers) def construct(self, x): """Transformer construct.""" return self.layers(x)
7. ViT模型的输入
传统的Transformer结构主要用于处理自然语言领域的词向量(Word Embedding or Word Vector),词向量与传统图像数据的主要区别在于,词向量通常是一维向量进行堆叠,而图片则是二维矩阵的堆叠,多头注意力机制在处理一维词向量的堆叠时会提取词向量之间的联系也就是上下文语义,这使得Transformer在自然语言处理领域非常好用,而二维图片矩阵如何与一维词向量进行转化就成为了Transformer进军图像处理领域的一个小门槛。
在ViT模型中:
-
通过将输入图像在每个channel上划分为1616个patch,这一步是通过卷积操作来完成的,当然也可以人工进行划分,但卷积操作也可以达到目的同时还可以进行一次而外的数据处理;*例如一幅输入224 x 224的图像,首先经过卷积处理得到16 x 16个patch,那么每一个patch的大小就是14 x 14。
-
再将每一个patch的矩阵拉伸成为一个一维向量,从而获得了近似词向量堆叠的效果。上一步得到的14 x 14的patch就转换为长度为196的向量。
class PatchEmbedding(nn.Cell): MIN_NUM_PATCHES = 4 def __init__(self, image_size: int = 224, patch_size: int = 16, embed_dim: int = 768, input_channels: int = 3): super(PatchEmbedding, self).__init__() self.image_size = image_size self.patch_size = patch_size self.num_patches = (image_size // patch_size) ** 2 self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True) def construct(self, x): """Path Embedding construct.""" x = self.conv(x) b, c, h, w = x.shape x = ops.reshape(x, (b, c, h * w)) x = ops.transpose(x, (0, 2, 1)) return x
输入图像在划分为patch之后,会经过pos_embedding 和 class_embedding两个过程。
-
class_embedding主要借鉴了BERT模型的用于文本分类时的思想,在每一个word vector之前增加一个类别值,通常是加在向量的第一位,上一步得到的196维的向量加上class_embedding后变为197维。
-
增加的class_embedding是一个可以学习的参数,经过网络的不断训练,最终以输出向量的第一个维度的输出来决定最后的输出类别;由于输入是16 x 16个patch,所以输出进行分类时是取 16 x 16个class_embedding进行分类。
-
pos_embedding也是一组可以学习的参数,会被加入到经过处理的patch矩阵中。
-
由于pos_embedding也是可以学习的参数,所以它的加入类似于全链接网络和卷积的bias。这一步就是创造一个长度维197的可训练向量加入到经过class_embedding的向量中。
实际上,pos_embedding总共有4种方案。但是经过作者的论证,只有加上pos_embedding和不加pos_embedding有明显影响,至于pos_embedding是一维还是二维对分类结果影响不大,所以,在我们的代码中,也是采用了一维的pos_embedding,由于class_embedding是加在pos_embedding之前,所以pos_embedding的维度会比patch拉伸后的维度加1。
总的而言,ViT模型还是利用了Transformer模型在处理上下文语义时的优势,将图像转换为一种“变种词向量”然后进行处理,而这样转换的意义在于,多个patch之间本身具有空间联系,这类似于一种“空间语义”,从而获得了比较好的处理效果。
8. 整体构建ViT
def init(init_type, shape, dtype, name, requires_grad): """Init.""" initial = initializer(init_type, shape, dtype).init_data() return Parameter(initial, name=name, requires_grad=requires_grad) class ViT(nn.Cell): def __init__(self, image_size: int = 224, input_channels: int = 3, patch_size: int = 16, embed_dim: int = 768, num_layers: int = 12, num_heads: int = 12, mlp_dim: int = 3072, keep_prob: float = 1.0, attention_keep_prob: float = 1.0, drop_path_keep_prob: float = 1.0, activation: nn.Cell = nn.GELU, norm: Optional[nn.Cell] = nn.LayerNorm, pool: str = 'cls') -> None: super(ViT, self).__init__() self.patch_embedding = PatchEmbedding(image_size=image_size, patch_size=patch_size, embed_dim=embed_dim, input_channels=input_channels) num_patches = self.patch_embedding.num_patches self.cls_token = init(init_type=Normal(sigma=1.0), shape=(1, 1, embed_dim), dtype=ms.float32, name='cls', requires_grad=True) self.pos_embedding = init(init_type=Normal(sigma=1.0), shape=(1, num_patches + 1, embed_dim), dtype=ms.float32, name='pos_embedding', requires_grad=True) self.pool = pool self.pos_dropout = nn.Dropout(p=1.0-keep_prob) self.norm = norm((embed_dim,)) self.transformer = TransformerEncoder(dim=embed_dim, num_layers=num_layers, num_heads=num_heads, mlp_dim=mlp_dim, keep_prob=keep_prob, attention_keep_prob=attention_keep_prob, drop_path_keep_prob=drop_path_keep_prob, activation=activation, norm=norm) self.dropout = nn.Dropout(p=1.0-keep_prob) self.dense = nn.Dense(embed_dim, num_classes) def construct(self, x): """ViT construct.""" x = self.patch_embedding(x) cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1)) x = ops.concat((cls_tokens, x), axis=1) x += self.pos_embedding x = self.pos_dropout(x) x = self.transformer(x) x = self.norm(x) x = x[:, 0] if self.training: x = self.dropout(x) x = self.dense(x) return x
9. 模型训练
模型开始训练前,需要设定损失函数,优化器,回调函数等。
完整训练ViT模型需要很长的时间,实际应用时建议根据项目需要调整epoch_size,当正常输出每个Epoch的step信息时,意味着训练正在进行,通过模型输出可以查看当前训练的loss值和时间等指标。
epoch_size = 10 momentum = 0.9 num_classes = 1000 resize = 224 step_size = dataset_train.get_dataset_size() # construct model network = ViT() # load ckpt vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt" path = "./ckpt/vit_b_16_224.ckpt" vit_path = download(vit_url, path, replace=True) param_dict = ms.load_checkpoint(vit_path) ms.load_param_into_net(network, param_dict) # define learning rate lr = nn.cosine_decay_lr(min_lr=float(0), max_lr=0.00005, total_step=epoch_size * step_size, step_per_epoch=step_size, decay_epoch=10) # define optimizer network_opt = nn.Adam(network.trainable_params(), lr, momentum) # define loss function class CrossEntropySmooth(LossBase): """CrossEntropy.""" def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): super(CrossEntropySmooth, self).__init__() self.onehot = ops.OneHot() self.sparse = sparse self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32) self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32) self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction) def construct(self, logit, label): if self.sparse: label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value) loss = self.ce(logit, label) return loss network_loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, num_classes=num_classes) # set checkpoint ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100) ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config) # initialize model # "Ascend + mixed precision" can improve performance ascend_target = (ms.get_context("device_target") == "Ascend") if ascend_target: model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2") else: model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0") # train model model.train(epoch_size, dataset_train, callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)], dataset_sink_mode=False,)
10. 模型验证
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True) trans_val = [ transforms.Decode(), transforms.Resize(224 + 32), transforms.CenterCrop(224), transforms.Normalize(mean=mean, std=std), transforms.HWC2CHW() ] dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"]) dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True) # construct model network = ViT() # load ckpt param_dict = ms.load_checkpoint(vit_path) ms.load_param_into_net(network, param_dict) network_loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, num_classes=num_classes) # define metric eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(), 'Top_5_Accuracy': train.Top5CategoricalAccuracy()} if ascend_target: model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2") else: model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0") # evaluate model result = model.eval(dataset_val) print(result)
11. 模型推理
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True) trans_infer = [ transforms.Decode(), transforms.Resize([224, 224]), transforms.Normalize(mean=mean, std=std), transforms.HWC2CHW() ] dataset_infer = dataset_infer.map(operations=trans_infer, input_columns=["image"], num_parallel_workers=1) dataset_infer = dataset_infer.batch(1)
12. show
class Color(Enum): """dedine enum color.""" red = (0, 0, 255) green = (0, 255, 0) blue = (255, 0, 0) cyan = (255, 255, 0) yellow = (0, 255, 255) magenta = (255, 0, 255) white = (255, 255, 255) black = (0, 0, 0) def check_file_exist(file_name: str): """check_file_exist.""" if not os.path.isfile(file_name): raise FileNotFoundError(f"File `{file_name}` does not exist.") def color_val(color): """color_val.""" if isinstance(color, str): return Color[color].value if isinstance(color, Color): return color.value if isinstance(color, tuple): assert len(color) == 3 for channel in color: assert 0