鱼C论坛

 找回密码
 立即注册
查看: 2433|回复: 4

[作品展示] 卷积神经网络0-9数字识别

[复制链接]
发表于 2020-12-30 13:50:01 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. import numpy as np
  4. from PIL import Image

  5. #算法超参数
  6. batch_size=100
  7. training_epochs=10
  8. learning_rate_init=0.01
  9. display_step=100

  10. #网络参数
  11. n_input=784
  12. n_classes=10

  13. #根据指定的维数返回指定的权数
  14. def WeightsVariable(shape,name_str,stddev=0.1):
  15.     """
  16.     :param shape:               形状,n行n列
  17.     :param name_str:            变量名称
  18.     :param stddev:              正态分布标准差
  19.     :return:
  20.     """
  21.     initial=tf.random_normal(shape=shape,stddev=stddev,dtype=tf.float32)
  22.     return tf.Variable(initial_value=initial,dtype=tf.float32,name=name_str)

  23. #根据指定维数返回指定的偏置
  24. def BiasesVariable(shape,name_str,stddev=0.1):
  25.     """
  26.     :param shape:               偏置的形状,nxn
  27.     :param name_str:            变量名称
  28.     :param stddev:              正态分布标准差
  29.     :return:
  30.     """
  31.     initial=tf.random_normal(shape=shape,stddev=stddev,dtype=tf.float32)
  32.     return tf.Variable(initial_value=initial,dtype=tf.float32,name=name_str)

  33. #二维卷积层的封装(conv2d+bias)
  34. def Conv2d(X,w,b,stride=1,padding="SAME"):
  35.     """
  36.     :param X:                   特征图输入n*n*None
  37.     :param w:                   滤波器(权重参数)
  38.     :param b:                   修正(偏置参数)
  39.     :param stride:              滑动步长(默认为1)
  40.     :param padding:             外填充
  41.     :return:
  42.     """
  43.     with tf.name_scope("Conv2d"):
  44.         y=tf.nn.conv2d(X,w,strides=[1,stride,stride,1],padding=padding)
  45.         y=tf.nn.bias_add(y,b)
  46.         return y

  47. #非线性激活层的封装
  48. def Activation(x,activation=tf.nn.relu,name="relu"):
  49.     """
  50.     :param x:                   特征图输入
  51.     :param activation:          激活函数,类似于看待事物的一种模式
  52.     :param name:                名称
  53.     :return:
  54.     """
  55.     with tf.name_scope(name):
  56.         y=activation(x)
  57.         return y

  58. #最大池化层的封装
  59. def Pool2d(x,pool=tf.nn.max_pool,k=2,stride=2):
  60.     return pool(x,ksize=[1,k,k,1],strides=[1,stride,stride,1],padding="VALID")

  61. #全连接层封装
  62. def FullyConnect(x,w,b,activate=tf.identity,act_name="identity"):
  63.     with tf.name_scope("Wx_b"):
  64.         y=tf.add(tf.matmul(x,w),b)
  65.     with tf.name_scope("SoftMax"):
  66.         y=tf.nn.softmax(logits=y)
  67.     with tf.name_scope(act_name):
  68.         y=activate(y)
  69.     return y

  70. #通用的评估函数,用来评估模型在给定的数据集上的损失和准确率
  71. def EvaluateMode10nDataset(sess,images,labels):
  72.     n_samples=images.shape[0]  #传入的样本数量
  73.     per_batch_size=batch_size
  74.     loss=0
  75.     acc=0
  76.     if(n_samples<=per_batch_size):
  77.         batch_count=1
  78.         loss,acc=sess.run([cross_entropy_loss,accuracy],
  79.                           feed_dict={X_origin:images,
  80.                                      Y_true:labels,
  81.                                      learning_rate:learning_rate_init})
  82.     else:
  83.         batch_count=int(n_samples/per_batch_size)
  84.         batch_start=0
  85.         for idx in range(batch_count):
  86.             batch_loss,batch_acc=sess.run([cross_entropy_loss,accuracy],
  87.                                           feed_dict={X_origin:images[batch_start:batch_start+per_batch_size,:],
  88.                                                      Y_true:labels[batch_start:batch_start+per_batch_size,:],
  89.                                                      learning_rate:learning_rate_init})
  90.             batch_start+=per_batch_size
  91.             #累计所有批次上的损失和准确率
  92.             loss+=batch_loss
  93.             acc+=batch_acc
  94.     return loss/batch_count,acc/batch_count

  95. '''计算图绘制'''
  96. with tf.Graph().as_default():
  97.     #输入层:28*28*1
  98.     with tf.name_scope("Input"):
  99.         X_origin=tf.placeholder(tf.float32,shape=[None,n_input],name="X_origin")
  100.         Y_true=tf.placeholder(tf.float32,shape=[None,n_classes],name="Y_true")
  101.         X_image=tf.reshape(X_origin,[-1,28,28,1],name="X_image")

  102.     #前向推断过程
  103.     with tf.name_scope("Inference"):

  104.         #卷积层第一层:24*24*16
  105.         with tf.name_scope("Conv2d_1"):
  106.             weight_1=WeightsVariable(shape=[5,5,1,16],name_str="weight_1")
  107.             bias_1=BiasesVariable(shape=[16],name_str="bias_1")
  108.             conv1_out=Conv2d(X_image,weight_1,bias_1,stride=1,padding="VALID")

  109.         #非线性激活层
  110.         with tf.name_scope("Active"):
  111.             activation_out=Activation(x=conv1_out,activation=tf.nn.relu,name="relu")

  112.         #池化层(默认采用最大池化)12*12*16
  113.         with tf.name_scope("Pool2d_1"):
  114.             pool1_out=Pool2d(x=activation_out,pool=tf.nn.max_pool,k=2,stride=2)

  115.         #将二维特征图转变为1维,2304
  116.         with tf.name_scope("FeatsReshape"):
  117.             features=tf.reshape(pool1_out,[-1,12*12*16])

  118.         #全连接层
  119.         with tf.name_scope("FC_linear1"):
  120.             weight_fc=WeightsVariable(shape=[12*12*16,n_classes],name_str="weight_fc")
  121.             bias_fc=BiasesVariable(shape=[n_classes],name_str="bias_fc")
  122.             Ypre_logits=FullyConnect(features,weight_fc,bias_fc,activate=tf.identity,act_name="identity")

  123.     #定义损失层
  124.     with tf.name_scope("Loss"):
  125.         cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
  126.             labels=Y_true, logits=Ypre_logits
  127.         ))

  128.     #定义训练优化层
  129.     with tf.name_scope("Train"):
  130.         global_step = tf.Variable(0, name='global_step', trainable=False)
  131.         learning_rate=tf.placeholder(tf.float32)
  132.         optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate)
  133.         trainer=optimizer.minimize(cross_entropy_loss,global_step=global_step)

  134.     #定义模型评估层
  135.     with tf.name_scope("Evaluate"):
  136.         correct_pred=tf.equal(tf.argmax(Ypre_logits,1),tf.argmax(Y_true,1))
  137.         accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))

  138.     #定义初始化所有变量节点
  139.     init=tf.initialize_all_variables()

  140.     print("将计算图写入log事件文件中,并在tensorboard中查看!!")
  141.     writer=tf.summary.FileWriter("logs/",graph=tf.get_default_graph())
  142.     writer.close()

  143.     # 导入mnist数据集
  144.     mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)

  145.     '''启动会话'''
  146.     with tf.Session() as sess:
  147.         sess.run(init)                  #初始化variable变量
  148.         total_batches=int(mnist.train.num_examples/batch_size)
  149.         print("每批次的样本数量:",batch_size)
  150.         print("总共的批次数量:",total_batches)
  151.         print("总共的训练数据:",mnist.train.num_examples)

  152.         #保存和载入网络
  153.         saver=tf.train.Saver()
  154.         checkpoint=tf.train.get_checkpoint_state("saver_network")
  155.         if checkpoint and checkpoint.model_checkpoint_path:
  156.             saver.restore(sess,checkpoint.model_checkpoint_path)
  157.             print("Successfully loaded:", checkpoint.model_checkpoint_path)
  158.         else:
  159.             print("Could not find old network weights")

  160.         training_step=0             #记录模型被训练的步数

  161.         #指定训练轮数,将所有的样本都训练一遍
  162.         for epoch in range(training_epochs):
  163.             #把一轮所有的batch都跑一遍
  164.             for batch_idx in range(total_batches):
  165.                 # 取出数据
  166.                 batch_x, batch_y = mnist.train.next_batch(batch_size)
  167.                 #训练优化器训练节点
  168.                 sess.run(trainer,feed_dict={
  169.                     X_origin:batch_x,
  170.                     Y_true:batch_y,
  171.                     learning_rate:learning_rate_init,
  172.                 })

  173.                 #每调用一次训练节点,training_step就加1
  174.                 training_step=training_step+1

  175.                 # 每训练display_step次,就计算当前模型的损失和准确率
  176.                 if training_step % display_step == 0:
  177.                     start_idx = max(0, (batch_idx - display_step) * batch_size)
  178.                     end_idx = batch_idx * batch_size
  179.                     train_loss, train_acc = EvaluateMode10nDataset(sess,
  180.                                                                    mnist.train.images[start_idx:end_idx, :],
  181.                                                                    mnist.train.labels[start_idx:end_idx, :])
  182.                     print("Training Step:" + str(training_step) +
  183.                           ",Training_Loss={:.6f}".format(train_loss) +
  184.                           ",Training_accuracy={:.5f}".format(train_acc))

  185.                     # 计算当前模型在验证集上的损失和准确率
  186.                     validation_loss, validation_acc = EvaluateMode10nDataset(sess,
  187.                                                                              mnist.validation.images,
  188.                                                                              mnist.validation.labels)
  189.                     print("Training Step:" + str(training_step) +
  190.                           ",Validation_Loss={:.6f}".format(validation_loss) +
  191.                           ",Validation_accuracy={:.5f}".format(validation_acc))

  192.                 if training_step%display_step==0:
  193.                     saver.save(sess,"./saver_network/check_mnist",global_step=global_step)

  194.         print("训练完毕!!!")

  195.         img = Image.open("5.png")
  196.         img = img.convert('L')  # 灰度化
  197.         img = np.reshape(img, (784, 1)).reshape(1, 784)
  198.         img = ((255 - np.array(img, dtype=np.uint8)) / 255.0).reshape(1, 784)
  199.         logits=sess.run(Ypre_logits,feed_dict={
  200.             X_origin:img
  201.         })
  202.         print("预测的图像数字为:",np.argmax(logits))

  203. if __name__ == '__main__':
  204.     pass

复制代码
识别.png
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2020-12-30 13:53:31 | 显示全部楼层
因为gui界面功能尚未完善,因此将主要逻辑上传至此(可以运行),后续持续更新代码完善
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2020-12-30 16:45:47 | 显示全部楼层
高手啊
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2021-1-9 15:13:48 | 显示全部楼层
为什么我运行提示没有名叫“tensorflow.examples”的模块
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2021-1-12 22:09:14 | 显示全部楼层
江晓夜 发表于 2021-1-9 15:13
为什么我运行提示没有名叫“tensorflow.examples”的模块

因为你用的2.0,要用1.0的版本
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-6-29 01:31

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表