我真的不想秃头 发表于 2022-4-30 11:25:12

svm--李航统计学习方法

错误提示:

Traceback (most recent call last):
File "D:/code/pycharm/dingwei.py", line 400, in <module>
    svm.train()
File "D:/code/pycharm/dingwei.py", line 241, in train
    if self.isSatisfyKKT(i) is not True:
File "D:/code/pycharm/dingwei.py", line 112, in isSatisfyKKT
    if (math.fabs(self.alpha) < self.toler) and (yi * gxi >= 1):
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

错误相关代码:
def isSatisfyKKT(self, i):
      
      gxi =self.calc_gxi(i)
      yi = self.trainLabelMat

       if (math.fabs(self.alpha) < self.toler) and (yi * gxi >= 1):
            return True
      
      elif (math.fabs(self.alpha - self.C) < self.toler) and (yi * gxi <= 1):
            return True
      
      elif (self.alpha > -self.toler) and (self.alpha < (self.C + self.toler)) \
                and (math.fabs(yi * gxi - 1) < self.toler):
            return True
      else:
            return False


while (iterStep < iter) and (parameterChanged > 0):
            #打印当前迭代轮数
            print('iter:%d:%d'%( iterStep, iter))
            #迭代步数加1
            iterStep += 1
            #新的一轮将参数改变标志位重新置0
            parameterChanged = 0

            #大循环遍历所有样本,用于找SMO中第一个变量
            for i in range(self.m):
                #查看第一个遍历是否满足KKT条件,如果不满足则作为SMO中第一个变量从而进行优化

                   if self.isSatisfyKKT(i) is not True:
                  #如果下标为i的α不满足KKT条件,则进行优化

                  #第一个变量α的下标i已经确定,接下来按照“7.4.2 变量的选择方法”第二步
                  #选择变量2。由于变量2的选择中涉及到|E1 - E2|,因此先计算E1
                  E1 = self.calcEi(i)

                  #选择第2个变量
                  E2, j = self.getAlphaJ(E1, i)



if __name__ == '__main__':
    start = time.time()

    # 获取训练集及标签
    print('start read transSet')
    trainDataList, trainLabelList = loadData('train.txt')

    # 获取测试集及标签
    print('start read testSet')
    testDataList, testLabelList = loadData('test.txt')

    #初始化SVM类
    print('start init SVM')
    svm = SVM(trainDataList[:1000], trainLabelList[:1000], 10, 200, 0.001)

    # 开始训练
    print('start to train')
    svm.train()

    # 开始测试
    print('start to test')
    accuracy = svm.test(testDataList[:100, testLabelList[:100])
    print('the accuracy is:%d'%(accuracy * 100), '%')

    # 打印时间
    print('time span:', time.time() - start)



请问一下该怎么改

Twilight6 发表于 2022-4-30 11:39:55



报错意思是你将一个数值和多给数值一次性进行比较了,检查检查比较的是不是数组吧

amazed 发表于 2022-5-1 00:42:48

666666666666666666

kerln888 发表于 2022-5-1 09:06:25

来学习一下
页: [1]
查看完整版本: svm--李航统计学习方法