|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
本帖最后由 alltolove 于 2018-11-27 19:17 编辑
tensorflow读取数据的方法很多,有多线程队列的方法:
- import tensorflow as tf
- import sys
- q = tf.FIFOQueue(10, "float")
- counter = tf.Variable(0.0) #计数器
- # 给计数器加一
- increment_op = tf.assign_add(counter, 1.0)
- # 将计数器加入队列
- enqueue_op = q.enqueue(counter)
- # 创建QueueRunner
- # 用多个线程向队列添加数据
- # 这里实际创建了4个线程,两个增加计数,两个执行入队
- qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2)
- # 主线程
- sess = tf.InteractiveSession()
- tf.global_variables_initializer().run()
- # 启动入队线程
- qr.create_threads(sess, start=True)
- for i in range(20):
- print (sess.run(q.dequeue()))
复制代码
还有随机读取数据的方法:
- import tensorflow as tf
-
- q = tf.RandomShuffleQueue(capacity=10,min_after_dequeue=0, dtypes="string")
- enqueue_op = q.enqueue(['aaa','bbb'])
- qr = tf.train.QueueRunner(q, enqueue_ops=[enqueue_op] * 1)
- sess = tf.Session()
- coord = tf.train.Coordinator()
- enqueue_threads = qr.create_threads(sess, start=True,coord=coord)
- try:
- for i in range(100):
- print(sess.run(q.dequeue()))
- if i>=50:
- coord.request_stop()
- coord.join(enqueue_threads)
- print(i)
- except tf.errors.OutOfRangeError:
- print('Done training -- epoch limit reached')
- finally:
- coord.request_stop()
- coord.join(enqueue_threads)
复制代码
还有读取文件的方法:
- import tensorflow as tf
- train_filenames=['data_batch_%d'%i for i in range(1,6)]
- q = tf.FIFOQueue(capacity=5, dtypes=tf.string, name="queue")
- init = q.enqueue_many((train_filenames,))
- x = q.dequeue()
- with tf.Session() as sess:
- # 运行初始化队列的操作。
- init.run()
- for i in range(5):
- # 运行q_inc将执行数据出队列,出队的元素值加1,重新加入队列的整个过程。
- v= sess.run(x)
- # 打印出队元素的值。
- print(v)
复制代码
还有反正是各种用线程读数据方法:
- import tensorflow as tf
- import numpy as np
- queue = tf.FIFOQueue(100, tf.float32)
- enqueue_op = queue.enqueue([[1]])
- qr = tf.train.QueueRunner(queue, [enqueue_op]*5)
- tf.train.add_queue_runner(qr)
- out_tensor = queue.dequeue()
- with tf.Session() as sess:
-
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
- for i in range(6):
- print(sess.run(out_tensor)[0])
- coord.request_stop()
- coord.join(threads)
复制代码
以上这些方法都是太深奥了,没有多少人会用,对于我这个连入门都不算的新手来说更是没法用。
现在我主要用的是dataset方法,代码如下:
- import tensorflow as tf
- import numpy as np
- BATCH_SIZE = 5
- x = np.array([
- [100,2,2,3],
- [200,1,1,1],
- [300,2,2,2],
- [400,21,1,1]
- ])
- y = np.array([
- [100,2,2,3],
- [200,1,1,1],
- [300,2,2,2],
- [400,21,1,1]
- ])
- dataset = tf.data.Dataset.from_tensor_slices({'x':x,'y':y})
- dataset = dataset.shuffle(buffer_size=4)
- dataset=dataset.batch(BATCH_SIZE)
- dataset=dataset.repeat()
- iter = dataset.make_one_shot_iterator()
- el = iter.get_next()
- with tf.Session() as sess:
- for i in range(6):
- print('x',sess.run(el)['x'])
- print('y',sess.run(el)['y'])
复制代码
这种方法有个坑,就是x和y随机shuffle的数据不统一,x和y都是各自随机的。 |
|