鱼C论坛

 找回密码
 立即注册
查看: 2273|回复: 4

[基础知识] 【深度学习实战】利用多层感知机解决手写数字识别问题

[复制链接]
发表于 2024-6-7 11:28:26 | 显示全部楼层
这段代码使用Pytorch搭建了一个多层感知机(MLP)来解决手写数字识别问题。下面我来逐步解释一下代码的每一部分。

1. 数据准备
首先我们导入需要的库,包括torch、torchvision以及一些常用的模块。然后使用torchvision中的datasets和transforms模块加载MNIST数据集,并对数据进行预处理。接着我们定义batch_size来指定每批次加载的图片数量,并使用DataLoader来创建训练集和测试集的加载器。

2. 定义网络模型
我们定义了一个MLP类来表示我们的神经网络模型。该类继承自nn.Module,并在构造函数中定义了模型的各个层,包括两个全连接层和一个ReLU激活函数。forward方法指定了前向传播的过程,即通过各个层的计算获取输出。

3. 定义损失函数
这里我们使用交叉熵损失函数来衡量模型在训练过程中的性能。

4. 定义优化器
我们使用Adam优化器来更新模型的权重参数。学习率被设定为0.001。

5. 训练
我们使用一个循环来迭代训练集。每次迭代时,我们将一批次的图片和标签从加载器中读取出来,并转换成合适的形状。然后通过模型的forward方法计算输出,并使用交叉熵损失函数计算损失。之后我们进行反向传播,清零梯度并更新参数。每100次迭代,我们打印出当前的epoch、step和损失值。

6. 测试
在测试阶段,我们不需要对梯度和参数进行更新,所以我们使用torch.no_grad()来关闭autograd的功能。然后我们通过循环从测试集加载器中读取图片和标签,并使用模型进行前向传播。最后我们统计出模型在测试集上的准确率。

7. 保存
这部分代码没有涉及到,可以根据需求选择是否保存训练好的模型。

希望这样解释可以帮到你!

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-6-7 11:42:10 | 显示全部楼层
剑魂 发表于 2024-06-07 11:36
_,predicted = torch.max(outputs.data,1)
total +=  ...


当我们使用神经网络进行分类任务时,模型的输出是一个向量,每个元素代表对应类别的概率。在这里,outputs是模型的输出,每一行是一个样本,每一列是一个类别的概率。而我们通常会选择概率最大的类别作为模型的预测结果。

`torch.max(outputs.data, 1)`中的`torch.max()`函数会返回两个值。第一个值是每一行中最大值的值,而第二个值是每一行最大值所在的列(类别)的索引。由于我们只关心类别的索引,所以用一个下划线`_`来接收第一个值,而用变量predicted来接收第二个值。

那么,predicted中存储的就是模型对输入图片的预测结果,即预测的类别标签。

而对于total += labels.size(0),labels表示真实的类别标签,labels.size(0)表示labels的行数,也就是当前批次的图片数量。我们将这个数量加上total,用于统计测试集中总共的图片数量。

希望这样解释能够帮助你理解这两行代码的含义!如果还有疑惑,请随时提问。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-9-19 04:14

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表