pytorch跑手写体实验
目录
(图片来源网络,侵删)
1、环境条件
2、代码实现
3、总结
1、环境条件
- pycharm编译器
- pytorch依赖
- matplotlib依赖
- numpy依赖等等
2、代码实现
import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 定义数据变换 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # 加载 MNIST 数据集 trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) # 定义 LeNet-5 模型 class LeNet5(nn.Module): def __init__(self): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = self.pool(torch.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x # 初始化模型、损失函数和优化器 model = LeNet5().to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 训练模型 epochs = 5 for epoch in range(epochs): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 100 == 99: print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}') running_loss = 0.0 print('Finished Training') # 保存模型 torch.save(model.state_dict(), 'lenet5.pth') print('Model saved to lenet5.pth') # 加载模型 model = LeNet5() model.load_state_dict(torch.load('lenet5.pth')) model.to(device) model.eval() # 在测试集上评估模型 correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f'Accuracy on the test set: {100 * correct / total:.2f}%') # 加载并预处理本地图片进行预测 from PIL import Image def load_and_preprocess_image(image_path): img = Image.open(image_path).convert('L') # 转为灰度图 img = img.resize((28, 28)) img = np.array(img, dtype=np.float32) img = (img / 255.0 - 0.5) / 0.5 # 归一化到[-1, 1] img = torch.tensor(img).unsqueeze(0).unsqueeze(0) # 添加批次和通道维度 return img.to(device) # 预测本地图片 image_path = '4.png' # 替换为你的本地图片路径 img = load_and_preprocess_image(image_path) # 使用加载的模型进行预测 model.eval() with torch.no_grad(): outputs = model(img) _, predicted = torch.max(outputs, 1) # 打印预测结果 predicted_label = predicted.item() print(f'预测结果: {predicted_label}') # 显示图片及预测结果 img_np = img.cpu().numpy().squeeze() plt.imshow(img_np, cmap='gray') plt.title(f'预测结果: {predicted_label}') plt.show()
解释:torch.save()方法完成模型的保存,image_path为本地图片,用于测试
3、总结
安装环境是比较难的点,均使用pip install 。。指令进行依赖环境的安装,其他的比较简单。
学习之所以会想睡觉,是因为那是梦开始的地方。
ଘ(੭ˊᵕˋ)੭ (开心) ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)ଘ(੭ˊᵕˋ)੭ (开心)
------不写代码不会凸的小刘
文章版权声明:除非注明,否则均为主机测评原创文章,转载或复制请以超链接形式并注明出处。