Handsome_zhou 发表于 2022-11-24 09:44:25

使用TensorDataset()和DataLoader()来构建数据集和dataloader

本帖最后由 Handsome_zhou 于 2022-11-24 10:09 编辑

在深度学习中,需要将数据分批的放入训练网络中,批数量的大小称为batch_size, 通过torch提供的dataloader方法,
可以实现一个迭代器。每次返回一组batch_size个样本和标签来训练。


from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

#给训练集创建DataLoader
train_data = TensorDataset(train_inputs, train_masks, trian_labels)
train_sampler = RandomSampler(train_data)# 训练集随机采样
train_dataloader = DataLoader(train_data, sampler = train_sampler, batch_size=batch_size)

#给验证集创建 DataLoader
valid_data = TensorDataset(test_inputs, test_masks, test_labels)
valid_sampler = SequentialSampler(valid_data)# 验证集按同样的顺序采样
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=batch_size)
页: [1]
查看完整版本: 使用TensorDataset()和DataLoader()来构建数据集和dataloader