鱼C论坛

 找回密码
 立即注册
查看: 1267|回复: 4

关于TensorFlow的tfrecord文件生成问题

[复制链接]
发表于 2019-2-17 22:20:55 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. import tensorflow as tf
  2. import os
  3. import random
  4. import math
  5. import sys
  6. from PIL import Image
  7. import numpy as np

  8. # 验证集数量
  9. _NUM_TEST = 500
  10. # 随机种子
  11. _RANDOM_SEED = 0
  12. # 数据集路径
  13. DATASET_DIR = "F:/holiday_study/tensorflow_1/class3-2/mnist/mnist/captcha/images"
  14. # tfrecord文件存放路径
  15. TFRECORD_DIR = "F:/holiday_study/tensorflow_1/class3-2/mnist/mnist/captcha/"


  16. # 判断tfrecord文件是不是存在
  17. def _dataset_exists(dataset_dir):
  18.     for split_name in ['train', 'test']:
  19.         output_filename = os.path.join(dataset_dir, split_name + '.tfrecords')
  20.         if not tf.gfile.Exists(output_filename):
  21.             return False
  22.     return True


  23. # 获取所有的验证码图片
  24. def _get_filenames_and_classes(dataset_dir):
  25.     photo_filenames = []
  26.     for filename in os.listdir(dataset_dir):
  27.         path = os.path.join(dataset_dir, filename)
  28.         photo_filenames.append(path)
  29.     return photo_filenames


  30. def int64_feature(values):
  31.     if not isinstance(values, (tuple, list)):
  32.         values = [values]
  33.     return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


  34. def bytes_feature(values):
  35.     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


  36. def image_to_tfexample(image_data, label0, label1, label2, label3):
  37.     return tf.train.Example(features=tf.train.Features(feature={
  38.         'image': bytes_feature(image_data),
  39.         'label0': int64_feature(label0),
  40.         'label1': int64_feature(label1),
  41.         'label2': int64_feature(label2),
  42.         'label3': int64_feature(label3),
  43.     }))


  44. # 把数据转为tfrecord格式
  45. def _convert_dataset(split_name, filenames, dataset_dir):
  46.     assert split_name in ['train', 'test']

  47.     with tf.Session() as sess:
  48.         # 定义tfrecord文件的路径和名字
  49.         output_filename = os.path.join(TFRECORD_DIR, split_name + '.tfrecords')
  50.         with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
  51.             for i,filename in enumerate(filenames):
  52.                 try:
  53.                     sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames)))
  54.                     sys.stdout.flush()

  55.                     # 读取照片
  56.                     image_data = Image.open(filename)
  57.                     # 根据模型的结构resize
  58.                     image_data = image_data.resize((224, 224))
  59.                     # 灰度化
  60.                     image_data = np.array(image_data.convert('L'))
  61.                     image_data = image_data.tobytes()

  62.                     # 获取labels
  63.                     labels = filename.split('/')[-1][0:4]
  64.                     num_labels = []
  65.                     for j in range(4):
  66.                         num_labels.append(labels[j])

  67.                     # 生成protocol数据类型
  68.                     example = image_to_tfexample(image_data, num_labels[0], num_labels[1], num_labels[2], num_labels[3])
  69.                     tfrecord_writer.write(example.SerializeToString())


  70.                 except IOError as e:
  71.                     print('could not read: ', filename)
  72.                     print('error: ', e)
  73.                     print('skip it \n')
  74.     sys.stdout.write('\n')




  75. if _dataset_exists(TFRECORD_DIR):
  76.     print("wenjiancunzai")
  77. else:
  78.     # 获取所有的图片
  79.     photo_filenames = _get_filenames_and_classes(DATASET_DIR)

  80.     # 把数据切分为训练集和测试机斌打乱
  81.     random.seed(_RANDOM_SEED)
  82.     random.shuffle(photo_filenames)
  83.     training_filenames = photo_filenames[_NUM_TEST:]
  84.     testing_filenames = photo_filenames[:_NUM_TEST]

  85.     # 数据转化
  86.     _convert_dataset('train', training_filenames, DATASET_DIR)
  87.     _convert_dataset('test', testing_filenames, DATASET_DIR)
  88.     print('生成tfrecord文件')
复制代码



报错
  1. Converting image 1/3448Traceback (most recent call last):
  2.   File "F:/holiday_study/tensorflow_1/class3-2/mnist/mnist/class9-2nsdn.py", line 111, in <module>
  3.     _convert_dataset('train', training_filenames, DATASET_DIR)
  4.   File "F:/holiday_study/tensorflow_1/class3-2/mnist/mnist/class9-2nsdn.py", line 85, in _convert_dataset
  5.     example = image_to_tfexample(image_data, num_labels[0], num_labels[1], num_labels[2], num_labels[3])
  6.   File "F:/holiday_study/tensorflow_1/class3-2/mnist/mnist/class9-2nsdn.py", line 50, in image_to_tfexample
  7.     'label0': int64_feature(label0),
  8.   File "F:/holiday_study/tensorflow_1/class3-2/mnist/mnist/class9-2nsdn.py", line 40, in int64_feature
  9.     return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
  10. TypeError: 'i' has type str, but expected one of: int, long
复制代码



查了一些解决方法还是没解决。。。。求助各位大佬
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2019-2-17 22:30:44 | 显示全部楼层
for j in range(4):
                        num_labels.append(int(labels[j]))
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2019-2-18 07:33:04 | 显示全部楼层
看错误提示啊,传入int64_feature函数的values必须是int或者long类型的,不能是str类型的,可以强制转换一下。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-2-19 10:35:30 | 显示全部楼层
塔利班 发表于 2019-2-17 22:30
for j in range(4):
                        num_labels.append(int(labels[j]))

invalid literal for int() with base 10: 'i'

强制转换会报这个错
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2019-2-19 10:38:45 | 显示全部楼层
kidcad 发表于 2019-2-19 10:35
invalid literal for int() with base 10: 'i'

强制转换会报这个错

那是你的filename有问题,不能这么分割,你看看你的filename要不要处理下
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2026-1-14 00:54

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表