|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
废话不多说,先贴代码
- from tensorflow.examples.tutorials.mnist import input_data
- import tensorflow as tf
- #以下只要有tensorflow库就会自动下载'MNIST数据,使用Jupyter前现在本地编译器上跑一边,然后修改具体路径
- mnist = input_data.read_data_sets('MNIST_data/', one_hot = True)
- sess = tf.InteractiveSession()
- x = tf.placeholder(tf.float32, [None, 784])
- W = tf.Variable(tf.zeros([784, 10]))
- b = tf.Variable(tf.zeros([10]))
- #算法公式
- y = tf.nn.softmax(tf.matmul(x, W) + b)
- y_ = tf.placeholder(tf.float32, [None, 10])
- #损失函数
- cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),
- reduction_indices = [1]))
- #优化算法
- train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
- #全局参数初始化器
- tf.global_variables_initializer().run()
- #训练过程
- for i in range(100):
- batch_xs, batch_ys = mnist.train.next_batch(100)
- train_step.run({x:batch_xs, y_:batch_ys})
- #正确率判断(用已经训练好的W和b来输出预测值)
- correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
- accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- #可以打印看一下数据原始格式,应该是把图片灰度值转换成白色为0到黑色为1的一维矩阵图了
- #print(mnist.test.images[0])
- print(accuracy.eval({x:mnist.test.images, y_:mnist.test.labels}))
复制代码
我这边跑了一下,test结果很牛,0.917,看着很腻害~
然后发现书上,网络上几乎没有找到MNIST实际测试结果,以及怎么来检查识别的实际效果。为啥大家不试一试实际效果呢?很费解。。。
经过研究,实际验证操作方法应该如下:
- #实际测试
- from PIL import Image
- import numpy as np
- def Get_Image_np(file_in):
- width = 28
- height = 28
- image = Image.open(file_in)
- resized_image = image.resize((width, height), Image.ANTIALIAS)
- grey = np.array(resized_image.convert('L'))
- #正常灰度图片白色是255,黑色是0,以下数据表达方式转换成白色转换为0,黑色转换为1,和MNIST保持一致
- transform = 1 - grey / 255
- resized_image.save(r'D:\6661.PNG')
- return transform
- a = Get_Image_np(r'D:\666.PNG')
- #print(a.shape)
- #print(a)
- z = a.reshape((1,784))
- #转换成定义tf输入的数据类型
- z = z.astype(np.float32)
- #print(z)
- h = tf.nn.relu(tf.matmul(z, W1) + b1)
- r = tf.nn.softmax(tf.matmul(h, W2) + b2)
- #上面这么些可能大家都好理解一些,这里还可以用tf.placeholder来feed_dict输入数据,对tensorflow熟悉一些的朋友可以试一试
- print(sess.run(r))
- print(sess.run(tf.argmax(r,1)))
复制代码
识别结果居然是5!!!
识别图片是网上随便找的,压缩后效果如6661.PNG
猜测原因:
网上下的手写6大小比例和训练集差的比较多,识别错误率明细高于MNIST测试集精度结果,各位有兴趣可以试一试。
理解错误和不足之处还请各位大佬指教,下次我再跑一下加上卷积神经网络层的效果。 |
|