鱼C论坛

 找回密码
 立即注册
查看: 3638|回复: 8

关于mnist导入测试集并打乱的问题

[复制链接]
发表于 2019-8-31 12:44:04 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 facevoid 于 2019-8-31 13:01 编辑
  1. import numpy
  2. import torch
  3. from torch.utils.data import DataLoader
  4. import torchvision.datasets as datasets
  5. import torchvision.transforms as transforms
  6. import matplotlib.pyplot as plot
  7. from classify_lib import kNN_classify as classify


  8. class Image_identify_mnist():

  9.     def __init__(self):

  10.         self.batch_size = 100

  11.         self.test_dataset = datasets.MNIST('ml/pymnist',train = False, transform = None, download = True)

  12.         self.test_loader = torch.utils.data.DataLoader(dataset = self.test_dataset, batch_size = self.batch_size, shuffle = True)
复制代码



上面这个就是要导入mnist测试集,最后一行有个参数shuffle = True,意思是把数据打乱,可是我并没有发现有任何打乱效果,比如self.test_loader.dataset.data[0],反复运行始终是一张图片。

有哪位大神知道什么原因吗。。。多谢!!
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2019-8-31 12:48:20 | 显示全部楼层
你打乱的不是测试集吗,怎么还看训练集
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-8-31 13:02:29 | 显示全部楼层
塔利班 发表于 2019-8-31 12:48
你打乱的不是测试集吗,怎么还看训练集

复制的时候复制错了。。。改过来啦,还是测试集
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2019-8-31 17:16:39 | 显示全部楼层
shuffle的作用是
  1. set to True to have the data reshuffled at every epoch
复制代码

你可以试试打印第一个epoch第一轮iteration的数据,和第二个epoch第一轮iteration的数据,他们就是不一样的。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-9-1 15:45:05 | 显示全部楼层
Charles未晞 发表于 2019-8-31 17:16
shuffle的作用是
你可以试试打印第一个epoch第一轮iteration的数据,和第二个epoch第一轮iteration的数据 ...

我的完整代码大概这个样子
  1. import numpy
  2. import torch
  3. from torch.utils.data import DataLoader
  4. import torchvision.datasets as datasets
  5. import torchvision.transforms as transforms
  6. import matplotlib.pyplot as plot


  7. def main():

  8.     batch_size = 100

  9.     train_dataset = datasets.MNIST('ml/pymnist',train = True, transform = None, download = True)

  10.     train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)

  11.     data_train = train_loader.dataset.data[ : 100].numpy()

  12.     digit = train_loader.dataset.data[0]

  13.     plot.imshow(digit, cmap = plot.cm.binary)

  14.     plot.show()


  15. main()
复制代码


我把这个跑了几遍画出来的第一张图片都是一样的,我不太理解你说的那个epoch跟iteration的意思
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 1 反对 0

使用道具 举报

发表于 2019-9-1 16:02:33 | 显示全部楼层
本帖最后由 Charles未晞 于 2019-9-1 17:31 编辑
facevoid 发表于 2019-9-1 15:45
我的完整代码大概这个样子


一个epoch代表迭代一遍数据集,一次iteration一般代表一个batch。
  1. for epoch in range(1, 20):
  2.     for num_iter, data_batch in enumerate(train_loader):
  3.         ...
复制代码

第一个epoch时,num_iter=1的数据才和第二个epoch时, num_iter=1的数据不一样。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-9-1 16:27:34 | 显示全部楼层
本帖最后由 facevoid 于 2019-9-1 16:29 编辑
Charles未晞 发表于 2019-9-1 16:02
一个epoch代表迭代一遍数据集,一次iteration一般代表一个batch。

第一个epoch时,num_iter=1的数据才 ...


我按照你的说法把代码改了一下

  1. import numpy
  2. import torch
  3. from torch.utils.data import DataLoader
  4. import torchvision.datasets as datasets
  5. import torchvision.transforms as transforms
  6. import matplotlib.pyplot as plot


  7. def main():

  8.     batch_size = 100

  9.     train_dataset = datasets.MNIST('ml/pymnist',train = True, transform = None, download = True)

  10.     train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)

  11.     data_train = train_loader.dataset.data[ : 100].numpy()

  12.     for epoch in range (10):

  13.         for num_iter, data_batch in enumerate(data_train):

  14.             if num_iter == 1:

  15.                 plot.imshow(data_batch, cmap = plot.cm.binary)

  16.                 plot.show()

  17.             else:

  18.                 pass
  19.    
  20. main()
复制代码


就是对每个epoch只画出来num_iter= 1的那个图片
可是图片还是都一样呀。。。

不过还是多谢你热心回答!
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2019-9-1 17:32:08 | 显示全部楼层
facevoid 发表于 2019-9-1 16:27
我按照你的说法把代码改了一下

train_loader不是data_train...
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-9-1 19:06:40 | 显示全部楼层
Charles未晞 发表于 2019-9-1 17:32
train_loader不是data_train...
  1. for num_iter, data_batch in enumerate(data_train):
复制代码


这一行的data_train改成train_loader运行会报错的

改成train_loader.dataset.data可以运行但是跑出来还是一样的图片 >_<
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2026-1-17 22:48

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表