马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
本帖最后由 KeyError 于 2023-3-5 12:05 编辑 #Network.py
import matplotlib.pyplot as plt
from scipy.special import expit as s
import numpy as np
class NeuralNetwork:
#初始化神经网络
def __init__(self,input_nodes,hidden_nodes,output_nodes,learingrate):
self.inodes = input_nodes
self.hnodes = hidden_nodes
self.onodes = output_nodes
#学习率
self.lr = learingrate
#权重
self.w_ih = ( np.random.rand(self.hnodes,self.inodes) - 0.5 )
self.w_ho = ( np.random.rand(self.onodes,self.hnodes) - 0.5 )
pass
#优化权重
def train(self,inputs,trues):
input_list=np.array(inputs,ndmin=2).T
hi = np.dot(self.w_ih,input_list)
ho = s(hi)
oi = np.dot(self.w_ho,ho)
oo = s(oi)
true_list = np.array(trues,ndmin=2).T
ho_error = true_list-oo
ih_error = np.dot(self.w_ho.T,ho_error)
self.w_ho += self.lr * np.dot(( ho_error * oo * (1.0 - oo)),np.transpose(ho))
self.w_ih += self.lr * np.dot(( ih_error * ho * (1.0 - ho)),np.transpose(input_list))
#计算输出
def query(self,inputs):
input_list=np.array(inputs,ndmin=2).T
hi = np.dot(self.w_ih,input_list)
ho = s(hi)
oi = np.dot(self.w_ho,ho)
oo = s(oi)
return oo
def __repr__(self):
return 'A NetWork'
__str__=__repr__
pass
#main.py
import Network
n=Network.NeuralNetwork(10000,100,10,0.1)
file=open('mnist_train.csv','r')
file_test=open('mnist_test.csv','r')
test=file_test.readlines()
data=file.readlines()
file.close()
file_test.close()
for i in range(5):
for mr in data:
value = mr.split(',')
inputs = (Network.np.asfarray(value[1:]) / 255.0 * 0.99) + 0.01
true=Network.np.zeros(10) + 0.01
true[int(value[0])] = 0.99
n.train(inputs,true)
def shi(zuo):
value = test [zuo].split(',')
image=Network.np.asfarray(value[1:]).reshape((28,28))
Network.plt.imshow(image,cmap='Greys')
inputs = (Network.np.asfarray(value[1:]) / 255.0 * 0.99) + 0.01
a1 = n.query(inputs)
a2 = max(a1)
AI = list(a1).index(a2)
true = value[0]
print('The AI is',AI)
print('The true is',true)
Network.plt.show()
return AI
这里下载CSV
你也可以用自己的手写数字,但要先把它转换为784x1的CSV
|