53清 发表于 2021-11-3 15:18:20

torchvirsion下载数据集

代码
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l


d2l.use_svg_display
trans = transforms.ToTensor

mnist_train=torchvision.datasets.FashionMNIST
(
    root='./data',train=True,transform=trans,download=True
    )
mnist_train=torchvision.datasets.FashionMNIST
(
    root='./data',train=False,transform=trans,download=True
    )
len(mnist_train),len(mnist_test)

mnist_train.shape

53清 发表于 2021-11-16 20:38:39

格式问题,把空格去掉就好
页: [1]
查看完整版本: torchvirsion下载数据集