鱼C论坛

 找回密码
 立即注册
查看: 3188|回复: 3

[作品展示] Python数字识别

[复制链接]
发表于 2023-3-5 12:03:13 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 KeyError 于 2023-3-5 12:05 编辑
  1. #Network.py
  2. import matplotlib.pyplot as plt
  3. from scipy.special import expit as s
  4. import numpy as np
  5. class NeuralNetwork:
  6.     #初始化神经网络
  7.     def __init__(self,input_nodes,hidden_nodes,output_nodes,learingrate):
  8.         self.inodes = input_nodes
  9.         self.hnodes = hidden_nodes
  10.         self.onodes = output_nodes
  11.         #学习率
  12.         self.lr = learingrate
  13.         #权重
  14.         self.w_ih = ( np.random.rand(self.hnodes,self.inodes) - 0.5 )
  15.         self.w_ho = ( np.random.rand(self.onodes,self.hnodes) - 0.5 )
  16.         pass
  17.     #优化权重
  18.     def train(self,inputs,trues):
  19.         input_list=np.array(inputs,ndmin=2).T
  20.         hi = np.dot(self.w_ih,input_list)
  21.         ho = s(hi)
  22.         oi = np.dot(self.w_ho,ho)
  23.         oo = s(oi)
  24.         true_list = np.array(trues,ndmin=2).T
  25.         ho_error = true_list-oo
  26.         ih_error = np.dot(self.w_ho.T,ho_error)
  27.         self.w_ho += self.lr * np.dot(( ho_error * oo * (1.0 - oo)),np.transpose(ho))
  28.         self.w_ih += self.lr * np.dot(( ih_error * ho * (1.0 - ho)),np.transpose(input_list))
  29.     #计算输出
  30.     def query(self,inputs):
  31.         input_list=np.array(inputs,ndmin=2).T
  32.         hi = np.dot(self.w_ih,input_list)
  33.         ho = s(hi)
  34.         oi = np.dot(self.w_ho,ho)
  35.         oo = s(oi)
  36.         return oo
  37.     def __repr__(self):
  38.         return 'A NetWork'
  39.     __str__=__repr__
  40.     pass
复制代码
  1. #main.py
  2. import Network
  3. n=Network.NeuralNetwork(10000,100,10,0.1)
  4. file=open('mnist_train.csv','r')
  5. file_test=open('mnist_test.csv','r')
  6. test=file_test.readlines()
  7. data=file.readlines()
  8. file.close()
  9. file_test.close()
  10. for i in range(5):
  11.     for mr in data:
  12.         value = mr.split(',')
  13.         inputs = (Network.np.asfarray(value[1:]) / 255.0 * 0.99) + 0.01
  14.         true=Network.np.zeros(10) + 0.01
  15.         true[int(value[0])] = 0.99
  16.         n.train(inputs,true)
  17. def shi(zuo):
  18.         value = test [zuo].split(',')
  19.         image=Network.np.asfarray(value[1:]).reshape((28,28))
  20.         Network.plt.imshow(image,cmap='Greys')
  21.         inputs = (Network.np.asfarray(value[1:]) / 255.0 * 0.99) + 0.01
  22.         a1 = n.query(inputs)
  23.         a2 = max(a1)
  24.         AI = list(a1).index(a2)
  25.         true = value[0]
  26.         print('The AI is',AI)
  27.         print('The true is',true)
  28.         Network.plt.show()
  29.         return AI
复制代码

这里下载CSV
你也可以用自己的手写数字,但要先把它转换为784x1的CSV

评分

参与人数 1荣誉 +5 鱼币 +5 收起 理由
歌者文明清理员 + 5 + 5 鱼C有你更精彩^_^

查看全部评分

本帖被以下淘专辑推荐:

  • · AI|主题: 2, 订阅: 0
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2023-3-5 12:26:05 | 显示全部楼层
666
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2023-3-5 19:36:30 | 显示全部楼层
AI
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2023-3-22 20:26:12 | 显示全部楼层
顶一顶
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-24 14:44

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表