|
|

楼主 |
发表于 2019-9-1 16:27:34
|
显示全部楼层
本帖最后由 facevoid 于 2019-9-1 16:29 编辑
我按照你的说法把代码改了一下
- import numpy
- import torch
- from torch.utils.data import DataLoader
- import torchvision.datasets as datasets
- import torchvision.transforms as transforms
- import matplotlib.pyplot as plot
- def main():
- batch_size = 100
- train_dataset = datasets.MNIST('ml/pymnist',train = True, transform = None, download = True)
- train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)
- data_train = train_loader.dataset.data[ : 100].numpy()
- for epoch in range (10):
- for num_iter, data_batch in enumerate(data_train):
- if num_iter == 1:
- plot.imshow(data_batch, cmap = plot.cm.binary)
- plot.show()
- else:
- pass
-
- main()
复制代码
就是对每个epoch只画出来num_iter= 1的那个图片
可是图片还是都一样呀。。。
不过还是多谢你热心回答! |
|