使用TensorDataset()和DataLoader()来构建数据集和dataloader
本帖最后由 Handsome_zhou 于 2022-11-24 10:09 编辑在深度学习中,需要将数据分批的放入训练网络中,批数量的大小称为batch_size, 通过torch提供的dataloader方法,
可以实现一个迭代器。每次返回一组batch_size个样本和标签来训练。
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
#给训练集创建DataLoader
train_data = TensorDataset(train_inputs, train_masks, trian_labels)
train_sampler = RandomSampler(train_data)# 训练集随机采样
train_dataloader = DataLoader(train_data, sampler = train_sampler, batch_size=batch_size)
#给验证集创建 DataLoader
valid_data = TensorDataset(test_inputs, test_masks, test_labels)
valid_sampler = SequentialSampler(valid_data)# 验证集按同样的顺序采样
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=batch_size)
页:
[1]