鱼C论坛

 找回密码
 立即注册
查看: 1954|回复: 1

[技术交流] python实现SVM【软间隔】【SMO算法】

[复制链接]
发表于 2020-11-10 16:34:21 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 糖逗 于 2020-11-10 16:39 编辑

参考书籍:《机器学习实战》

  1. import numpy as np
  2. from numpy import random

  3. def loadDataSet(fileName):
  4.     dataMat = []
  5.     labelMat = []
  6.     fr = open(fileName)
  7.     for line in fr.readlines():
  8.         lineArr = line.strip().split('\t')
  9.         dataMat.append([float(lineArr[0]), float(lineArr[1])])
  10.         labelMat.append(float(lineArr[2]))
  11.     return dataMat, labelMat

  12. def selectJrand(i, m):
  13.     j = i
  14.     while(j == i):
  15.         j = int(random.uniform(0, m))
  16.     return j

  17. def clipAlpha(aj, H, L):
  18.     if aj > H:
  19.         aj = H
  20.     if L > aj:
  21.         aj = L
  22.     return aj


  23. def smoSimple(dataMatIn, classLabels, C, toler, maxIter):
  24.     dataMatrix = np.mat(dataMatIn)
  25.     labelMat = np.mat(classLabels).transpose()
  26.     b = 0; m,n = np.shape(dataMatrix)
  27.     alphas = np.mat(np.zeros((m,1)))
  28.     iter = 0
  29.     while (iter < maxIter):
  30.         alphaPairsChanged = 0
  31.         for i in range(m):
  32.             fXi = float(np.multiply(alphas,labelMat).T*(dataMatrix*dataMatrix[i,:].T)) + b
  33.             Ei = fXi - float(labelMat[i])
  34.             if ((labelMat[i]*Ei < -toler) and (alphas[i] < C)) or ((labelMat[i]*Ei > toler) and (alphas[i] > 0)):
  35.                 j = selectJrand(i,m)
  36.                 fXj = float(np.multiply(alphas,labelMat).T*(dataMatrix*dataMatrix[j,:].T)) + b
  37.                 Ej = fXj - float(labelMat[j])
  38.                 alphaIold = alphas[i].copy(); alphaJold = alphas[j].copy();
  39.                 if (labelMat[i] != labelMat[j]):
  40.                     L = max(0, alphas[j] - alphas[i])
  41.                     H = min(C, C + alphas[j] - alphas[i])
  42.                 else:
  43.                     L = max(0, alphas[j] + alphas[i] - C)
  44.                     H = min(C, alphas[j] + alphas[i])
  45.                 if L==H: print("L==H"); continue
  46.                 eta = 2.0 * dataMatrix[i,:]*dataMatrix[j,:].T - dataMatrix[i,:]*dataMatrix[i,:].T - dataMatrix[j,:]*dataMatrix[j,:].T
  47.                 if eta >= 0:
  48.                     print("eta>=0"); continue
  49.                 alphas[j] -= labelMat[j]*(Ei - Ej)/eta
  50.                 alphas[j] = clipAlpha(alphas[j],H,L)
  51.                 if (abs(alphas[j] - alphaJold) < 0.00001):
  52.                     print("j not moving enough"); continue
  53.                 alphas[i] += labelMat[j]*labelMat[i]*(alphaJold - alphas[j])
  54.                 b1 = b - Ei- labelMat[i]*(alphas[i]-alphaIold)*dataMatrix[i,:]*dataMatrix[i,:].T - labelMat[j]*(alphas[j]-alphaJold)*dataMatrix[i,:]*dataMatrix[j,:].T
  55.                 b2 = b - Ej- labelMat[i]*(alphas[i]-alphaIold)*dataMatrix[i,:]*dataMatrix[j,:].T - labelMat[j]*(alphas[j]-alphaJold)*dataMatrix[j,:]*dataMatrix[j,:].T
  56.                 if (0 < alphas[i]) and (C > alphas[i]):
  57.                     b = b1
  58.                 elif (0 < alphas[j]) and (C > alphas[j]):
  59.                     b = b2
  60.                 else:
  61.                     b = (b1 + b2)/2.0
  62.                 alphaPairsChanged += 1
  63.                 print("iter: %d i:%d, pairs changed %d" % (iter,i,alphaPairsChanged))
  64.         if(alphaPairsChanged == 0):
  65.             iter += 1
  66.         else:
  67.             iter = 0
  68.         print("iteration number: %d" % iter)
  69.     return b,alphas


  70. def calcWs(alphas,dataArr,classLabels):
  71.     X = np.mat(dataArr); labelMat = np.mat(classLabels).transpose()
  72.     m, n = np.shape(X)
  73.     w = np.zeros((n,1))
  74.     for i in range(m):
  75.         w += np.multiply(alphas[i]*labelMat[i],X[i,:].T)
  76.     return w


  77. dataArr, labelArr = loadDataSet(r"C:\...\testSet.txt")         
  78. b, alphas = smoSimple(dataArr, labelArr, 0.6, 0.001, 40)   
  79. ws = calcWs(alphas, dataArr, labelArr)   



  80. import seaborn as sns
  81. import pandas as pd
  82. import matplotlib.pyplot as plt
  83. temp = pd.DataFrame(dataArr)
  84. temp.columns = ["1", "2"]
  85. temp["label"] = pd.array(labelArr)
  86. temp["label"] = np.array(temp["label"]).astype(np.int)
  87. xx = np.linspace(0, 10, 20)
  88. yy = (-b - xx * ws[0]) / ws[1]
  89. temp1 = pd.DataFrame()
  90. temp1["xx"] = np.array(xx)
  91. temp1["yy"] = np.array(yy.T)
  92. sns.scatterplot(data = temp, x = "1", y = "2", hue = "label")
  93. plt.plot(temp1['xx'], temp1['yy'])               
  94.                
复制代码

本帖被以下淘专辑推荐:

小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2020-11-10 16:35:53 | 显示全部楼层
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-5-20 07:24

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表