|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
本帖最后由 KeyError 于 2023-3-5 12:05 编辑
- #Network.py
- import matplotlib.pyplot as plt
- from scipy.special import expit as s
- import numpy as np
- class NeuralNetwork:
- #初始化神经网络
- def __init__(self,input_nodes,hidden_nodes,output_nodes,learingrate):
- self.inodes = input_nodes
- self.hnodes = hidden_nodes
- self.onodes = output_nodes
- #学习率
- self.lr = learingrate
- #权重
- self.w_ih = ( np.random.rand(self.hnodes,self.inodes) - 0.5 )
- self.w_ho = ( np.random.rand(self.onodes,self.hnodes) - 0.5 )
- pass
- #优化权重
- def train(self,inputs,trues):
- input_list=np.array(inputs,ndmin=2).T
- hi = np.dot(self.w_ih,input_list)
- ho = s(hi)
- oi = np.dot(self.w_ho,ho)
- oo = s(oi)
- true_list = np.array(trues,ndmin=2).T
- ho_error = true_list-oo
- ih_error = np.dot(self.w_ho.T,ho_error)
- self.w_ho += self.lr * np.dot(( ho_error * oo * (1.0 - oo)),np.transpose(ho))
- self.w_ih += self.lr * np.dot(( ih_error * ho * (1.0 - ho)),np.transpose(input_list))
- #计算输出
- def query(self,inputs):
- input_list=np.array(inputs,ndmin=2).T
- hi = np.dot(self.w_ih,input_list)
- ho = s(hi)
- oi = np.dot(self.w_ho,ho)
- oo = s(oi)
- return oo
- def __repr__(self):
- return 'A NetWork'
- __str__=__repr__
- pass
复制代码- #main.py
- import Network
- n=Network.NeuralNetwork(10000,100,10,0.1)
- file=open('mnist_train.csv','r')
- file_test=open('mnist_test.csv','r')
- test=file_test.readlines()
- data=file.readlines()
- file.close()
- file_test.close()
- for i in range(5):
- for mr in data:
- value = mr.split(',')
- inputs = (Network.np.asfarray(value[1:]) / 255.0 * 0.99) + 0.01
- true=Network.np.zeros(10) + 0.01
- true[int(value[0])] = 0.99
- n.train(inputs,true)
- def shi(zuo):
- value = test [zuo].split(',')
- image=Network.np.asfarray(value[1:]).reshape((28,28))
- Network.plt.imshow(image,cmap='Greys')
- inputs = (Network.np.asfarray(value[1:]) / 255.0 * 0.99) + 0.01
- a1 = n.query(inputs)
- a2 = max(a1)
- AI = list(a1).index(a2)
- true = value[0]
- print('The AI is',AI)
- print('The true is',true)
- Network.plt.show()
- return AI
复制代码
这里下载CSV
你也可以用自己的手写数字,但要先把它转换为784x1的CSV
|
评分
-
查看全部评分
|