鱼C论坛

 找回密码
 立即注册
查看: 1089|回复: 5

[已解决]scipy.optimize优化问题,矩阵维度不匹配

[复制链接]
发表于 2019-3-13 21:29:19 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. import numpy as np
  2. import pandas as pd
  3. data=pd.read_csv('ex2data1.txt',header=None,names=['Exam1','Exam2','Class'])
  4. positive=data[ data.Class.isin(['1']) ]
  5. negative=data[ data.Class.isin(['0']) ]
  6. data.insert(0,'Ones',1)
  7. X=data.iloc[:,0:3]
  8. y=data.iloc[:,3]
  9. X=np.matrix(X.values)  #X的shape (100,3) len(X)=100
  10. y=np.matrix(y.values).reshape(100,1)   #y的shape(100,1)
  11. theta=np.matrix(np.zeros(3).reshape(1,3),dtype=int) #matrix([[0, 0, 0]])   #theta的shape(1,3)

  12. def sigmoid(theta,X):
  13.     z=X*theta.T
  14.     return 1/(1+np.exp(-z))

  15. def cost(theta,X,y):
  16.     hypothesis=sigmoid(theta,X)
  17.     a=np.multiply(y,np.log(hypothesis)) + np.multiply((1-y),np.log(1-hypothesis))   #matrix(100,1) 计算初始成本下,a中每个元素都是-0.6931
  18.     b=np.sum(a)   #计算初始成本下,b大小为-69.31
  19.     return (-1./len(X))*np.sum(a)

  20. print 'original cost:',cost(theta,X,y) #0.69314718056

  21. def gradient(theta,X,y):
  22.     error=sigmoid(theta,X)-y #(100,1)
  23.     return (1.0/len(X))*error.T*X #(1,3)

  24. print gradient(theta,X,y)   #[[ -0.1  -12.00921659  -11.26284221]]

  25. import scipy.optimize as opt
  26. res=opt.minimize(fun=cost,x0=theta,args=(X,y),method='TNC',jac=gradient)
  27. print res
复制代码


我的问题是,(1)明明print 'original cost:',cost(theta,X,y) 和print gradient(theta,X,y) 这两条语句都执行成功了(截图中的代码行数不准哈,很多乱七八糟的注释我没放上来),也就是说我的矩阵维度并没有出问题,为什么在res=opt.minimize(fun=cost,x0=theta,args=(X,y),method='TNC',jac=gradient) 这条语句中就出现ValueError说我的矩阵维度不匹配了呢。
另外,还想问一下(2)ndarray和matrix这两种数据形式用于scipy优化方法有好坏之分吗?其实不是很懂数据这块。
最佳答案
2019-3-14 12:43:08
def minimize(fun, x0, args=(), method=None, jac=None, hess=None,
             hessp=None, bounds=None, constraints=(), tol=None,
             callback=None, options=None):
    """Minimization of scalar function of one or more variables.

    Parameters
    ----------
    fun : callable
        The objective function to be minimized.

            ``fun(x, *args) -> float``

        where x is an 1-D array with shape (n,) and `args`
        is a tuple of the fixed parameters needed to completely
        specify the function.
    x0 : ndarray, shape (n,)           这里你的1,3被置成了3,1
        Initial guess. Array of real elements of size (n,),
        where 'n' is the number of independent variables.
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2019-3-13 21:35:05 | 显示全部楼层
我是楼主,下面分别是txt文件里的数据,和上面代码👆执行的结果,不知道为啥我上传的图片和数据文件一直说在上传中,发不出来....
  1. 34.62365962451697,78.0246928153624,0
  2. 30.28671076822607,43.89499752400101,0
  3. 35.84740876993872,72.90219802708364,0
  4. 60.18259938620976,86.30855209546826,1
  5. 79.0327360507101,75.3443764369103,1
  6. 45.08327747668339,56.3163717815305,0
  7. 61.10666453684766,96.51142588489624,1
  8. 75.02474556738889,46.55401354116538,1
  9. 76.09878670226257,87.42056971926803,1
  10. 84.43281996120035,43.53339331072109,1
  11. 95.86155507093572,38.22527805795094,0
  12. 75.01365838958247,30.60326323428011,0
  13. 82.30705337399482,76.48196330235604,1
  14. 69.36458875970939,97.71869196188608,1
  15. 39.53833914367223,76.03681085115882,0
  16. 53.9710521485623,89.20735013750205,1
  17. 69.07014406283025,52.74046973016765,1
  18. 67.94685547711617,46.67857410673128,0
  19. 70.66150955499435,92.92713789364831,1
  20. 76.97878372747498,47.57596364975532,1
  21. 67.37202754570876,42.83843832029179,0
  22. 89.67677575072079,65.79936592745237,1
  23. 50.534788289883,48.85581152764205,0
  24. 34.21206097786789,44.20952859866288,0
  25. 77.9240914545704,68.9723599933059,1
  26. 62.27101367004632,69.95445795447587,1
  27. 80.1901807509566,44.82162893218353,1
  28. 93.114388797442,38.80067033713209,0
  29. 61.83020602312595,50.25610789244621,0
  30. 38.78580379679423,64.99568095539578,0
  31. 61.379289447425,72.80788731317097,1
  32. 85.40451939411645,57.05198397627122,1
  33. 52.10797973193984,63.12762376881715,0
  34. 52.04540476831827,69.43286012045222,1
  35. 40.23689373545111,71.16774802184875,0
  36. 54.63510555424817,52.21388588061123,0
  37. 33.91550010906887,98.86943574220611,0
  38. 64.17698887494485,80.90806058670817,1
  39. 74.78925295941542,41.57341522824434,0
  40. 34.1836400264419,75.2377203360134,0
  41. 83.90239366249155,56.30804621605327,1
  42. 51.54772026906181,46.85629026349976,0
  43. 94.44336776917852,65.56892160559052,1
  44. 82.36875375713919,40.61825515970618,0
  45. 51.04775177128865,45.82270145776001,0
  46. 62.22267576120188,52.06099194836679,0
  47. 77.19303492601364,70.45820000180959,1
  48. 97.77159928000232,86.7278223300282,1
  49. 62.07306379667647,96.76882412413983,1
  50. 91.56497449807442,88.69629254546599,1
  51. 79.94481794066932,74.16311935043758,1
  52. 99.2725269292572,60.99903099844988,1
  53. 90.54671411399852,43.39060180650027,1
  54. 34.52451385320009,60.39634245837173,0
  55. 50.2864961189907,49.80453881323059,0
  56. 49.58667721632031,59.80895099453265,0
  57. 97.64563396007767,68.86157272420604,1
  58. 32.57720016809309,95.59854761387875,0
  59. 74.24869136721598,69.82457122657193,1
  60. 71.79646205863379,78.45356224515052,1
  61. 75.3956114656803,85.75993667331619,1
  62. 35.28611281526193,47.02051394723416,0
  63. 56.25381749711624,39.26147251058019,0
  64. 30.05882244669796,49.59297386723685,0
  65. 44.66826172480893,66.45008614558913,0
  66. 66.56089447242954,41.09209807936973,0
  67. 40.45755098375164,97.53518548909936,1
  68. 49.07256321908844,51.88321182073966,0
  69. 80.27957401466998,92.11606081344084,1
  70. 66.74671856944039,60.99139402740988,1
  71. 32.72283304060323,43.30717306430063,0
  72. 64.0393204150601,78.03168802018232,1
  73. 72.34649422579923,96.22759296761404,1
  74. 60.45788573918959,73.09499809758037,1
  75. 58.84095621726802,75.85844831279042,1
  76. 99.82785779692128,72.36925193383885,1
  77. 47.26426910848174,88.47586499559782,1
  78. 50.45815980285988,75.80985952982456,1
  79. 60.45555629271532,42.50840943572217,0
  80. 82.22666157785568,42.71987853716458,0
  81. 88.9138964166533,69.80378889835472,1
  82. 94.83450672430196,45.69430680250754,1
  83. 67.31925746917527,66.58935317747915,1
  84. 57.23870631569862,59.51428198012956,1
  85. 80.36675600171273,90.96014789746954,1
  86. 68.46852178591112,85.59430710452014,1
  87. 42.0754545384731,78.84478600148043,0
  88. 75.47770200533905,90.42453899753964,1
  89. 78.63542434898018,96.64742716885644,1
  90. 52.34800398794107,60.76950525602592,0
  91. 94.09433112516793,77.15910509073893,1
  92. 90.44855097096364,87.50879176484702,1
  93. 55.48216114069585,35.57070347228866,0
  94. 74.49269241843041,84.84513684930135,1
  95. 89.84580670720979,45.35828361091658,1
  96. 83.48916274498238,48.38028579728175,1
  97. 42.2617008099817,87.10385094025457,1
  98. 99.31500880510394,68.77540947206617,1
  99. 55.34001756003703,64.9319380069486,1
  100. 74.77589300092767,89.52981289513276,1
复制代码


original cost: 0.69314718056
[[ -0.1        -12.00921659 -11.26284221]]
Traceback (most recent call last):
  File "LogisReg.py", line 85, in <module>
    res=opt.minimize(fun=cost,x0=theta,args=(X,y),method='TNC',jac=gradient)
  File "/anaconda2/lib/python2.7/site-packages/scipy/optimize/_minimize.py", line 490, in minimize
    **options)
  File "/anaconda2/lib/python2.7/site-packages/scipy/optimize/tnc.py", line 409, in _minimize_tnc
    xtol, pgtol, rescale, callback)
  File "/anaconda2/lib/python2.7/site-packages/scipy/optimize/tnc.py", line 371, in func_and_grad
    f = fun(x, *args)
  File "LogisReg.py", line 44, in cost
    hypothesis=sigmoid(theta,X)
  File "LogisReg.py", line 37, in sigmoid
    z=X*theta.T
  File "/anaconda2/lib/python2.7/site-packages/numpy/matrixlib/defmatrix.py", line 309, in __mul__
    return N.dot(self, asmatrix(other))
ValueError: shapes (100,3) and (1,3) not aligned: 3 (dim 1) != 1 (dim 0)

谢谢各位的帮助啦!快哭了.....
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2019-3-13 23:12:09 | 显示全部楼层
你的theta的变成矩阵维度是3,1
X和转置乘维度不对
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-3-13 23:26:16 | 显示全部楼层
塔利班 发表于 2019-3-13 23:12
你的theta的变成矩阵维度是3,1
X和转置乘维度不对

您的意思是说我的theta的shape应该是(3,1)而不是(1,3),对吗?我的矩阵相乘运算应该是对上了的,不然cost和gradient这两个函数就不会执行成功了。

如果你有兴趣的话,可以看看https://blog.csdn.net/Cowry5/art ... 247569#commentsedit,他数据都用的是ndarray的形式而不是matrix,执行optimize时就没我这个问题,这一点我不明白。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2019-3-14 12:43:08 | 显示全部楼层    本楼为最佳答案   
def minimize(fun, x0, args=(), method=None, jac=None, hess=None,
             hessp=None, bounds=None, constraints=(), tol=None,
             callback=None, options=None):
    """Minimization of scalar function of one or more variables.

    Parameters
    ----------
    fun : callable
        The objective function to be minimized.

            ``fun(x, *args) -> float``

        where x is an 1-D array with shape (n,) and `args`
        is a tuple of the fixed parameters needed to completely
        specify the function.
    x0 : ndarray, shape (n,)           这里你的1,3被置成了3,1
        Initial guess. Array of real elements of size (n,),
        where 'n' is the number of independent variables.
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2019-3-15 00:13:29 | 显示全部楼层
塔利班 发表于 2019-3-14 12:43
def minimize(fun, x0, args=(), method=None, jac=None, hess=None,
             hessp=None, bounds=No ...

谢谢&#128591;。如果想用scipy的优化,一开始就还是得设置为ndarray的形式。得从头改了
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-3-29 22:31

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表