鱼C论坛

 找回密码
 立即注册
查看: 721|回复: 1

[已解决]求助 代码出现值错误

[复制链接]
发表于 2020-6-14 02:09:03 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
这是一个关于BP神经网络实现手写数字识别的代码 在运行的最后几步出现了未对齐的错误 求助大佬们应该怎么改啊

  1. #导入模块
  2. import numpy as np
  3. from sklearn.datasets import load_digits
  4. from sklearn.preprocessing import LabelBinarizer
  5. from sklearn.model_selection import train_test_split
  6. from sklearn.metrics import classification_report,confusion_matrix
  7. import matplotlib.pyplot as plt
  8. #载入数据
  9. digits = load_digits()
  10. print(digits.images.shape)
  11. #显示图片
  12. plt.imshow(digits.images[0],cmap = 'gray')
  13. plt.show()
  14. plt.imshow(digits.images[1],cmap = 'gray')
  15. plt.show()
  16. plt.imshow(digits.images[2],cmap = 'gray')
  17. plt.show
  18. #数据
  19. X = digits.data
  20. #标签
  21. y = digits.target
  22. print(X.shape)
  23. print(y.shape)
  24. print(X[:3])
  25. print(y[:3])
  26. # 定义一个神经网络,结构:64-100-10
  27. # 定义输入层到隐藏层之间的权值矩阵
  28. V = np.random.random((64,100))*2-1
  29. # 定义隐藏层到输出层之间的权值矩阵
  30. V = np.random.random((100,10))*2-1
  31. # 数据切分
  32. # 1/4为测试集,3/4为训练集
  33. X_train,X_test,y_train,y_test = train_test_split(X,y)
  34. #标签二值化
  35. lables_train = LabelBinarizer().fit_transform(y_train)
  36. print(y_train[:5])
  37. print('哈哈')
  38. print(lables_train[:5])
  39. # 激活函数
  40. def sigmoid(x):
  41.     return 1/(1 + np.exp(-x))
  42. #激活函数的导数
  43. def dsigmoid(x):
  44.     return x*(1-x)
  45. #训练模型
  46. def train(X,y,steps = 10000,lr = 0.11):
  47.     global V,W
  48.     for n in range(steps + 1):
  49.         #随机选取一个数据
  50.         i = np.random.randint(X.shape[0])
  51.         #获取一个数据
  52.         x = X[i]
  53.         x = np.atleast_2d(x)
  54.         # BP算法公式
  55.         # 计算隐藏层的输出
  56.         L1 = sigmoid(np.dot(x,V))
  57.         #计算输出层的输出
  58.         L2 = sigmoid(np.dot(L1,W))
  59.         #计算L2_delta,L1_delta
  60.         L2_delta = (y[i] - L2) * dsigmoid(L2)
  61.         L1_delta = L2_delta.dot(W.t) * dsigmoid(L1)
  62.         #更新权值
  63.         W += lr * L1.T.dot(L2_delta)
  64.         V += lr*x.T.dot(L1_delta)
  65.         #每训练1000次预测一次准确率
  66.         if n % 1000 == 0:
  67.             output = predict(X_test)
  68.             predictions = np.argmax(output,axis = 1)
  69.             acc = np.mean(np.equal(predictions,y_test))
  70.             print('steps:',n,'accuracy:',acc)
  71. def predict(x):
  72.     #计算隐藏的输出
  73.     L1 = sigmoid(np.dot(x,V))
  74.     #计算输出的输出
  75.     L2 = sigmoid(np.dot(L1,W))
  76.     return L2
  77. # 进行30000次的模型训练,完成1000次训练输出一次训练结果
  78. train(X_train,lables_train,30000)
  79. # 用测试数据对已训练的模型进行测试
  80. output = predict(X_test)
  81. # 按行查找训练结果的最大元素
  82. predictions = np.argmax(output,axis = 1)
  83. # 比较预测测试标签和真实标签,并输出准确率
  84. print(classification_report(predictions,y_test))
  85. #利用混淆矩阵实现模型评估,矩阵行数据相加是真实值类别数,列数据相加是分类后的类别数
  86. print(confusion_matrix(predictions,y_test))
复制代码
最佳答案
2020-6-14 08:37:21

报错那写了

  1. L1 = sigmoid(np.dot(x,V))
复制代码


这里58行未对奇,可能是长度不一样,导致的吧,你检查检查你的代码
QQ图片20200614020636.png
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2020-6-14 08:37:21 | 显示全部楼层    本楼为最佳答案   

报错那写了

  1. L1 = sigmoid(np.dot(x,V))
复制代码


这里58行未对奇,可能是长度不一样,导致的吧,你检查检查你的代码
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-6-21 23:03

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表