马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
本帖最后由 alltolove 于 2018-11-29 08:14 编辑
上次我们把cifar10的数据读进内存,然后就要把数据做成一批一批的进行处理,把每一批次数据打乱顺序,代码为:labels=labels.reshape((50000,1))
all_data=np.hstack((data,labels))
test_d,test_l=load_data(test_filename)
test_l=np.array(test_l)
test_l=test_l.reshape((10000,1))
all_test_data=np.hstack((test_d,test_l))
BATCH_SIZE = 20
dataset = tf.data.Dataset.from_tensor_slices(all_data)
dataset_test = tf.data.Dataset.from_tensor_slices(all_test_data).batch(BATCH_SIZE)
dataset = dataset.shuffle(buffer_size=50000)
dataset=dataset.batch(BATCH_SIZE)
dataset=dataset.repeat()
dataset_test=dataset_test.repeat()
iter_data = dataset.make_one_shot_iterator()
iter_data_test = dataset_test.make_one_shot_iterator()
el_data = iter_data.get_next()
el_data_test = iter_data_test.get_next()
这里的BATCH_SIZE就是批次的大小,每一批是20张图片。这里的dataset和dataset_test分别是训练数据和测试数据,dataset=dataset.repeat()
dataset_test=dataset_test.repeat()这两句话是一旦数据所有批次都读到头了就从头开始重复读。 |