【深度学习实战】利用多层感知机解决手写数字识别问题
本帖最后由 剑魂 于 2024-6-7 11:37 编辑本文使用MNIST数据集(6w张训练图片,1w张测试图片,数据集为黑白图,图片大小为28*28)
使用Pytorch搭建神经网络
搭建神经网络7步法:1. 数据 2. 网络结构 3. 损失函数 4. 优化器 5. 训练 6. 训练 7. 保存
【说明】:
[*]如果数据集总共有60,000条数据,那么使用 batch_size 为100的加载器(loader)将每次加载600条数据。这样,模型会分批次地处理整个数据集,每批次处理600条数据。总共需要60,000 / 100 = 600个批次来处理完整个数据集。
# 0. 导包
import torch
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# 1. 数据准备
train_data = datasets.MNIST(root = "./data", train = True, download = True, transform = transforms.ToTensor())
test_data = datasets.MNIST(root = "./data",train = False, download = True, transform = transforms.ToTensor())
batch_size = 100
train_loader = DataLoader(dataset = train_data, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(dataset = train_data, batch_size = batch_size, shuffle = False)
# 2. 定义网络模型
class MLP(nn.Module):
def __init__(self,input_size,hidden_size,output_classes):
"""
:param input_size: 神经网络的输入数据维度
:param hidden_size: 隐藏层的大小
:param output_classes: 输出分类的大小
"""
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size,hidden_size) # 定义第一个全连接层
self.relu = nn.ReLU() # 定义激活函数为ReLu
self.fc2 = nn.Linear(hidden_size,hidden_size) # 定义第二个全连接层
self.fc3 = nn.Linear(hidden_size,output_classes) # 定义第三个全连接层
# 前向传播
def forward(self,x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# 定义参数
input_size = 28*28
hidden_size = 512
output_classes = 10 # 因为手写数字识别有0~9共10个数字所以输入分类为10
model = MLP(input_size,hidden_size,output_classes) # 初始化MLP
# 3. 定义损失函数(本文使用交叉熵损失函数)
criterion = nn.CrossEntropyLoss()
# 4. 定义优化器(用于反向传播更新权重参数)
learning_rata = 1e-3 # 定义学习率为0.001
optimizer = optim.Adam(model.parameters(), lr = learning_rata)# parameters()返回参数,用于更新权重参数
# 5. 训练
num_epoch = 10
for epoch in range(num_epoch):
for i,(images,labels) in enumerate(train_loader):
images = images.reshape(-1,28*28)
outputs = model(images)
loss = criterion(outputs,labels)
# 每次在进入反向传播的时候,都要将梯度进行清零
optimizer.zero_grad()
loss.backward() # 反向传播
optimizer.step() # 更新参数
if (i+1) % 100 == 0:
print(f"Epoch [{epoch+1}/{num_epoch}],Step[{i+1}/{len(train_loader)}] ,loss = {loss.item()}") # .item() 将这个数量转换为一个标量整数
# 6. 测试
# 这里是测试集了,因此我们不需要在进行梯度和参数的更新, 所以我们用no_grad()
with torch.no_grad():
correct = 0
total = 0
# 从 test_loader中循环读取测试数据
for images, labels in test_loader:
images = images.reshape(-1,28*28)# -1 表示自动计算这个维度的大小
outputs = model(images)
_,predicted = torch.max(outputs.data,1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy of the network on the 10000 test images: {100*correct/total}%")
# 7. 保存模型
torch.save(model,'./mnist.pth')
这段代码使用Pytorch搭建了一个多层感知机(MLP)来解决手写数字识别问题。下面我来逐步解释一下代码的每一部分。
1. 数据准备
首先我们导入需要的库,包括torch、torchvision以及一些常用的模块。然后使用torchvision中的datasets和transforms模块加载MNIST数据集,并对数据进行预处理。接着我们定义batch_size来指定每批次加载的图片数量,并使用DataLoader来创建训练集和测试集的加载器。
2. 定义网络模型
我们定义了一个MLP类来表示我们的神经网络模型。该类继承自nn.Module,并在构造函数中定义了模型的各个层,包括两个全连接层和一个ReLU激活函数。forward方法指定了前向传播的过程,即通过各个层的计算获取输出。
3. 定义损失函数
这里我们使用交叉熵损失函数来衡量模型在训练过程中的性能。
4. 定义优化器
我们使用Adam优化器来更新模型的权重参数。学习率被设定为0.001。
5. 训练
我们使用一个循环来迭代训练集。每次迭代时,我们将一批次的图片和标签从加载器中读取出来,并转换成合适的形状。然后通过模型的forward方法计算输出,并使用交叉熵损失函数计算损失。之后我们进行反向传播,清零梯度并更新参数。每100次迭代,我们打印出当前的epoch、step和损失值。
6. 测试
在测试阶段,我们不需要对梯度和参数进行更新,所以我们使用torch.no_grad()来关闭autograd的功能。然后我们通过循环从测试集加载器中读取图片和标签,并使用模型进行前向传播。最后我们统计出模型在测试集上的准确率。
7. 保存
这部分代码没有涉及到,可以根据需求选择是否保存训练好的模型。
希望这样解释可以帮到你!
以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。 _,predicted = torch.max(outputs.data,1)
total += labels.size(0)
关于这两行代码的实现,我依然有点疑惑,有明白的大佬帮忙看看该如何进行解释? 剑魂 发表于 2024-06-07 11:36
_,predicted = torch.max(outputs.data,1)
total +=...
当我们使用神经网络进行分类任务时,模型的输出是一个向量,每个元素代表对应类别的概率。在这里,outputs是模型的输出,每一行是一个样本,每一列是一个类别的概率。而我们通常会选择概率最大的类别作为模型的预测结果。
`torch.max(outputs.data, 1)`中的`torch.max()`函数会返回两个值。第一个值是每一行中最大值的值,而第二个值是每一行最大值所在的列(类别)的索引。由于我们只关心类别的索引,所以用一个下划线`_`来接收第一个值,而用变量predicted来接收第二个值。
那么,predicted中存储的就是模型对输入图片的预测结果,即预测的类别标签。
而对于total += labels.size(0),labels表示真实的类别标签,labels.size(0)表示labels的行数,也就是当前批次的图片数量。我们将这个数量加上total,用于统计测试集中总共的图片数量。
希望这样解释能够帮助你理解这两行代码的含义!如果还有疑惑,请随时提问。
以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。 很不错的帖子
页:
[1]