鱼C论坛

 找回密码
 立即注册
查看: 2470|回复: 3

[作品展示] Python数字识别

[复制链接]
发表于 2023-3-5 12:03:13 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 KeyError 于 2023-3-5 12:05 编辑
#Network.py
import matplotlib.pyplot as plt
from scipy.special import expit as s
import numpy as np
class NeuralNetwork:
    #初始化神经网络
    def __init__(self,input_nodes,hidden_nodes,output_nodes,learingrate):
        self.inodes = input_nodes
        self.hnodes = hidden_nodes
        self.onodes = output_nodes
        #学习率
        self.lr = learingrate
        #权重
        self.w_ih = ( np.random.rand(self.hnodes,self.inodes) - 0.5 )
        self.w_ho = ( np.random.rand(self.onodes,self.hnodes) - 0.5 )
        pass
    #优化权重
    def train(self,inputs,trues):
        input_list=np.array(inputs,ndmin=2).T
        hi = np.dot(self.w_ih,input_list)
        ho = s(hi)
        oi = np.dot(self.w_ho,ho)
        oo = s(oi)
        true_list = np.array(trues,ndmin=2).T
        ho_error = true_list-oo
        ih_error = np.dot(self.w_ho.T,ho_error)
        self.w_ho += self.lr * np.dot(( ho_error * oo * (1.0 - oo)),np.transpose(ho))
        self.w_ih += self.lr * np.dot(( ih_error * ho * (1.0 - ho)),np.transpose(input_list)) 
    #计算输出
    def query(self,inputs):
        input_list=np.array(inputs,ndmin=2).T
        hi = np.dot(self.w_ih,input_list)
        ho = s(hi)
        oi = np.dot(self.w_ho,ho)
        oo = s(oi)
        return oo
    def __repr__(self):
        return 'A NetWork'
    __str__=__repr__
    pass
#main.py
import Network
n=Network.NeuralNetwork(10000,100,10,0.1)
file=open('mnist_train.csv','r')
file_test=open('mnist_test.csv','r')
test=file_test.readlines()
data=file.readlines()
file.close()
file_test.close()
for i in range(5):
    for mr in data:
        value = mr.split(',')
        inputs = (Network.np.asfarray(value[1:]) / 255.0 * 0.99) + 0.01
        true=Network.np.zeros(10) + 0.01
        true[int(value[0])] = 0.99
        n.train(inputs,true)
def shi(zuo):
        value = test [zuo].split(',')
        image=Network.np.asfarray(value[1:]).reshape((28,28))
        Network.plt.imshow(image,cmap='Greys')
        inputs = (Network.np.asfarray(value[1:]) / 255.0 * 0.99) + 0.01
        a1 = n.query(inputs)
        a2 = max(a1)
        AI = list(a1).index(a2)
        true = value[0]
        print('The AI is',AI)
        print('The true is',true)
        Network.plt.show()
        return AI
这里下载CSV
你也可以用自己的手写数字,但要先把它转换为784x1的CSV

评分

参与人数 1荣誉 +5 鱼币 +5 收起 理由
歌者文明清理员 + 5 + 5 鱼C有你更精彩^_^

查看全部评分

本帖被以下淘专辑推荐:

  • · AI|主题: 2, 订阅: 0
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2023-3-5 12:26:05 | 显示全部楼层
666
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2023-3-5 19:36:30 | 显示全部楼层
AI
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2023-3-22 20:26:12 | 显示全部楼层
顶一顶
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-9-22 17:33

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表