KeyError 发表于 2023-3-5 12:03:13

Python数字识别

本帖最后由 KeyError 于 2023-3-5 12:05 编辑

#Network.py
import matplotlib.pyplot as plt
from scipy.special import expit as s
import numpy as np
class NeuralNetwork:
    #初始化神经网络
    def __init__(self,input_nodes,hidden_nodes,output_nodes,learingrate):
      self.inodes = input_nodes
      self.hnodes = hidden_nodes
      self.onodes = output_nodes
      #学习率
      self.lr = learingrate
      #权重
      self.w_ih = ( np.random.rand(self.hnodes,self.inodes) - 0.5 )
      self.w_ho = ( np.random.rand(self.onodes,self.hnodes) - 0.5 )
      pass
    #优化权重
    def train(self,inputs,trues):
      input_list=np.array(inputs,ndmin=2).T
      hi = np.dot(self.w_ih,input_list)
      ho = s(hi)
      oi = np.dot(self.w_ho,ho)
      oo = s(oi)
      true_list = np.array(trues,ndmin=2).T
      ho_error = true_list-oo
      ih_error = np.dot(self.w_ho.T,ho_error)
      self.w_ho += self.lr * np.dot(( ho_error * oo * (1.0 - oo)),np.transpose(ho))
      self.w_ih += self.lr * np.dot(( ih_error * ho * (1.0 - ho)),np.transpose(input_list))
    #计算输出
    def query(self,inputs):
      input_list=np.array(inputs,ndmin=2).T
      hi = np.dot(self.w_ih,input_list)
      ho = s(hi)
      oi = np.dot(self.w_ho,ho)
      oo = s(oi)
      return oo
    def __repr__(self):
      return 'A NetWork'
    __str__=__repr__
    pass
#main.py
import Network
n=Network.NeuralNetwork(10000,100,10,0.1)
file=open('mnist_train.csv','r')
file_test=open('mnist_test.csv','r')
test=file_test.readlines()
data=file.readlines()
file.close()
file_test.close()
for i in range(5):
    for mr in data:
      value = mr.split(',')
      inputs = (Network.np.asfarray(value) / 255.0 * 0.99) + 0.01
      true=Network.np.zeros(10) + 0.01
      true)] = 0.99
      n.train(inputs,true)
def shi(zuo):
      value = test .split(',')
      image=Network.np.asfarray(value).reshape((28,28))
      Network.plt.imshow(image,cmap='Greys')
      inputs = (Network.np.asfarray(value) / 255.0 * 0.99) + 0.01
      a1 = n.query(inputs)
      a2 = max(a1)
      AI = list(a1).index(a2)
      true = value
      print('The AI is',AI)
      print('The true is',true)
      Network.plt.show()
      return AI
这里下载CSV
你也可以用自己的手写数字,但要先把它转换为784x1的CSV

sfqxx 发表于 2023-3-5 12:26:05

666

hveagle 发表于 2023-3-5 19:36:30

AI

歌者文明清理员 发表于 2023-3-22 20:26:12

顶一顶
页: [1]
查看完整版本: Python数字识别