【Week-G4】手势图像生成(CGAN)- PyTorch【使用生成好的生成器生成指定手势】
文章目录
- 一、准备数据 & 环境
- 2. 数据可视化
- 二、构建模型
- 三、训练模型
- 3. 训练模型
- 四、模型分析
- 4.1 绘制Loss图
- 2. 导入generator_epoch_300.pth,生成图像
- 五、predictions.permute(0,2,3,1)
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊 | 接辅导、项目定制
本文环境:
系统环境:
语言:Python3.7.8
编译器:VSCode
深度学习框架:torch 1.13.1
是否GPU训练:CPU
本周代码基本和G3周一致,个别地方有修改:
(1)epoch=300
(2)绘制Loss图
(3)选择训练好的epoch=300时的生成器模型来生成手势
(4)数据集可视化部分选择显示第一个批次的图片
训练持续时间:7.17日23:00-7月18日08:12,耗时9h12min
本次学习的重点:predictions.permute(0,2,3,1),该函数用于对张量的维度进行重新排列,生成指定图像就是靠这个函数,Pytorch框架才有的
一、准备数据 & 环境
2. 数据可视化
# 可视化第一个batch的数据 def show_images(dl): for images, _ in dl: fig, ax = plt.subplots(figsize=(10, 10)) ax.set_xticks([]); ax.set_yticks([]) ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0)) break show_images(train_loader) # def show_batch(dl): # for images, _ in dl: # show_images(images) # break # show_batch(train_loader) #显示图片 print("===============1.2 数据可视化 End============\n")
二、构建模型
print("===============2. 构建模型================") latent_dim = 100 n_classes = 3 embedding_dim = 100 # 自定义权重初始化函数,用于初始化生成器和判别器的权重 def weights_init(m): # 获取当前层的类名 classname = m.__class__.__name__ # 如果当前层是卷积层(类名中包含 'Conv' ) if classname.find('Conv') != -1: # 使用正态分布随机初始化权重,均值为0,标准差为0.02 torch.nn.init.normal_(m.weight, 0.0, 0.02) # 如果当前层是批归一化层(类名中包含 'BatchNorm' ) elif classname.find('BatchNorm') != -1: # 使用正态分布随机初始化权重,均值为1,标准差为0.02 torch.nn.init.normal_(m.weight, 1.0, 0.02) # 将偏置项初始化为全零 torch.nn.init.zeros_(m.bias) print("===============2. 构建模型 End============\n")
构建生成器与鉴别器的代码与G3一致【此处省略】
三、训练模型
定义损失函数、定义优化器与G3一致
训练时 epoch值改为300
3. 训练模型
print("===============3.3 训练模型================") # 设置训练的总轮数 num_epochs = 300 # 初始化用于存储每轮训练中判别器和生成器损失的列表 D_loss_plot, G_loss_plot = [], [] # 循环进行训练 for epoch in range(1, num_epochs + 1): # 初始化每轮训练中判别器和生成器损失的临时列表 D_loss_list, G_loss_list = [], [] # 遍历训练数据加载器中的数据 for index, (real_images, labels) in enumerate(train_loader): # 清空判别器的梯度缓存 D_optimizer.zero_grad() # 将真实图像数据和标签转移到GPU(如果可用) real_images = real_images.to(device) labels = labels.to(device) # 将标签的形状从一维向量转换为二维张量(用于后续计算) labels = labels.unsqueeze(1).long() # 创建真实目标和虚假目标的张量(用于判别器损失函数) real_target = Variable(torch.ones(real_images.size(0), 1).to(device)) fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device)) # 计算判别器对真实图像的损失 D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target) # 从噪声向量中生成假图像(生成器的输入) noise_vector = torch.randn(real_images.size(0), latent_dim, device=device) noise_vector = noise_vector.to(device) generated_image = generator((noise_vector, labels)) # 计算判别器对假图像的损失(注意detach()函数用于分离生成器梯度计算图) output = discriminator((generated_image.detach(), labels)) D_fake_loss = discriminator_loss(output, fake_target) # 计算判别器总体损失(真实图像损失和假图像损失的平均值) D_total_loss = (D_real_loss + D_fake_loss) / 2 D_loss_list.append(D_total_loss) # 反向传播更新判别器的参数 D_total_loss.backward() D_optimizer.step() # 清空生成器的梯度缓存 G_optimizer.zero_grad() # 计算生成器的损失 G_loss = generator_loss(discriminator((generated_image, labels)), real_target) G_loss_list.append(G_loss) # 反向传播更新生成器的参数 G_loss.backward() G_optimizer.step() # 打印当前轮次的判别器和生成器的平均损失 print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % ( (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), torch.mean(torch.FloatTensor(G_loss_list)))) # 将当前轮次的判别器和生成器的平均损失保存到列表中 D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list))) G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list))) if epoch%10 == 0: # 将生成的假图像保存为图片文件 save_image(generated_image.data[:50], './GAN/G4/out_images/sample_%d' % epoch + '.png', nrow=5, normalize=True) # 将当前轮次的生成器和判别器的权重保存到文件 torch.save(generator.state_dict(), './GAN/G4/training_weights/generator_epoch_%d.pth' % (epoch)) torch.save(discriminator.state_dict(), './GAN/G4/training_weights/discriminator_epoch_%d.pth' % (epoch)) print("===============3.3 训练模型 End============\n")
训练过程:
训练保存的/out_images:
训练保存的/train_weights:
四、模型分析
4.1 绘制Loss图
print("===============4.1 Loss图============\n") G_loss_list = [i.item() for i in G_loss_list] D_loss_list = [i.item() for i in D_loss_list] import matplotlib.pyplot as plt import warnings warnings.filterwarnings("ignore") plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False plt.rcParams['figure.dpi'] = 100 plt.figure(figsize=(8,4)) plt.title("Generator and Descriminator Loss During Training") plt.plot(G_loss_list,label="G") plt.plot(D_loss_list,label="D") plt.xlabel("iteration") plt.ylabel("Loss") plt.legend() plt.savefig("./GAN/G4/G&D Loss Figure.png") plt.show()
【咋感觉这个图有点奇怪】
2. 导入generator_epoch_300.pth,生成图像
print("===============4.2 生成指定图像============\n") from numpy.random import randint,randn from numpy import linspace from matplotlib import pyplot,gridspec #导入生成器模型 generator.load_state_dict(torch.load('./GAN/G4/training_weights/generator_epoch_300.pth'), strict=False) generator.eval() # 生成两个潜在空间的点 interpolated = randn(100) # 将数据转换为torch张量并将其移至GPU(假设device已正确声明为GPU) interpolated = torch.tensor(interpolated).to(device).type(torch.float32) label = 0 labels = torch.ones(1) * label labels = labels.to(device) labels = labels.unsqueeze(1).long() #使用生成器生成插值结果 predictions = generator((interpolated, labels)) predictions = predictions.permute(0,2,3,1).detach().cpu() # predictions.permute(0,2,3,1):用于对张量的维度进行重新排列,生成指定图像就是靠这个函数 import warnings warnings.filterwarnings("ignore") plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False plt.rcParams['figure.dpi'] = 100 plt.figure(figsize=(8,3)) pred = (predictions[0,:,:,:]+1)*127.5 pred = np.array(pred) plt.imshow(pred.astype(np.uint8)) # plt.savefig("./GAN/G4/G&D Loss Figure.png") plt.show() print("===============4.2 生成指定图像 End============\n")
五、predictions.permute(0,2,3,1)
predictions = predictions.permute(0,2,3,1):Pytorch中的操作,用于按照指定顺序重新排列张量。
假设predictions 是Pytorch中的一个张量,它的维度为(batch_size,height,width,channels),
permute(arg1,arg2,arg3,arg4)是Pytorch中的一个用于重新排列张量的函数,arg1,arg2,arg3,arg4是一个用于指定新维度的顺序的元组
如0,2,3,1表示:
- 原始维度的第0维—>新维度的第0维(此处0表示元组0,2,3,1中元素0的索引)
- 原始维度的第2维—>新维度的第1维(此处1表示元组0,2,3,1中元素2的索引)
- 原始维度的第3维—>新维度的第2维(此处2表示元组0,2,3,1中元素3的索引)
- 原始维度的第1维—>新维度的第3维(此处3表示元组0,2,3,1中元素1的索引)
所以predicitons经过(0,2,3,1)的排列后,从(batch_size,height,width,channels) 的顺序变为 (batch_size,width,channels,height)。