鱼C论坛

 找回密码
 立即注册
查看: 1322|回复: 3

python tensorflow 维度报错

[复制链接]
发表于 2020-4-27 15:26:56 | 显示全部楼层 |阅读模式
30鱼币
核心代码和教学视频一样的  不知道哪错了  弄了一天了  奔溃  希望高手们发表一下见解,帮助一下我!!!
我用的是tensorflow2.0  数据加载进来也没问题  下面是报错截图:

报错图

报错图


维度报错我也验证了  没啥毛病:
1587912984483.jpg 1587913019294.jpg
再下面是代码:

  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. mnist=tf.keras.datasets.mnist
  5. (train_x,train_y),(text_x,text_y)=mnist.load_data()

  6. look_photo_x=text_x[:]   #备用验证
  7. look_photo_y=text_y[:]

  8. #定义训练参数
  9. lun_shu=20  #训练轮数
  10. yang_benshu=50  #单次训练样本数
  11. xue_xilv=0.001   #学习率
  12. cunt=0 #记录学习轮数
  13. loss_train_save=[] #保存训练集损失函数值
  14. acc_train_save=[]  #保存训练集Acc值

  15. #图像拉直.....变为一维数组

  16. print(train_x[0].shape)  #数据验证

  17. train_x=train_x.reshape(-1,784)
  18. text_x=text_x.reshape(-1,784)

  19. print(train_x[0].shape)#数据验证

  20. #特征数据归一化代码块
  21. train_x=tf.cast(train_x/255.0,tf.float32)
  22. text_x=tf.cast(text_x/255.0,tf.float32)

  23. #标签数据变成独热码形式
  24. train_y=tf.one_hot(train_y,depth=10)
  25. text_y=tf.one_hot(text_y,depth=10)

  26. #模型变量的定义
  27. Input_Dim=784
  28. H1_NN=64
  29. W1=tf.Variable(tf.random.normal( [Input_Dim , H1_NN], mean=0.0,stddev=1.0,dtype=tf.float32))
  30. B1=tf.Variable(tf.zeros([H1_NN]), dtype=tf.float32)
  31.                
  32. Output_Dim=10
  33. W2=tf.Variable(tf.random.normal( [H1_NN , Output_Dim], mean=0.0,stddev=1.0,dtype=tf.float32))
  34. B2=tf.Variable(tf.zeros( [Output_Dim] ) , dtype=tf.float32)
  35.                
  36. W=[W1,W2]
  37. B=[B1,B2]

  38. #模型的构建
  39. def model(x,w,b):
  40.     x=tf.matmul(x,w[0]) + b[0]
  41.     x=tf.nn.relu(x)
  42.     x=tf.matmul(x,w[1]) + b[1]
  43.     pred=tf.nn.softmax(x)
  44.     return pred

  45. #定义损失函数
  46. def loss(x,y,w,b):
  47.     pred=model(x,w,b)  #计算预测值和标签值的差异
  48.     loss_=tf.keras.losses.categorical_crossentropy(y_true=y,y_pred=pred)
  49.     return tf.reduce_mean(loss_)  #求均值,得出均方差

  50. #梯度计算函数
  51. def grad(x,y,w,b):
  52.     with tf.GradientTape() as tape:
  53.         loss_=loss(x,y,w,b)
  54.     return tape.gradient(loss_,[w,b])  #返回梯度向量

  55. #Adam优化器
  56. optimizer=tf.keras.optimizers.Adam(learning_rate=xue_xilv)

  57. #定义准确率
  58. def accuracy(x,y,w,b):
  59.     pred=model(x,w,b)
  60.     one_or_zero=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
  61.     #准确率,将布尔型转换为浮点型并计算均值
  62.     return tf.reduce_mean(tf.cast(one_or_zero,tf.float32))

  63. #直线图封装函数
  64. def zhi_xian_tu(x_,y_):
  65.     plt.rcParams['font.sans-serif']='SimHei' #设置字体为黑体
  66.    
  67.     plt.plot(x_,label='学习率',color='r')   #创建直线图
  68.     plt.plot(y_,label='损失率',color='b')
  69.    
  70.     plt.ylim(0,1)      #设置纵坐标长度为1
  71.    
  72.     plt.xlabel('百分比',fontsize=12)     #显示图例
  73.     plt.xlabel('学习次数',fontsize=12)
  74.    
  75.     plt.title('手写数字识别学习率与损失率直观图',fontsize=16)  #设置标题
  76.    
  77.     plt.legend()  #显示图例
  78.     plt.show()  #显示完整直线图

  79. total_step=int(len(train_x)/yang_benshu)

  80. for i in range(lun_shu):
  81.     cunt+=1
  82.     for m in range(total_step):
  83.         xs=train_x[m*yang_benshu:(m+1)*yang_benshu]
  84.         ys=train_y[m*yang_benshu:(m+1)*yang_benshu]
  85.         
  86.         grads=grad(xs,ys,W,B)#计算梯度
  87.         optimizer.apply_gradients(zip(grads,W+B)) #优化器根据梯度自动调整梯度
  88.         
  89.     loss_train=loss(train_x,train_y,W,B).numpy()  #计算当前轮训练损失率
  90.     acc_train=accuracy(train_x,train_y,W,B).numpy()  #计算当前识别率
  91.    
  92.     loss_train_save.append(loss_train)
  93.     acc_train_save.append(acc_train)
  94.    
  95.     print('正在学习第:%d轮  损失率:%4f  识别正确率:%4f' % (cunt,loss_train,acc_train))
  96.    
  97. zhi_xian_tu(acc_train_save,loss_train_save)    #打印直线图
  98.    
复制代码

小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2020-4-27 15:39:30 | 显示全部楼层
{WZI1`B%1PFBP9RK0R]1]~2.png   如果把  优化器根据梯度自动调整梯度  注释掉就可以运行 但是就学习率就上不去了
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2020-4-27 16:59:09 | 显示全部楼层
这种问题还是去CDSN问,那边专业的多一些。
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2020-4-27 17:09:58 | 显示全部楼层
都在几处发帖了  但是还没在CDSN发过
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-26 12:40

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表