鱼C论坛

 找回密码
 立即注册
查看: 3545|回复: 0

[学习笔记] 使用TensorDataset()和DataLoader()来构建数据集和dataloader

[复制链接]
发表于 2022-11-24 09:44:25 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 Handsome_zhou 于 2022-11-24 10:09 编辑

在深度学习中,需要将数据分批的放入训练网络中,批数量的大小称为batch_size, 通过torch提供的dataloader方法,
可以实现一个迭代器。每次返回一组batch_size个样本和标签来训练。

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

#给训练集创建DataLoader
train_data = TensorDataset(train_inputs, train_masks, trian_labels)
train_sampler = RandomSampler(train_data)  # 训练集随机采样
train_dataloader = DataLoader(train_data, sampler = train_sampler, batch_size=batch_size)

#给验证集创建 DataLoader
valid_data = TensorDataset(test_inputs, test_masks, test_labels)
valid_sampler = SequentialSampler(valid_data)  # 验证集按同样的顺序采样
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=batch_size)
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-12-25 03:50

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表