|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
本帖最后由 alltolove 于 2018-11-29 08:14 编辑
上次我们把cifar10的数据读进内存,然后就要把数据做成一批一批的进行处理,把每一批次数据打乱顺序,代码为:
- labels=labels.reshape((50000,1))
- all_data=np.hstack((data,labels))
- test_d,test_l=load_data(test_filename)
- test_l=np.array(test_l)
- test_l=test_l.reshape((10000,1))
- all_test_data=np.hstack((test_d,test_l))
- BATCH_SIZE = 20
- dataset = tf.data.Dataset.from_tensor_slices(all_data)
- dataset_test = tf.data.Dataset.from_tensor_slices(all_test_data).batch(BATCH_SIZE)
- dataset = dataset.shuffle(buffer_size=50000)
- dataset=dataset.batch(BATCH_SIZE)
- dataset=dataset.repeat()
- dataset_test=dataset_test.repeat()
- iter_data = dataset.make_one_shot_iterator()
- iter_data_test = dataset_test.make_one_shot_iterator()
- el_data = iter_data.get_next()
- el_data_test = iter_data_test.get_next()
复制代码
这里的BATCH_SIZE就是批次的大小,每一批是20张图片。这里的dataset和dataset_test分别是训练数据和测试数据,dataset=dataset.repeat()
dataset_test=dataset_test.repeat()这两句话是一旦数据所有批次都读到头了就从头开始重复读。 |
|