鱼C论坛

 找回密码
 立即注册
查看: 1080|回复: 1

[技术交流] 通过tensorflow读取数据

[复制链接]
发表于 2018-11-27 19:11:11 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 alltolove 于 2018-11-27 19:17 编辑

        tensorflow读取数据的方法很多,有多线程队列的方法:
  1. import tensorflow as tf  
  2. import sys  
  3. q = tf.FIFOQueue(10, "float")  
  4. counter = tf.Variable(0.0)  #计数器
  5. # 给计数器加一
  6. increment_op = tf.assign_add(counter, 1.0)
  7. # 将计数器加入队列
  8. enqueue_op = q.enqueue(counter)

  9. # 创建QueueRunner
  10. # 用多个线程向队列添加数据
  11. # 这里实际创建了4个线程,两个增加计数,两个执行入队
  12. qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2)

  13. # 主线程
  14. sess = tf.InteractiveSession()
  15. tf.global_variables_initializer().run()
  16. # 启动入队线程
  17. qr.create_threads(sess, start=True)
  18. for i in range(20):
  19.     print (sess.run(q.dequeue()))
复制代码

        还有随机读取数据的方法:
  1. import tensorflow as tf

  2. q = tf.RandomShuffleQueue(capacity=10,min_after_dequeue=0, dtypes="string")
  3. enqueue_op = q.enqueue(['aaa','bbb'])
  4. qr = tf.train.QueueRunner(q, enqueue_ops=[enqueue_op] * 1)
  5. sess = tf.Session()
  6. coord = tf.train.Coordinator()
  7. enqueue_threads = qr.create_threads(sess, start=True,coord=coord)
  8. try:
  9.     for i in range(100):
  10.         print(sess.run(q.dequeue()))
  11.         if i>=50:
  12.             coord.request_stop()
  13.             coord.join(enqueue_threads)
  14.             print(i)
  15. except tf.errors.OutOfRangeError:
  16.     print('Done training -- epoch limit reached')
  17. finally:
  18.     coord.request_stop()
  19.     coord.join(enqueue_threads)
复制代码

        还有读取文件的方法:
  1. import tensorflow as tf

  2. train_filenames=['data_batch_%d'%i for i in range(1,6)]

  3. q = tf.FIFOQueue(capacity=5, dtypes=tf.string, name="queue")


  4. init = q.enqueue_many((train_filenames,))


  5. x = q.dequeue()

  6. with tf.Session() as sess:

  7.     # 运行初始化队列的操作。
  8.     init.run()
  9.     for i in range(5):

  10.         # 运行q_inc将执行数据出队列,出队的元素值加1,重新加入队列的整个过程。
  11.         v= sess.run(x)

  12.         # 打印出队元素的值。
  13.         print(v)
复制代码

        还有反正是各种用线程读数据方法:
  1. import tensorflow as tf
  2. import numpy as np



  3. queue = tf.FIFOQueue(100, tf.float32)

  4. enqueue_op = queue.enqueue([[1]])


  5. qr = tf.train.QueueRunner(queue, [enqueue_op]*5)


  6. tf.train.add_queue_runner(qr)


  7. out_tensor = queue.dequeue()


  8. with tf.Session() as sess:

  9.    
  10.     coord = tf.train.Coordinator()


  11.     threads = tf.train.start_queue_runners(sess=sess, coord=coord)


  12.     for i in range(6):


  13.         print(sess.run(out_tensor)[0])


  14.     coord.request_stop()
  15.     coord.join(threads)
复制代码

       
        以上这些方法都是太深奥了,没有多少人会用,对于我这个连入门都不算的新手来说更是没法用。
        现在我主要用的是dataset方法,代码如下:
  1. import tensorflow as tf
  2. import numpy as np
  3. BATCH_SIZE = 5
  4. x = np.array([
  5.     [100,2,2,3],
  6.     [200,1,1,1],
  7.     [300,2,2,2],
  8.     [400,21,1,1]
  9. ])
  10. y = np.array([
  11.     [100,2,2,3],
  12.     [200,1,1,1],
  13.     [300,2,2,2],
  14.     [400,21,1,1]
  15. ])
  16. dataset = tf.data.Dataset.from_tensor_slices({'x':x,'y':y})
  17. dataset = dataset.shuffle(buffer_size=4)
  18. dataset=dataset.batch(BATCH_SIZE)
  19. dataset=dataset.repeat()
  20. iter = dataset.make_one_shot_iterator()
  21. el = iter.get_next()

  22. with tf.Session() as sess:
  23.     for i in range(6):
  24.         print('x',sess.run(el)['x'])
  25.         print('y',sess.run(el)['y'])
复制代码

        这种方法有个坑,就是x和y随机shuffle的数据不统一,x和y都是各自随机的。

本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2018-11-27 21:04:32 | 显示全部楼层
都是版主纯原创的吗
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-4-23 21:25

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表