鱼C论坛

 找回密码
 立即注册
查看: 921|回复: 0

C++实现SVM

[复制链接]
发表于 2019-7-31 22:47:59 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
在机器学习中,SVM相对来说是比较难的了,所以我对SVM也是琢磨了很久,所以利用C++来实现SVM,来加深自己对SVM的理解。
  1. #include <iostream>
  2. #include <cstdio>
  3. #include <algorithm>
  4. #include <cmath>
  5. #include<string>
  6. #include<fstream>
  7. #include<sstream>

  8. using std::sort;
  9. using std::fabs;
  10. using namespace std;

  11. const int MAX_DIMENSION = 3;
  12. const int MAX_SAMPLES = 306;
  13. double x[MAX_SAMPLES][MAX_DIMENSION];
  14. double y[MAX_SAMPLES];
  15. double alpha[MAX_SAMPLES];
  16. double w[MAX_DIMENSION];
  17. double b;
  18. double c;
  19. double eps = 1e-6;


  20. int num_samples = 306;
  21. int num_dimension = 3;

  22. struct _E {
  23.         double val;
  24.         int index;
  25. }E[MAX_SAMPLES];

  26. bool cmp(const _E & a, const _E & b)
  27. {
  28.         return a.val < b.val;
  29. }



  30. double max(double a, double b)
  31. {
  32.         return a > b ? a : b;
  33. }

  34. double min(double a, double b)
  35. {
  36.         return a > b ? b : a;
  37. }

  38. double kernal(double x1[], double x2[], double dimension)
  39. {
  40.         double ans = 0;
  41.         for (int i = 0; i < dimension; i++)
  42.         {
  43.                 ans += x1[i] * x2[i];
  44.         }
  45.         return ans;
  46. }

  47. double target_function()
  48. {
  49.         double ans = 0;
  50.         for (int i = 0; i < num_samples; i++)
  51.         {
  52.                 for (int j = 0; j < num_samples; j++)
  53.                 {
  54.                         ans += alpha[i] * alpha[j] * y[i] * y[j] * kernal(x[i], x[j], num_dimension);
  55.                 }
  56.         }

  57.         for (int i = 0; i < num_samples; i++)
  58.         {
  59.                 ans -= alpha[i];
  60.         }

  61.         return ans;
  62. }


  63. double g(double _x[], int dimension)
  64. {
  65.         double ans = b;

  66.         for (int i = 0; i < num_samples; i++)
  67.         {
  68.                 ans += alpha[i] * y[i] * kernal(x[i], _x, dimension);
  69.         }

  70.         return ans;
  71. }

  72. bool satisfy_constrains(int i, int dimension)
  73. {
  74.         if (alpha[i] == 0)
  75.         {
  76.                 if (y[i] * g(x[i], dimension) >= 1)
  77.                         return true;
  78.                 else
  79.                         return false;
  80.         }
  81.         else if (alpha[i] > 0 && alpha[i] < c)
  82.         {
  83.                 if (y[i] * g(x[i], dimension) == 1)
  84.                         return true;
  85.                 else
  86.                         return false;
  87.         }
  88.         else
  89.         {
  90.                 if (y[i] * g(x[i], dimension) <= 1)
  91.                         return true;
  92.                 else
  93.                         return false;
  94.         }
  95. }


  96. double calE(int i, int dimension)
  97. {
  98.         return g(x[i], dimension) - y[i];
  99. }

  100. void calW()
  101. {
  102.         for (int i = 0; i < num_dimension; i++)
  103.         {
  104.                 w[i] = 0;
  105.                 for (int j = 0; j < num_samples; j++)
  106.                 {
  107.                         w[i] += alpha[j] * y[j] * x[j][i];
  108.                 }
  109.         }
  110.         return;
  111. }

  112. void calB()
  113. {
  114.         double ans = y[0];
  115.         for (int i = 0; i < num_samples; i++)
  116.         {
  117.                 ans -= y[i] * alpha[i] * kernal(x[i], x[0], num_dimension);
  118.         }
  119.         b = ans;
  120.         return;
  121. }


  122. void recalB(int alpha1index, int alpha2index, int dimension, double alpha1old, double alpha2old)
  123. {
  124.         double alpha1new = alpha[alpha1index];
  125.         double alpha2new = alpha[alpha2index];

  126.         alpha[alpha1index] = alpha1old;
  127.         alpha[alpha2index] = alpha2old;

  128.         double e1 = calE(alpha1index, num_dimension);
  129.         double e2 = calE(alpha2index, num_dimension);

  130.         alpha[alpha1index] = alpha1new;
  131.         alpha[alpha2index] = alpha2new;

  132.         double b1new = -e1 - y[alpha1index] * kernal(x[alpha1index], x[alpha1index], dimension)*(alpha1new - alpha1old);
  133.         b1new -= y[alpha2index] * kernal(x[alpha2index], x[alpha1index], dimension)*(alpha2new - alpha2old) + b;

  134.         double b2new = -e2 - y[alpha1index] * kernal(x[alpha1index], x[alpha2index], dimension)*(alpha1new - alpha1old);
  135.         b1new -= y[alpha2index] * kernal(x[alpha2index], x[alpha2index], dimension)*(alpha2new - alpha2old) + b;

  136.         b = (b1new + b2new) / 2;
  137. }

  138. bool optimizehelp(int alpha1index, int alpha2index)
  139. {
  140.         double alpha1new = alpha[alpha1index];
  141.         double alpha2new = alpha[alpha2index];

  142.         double alpha1old = alpha[alpha1index];
  143.         double alpha2old = alpha[alpha2index];

  144.         double H, L;

  145.         if (fabs(y[alpha1index] - y[alpha2index]) > eps)
  146.         {
  147.                 L = max(0, alpha2old - alpha1old);
  148.                 H = min(c, c + alpha2old - alpha1old);
  149.         }
  150.         else
  151.         {
  152.                 L = max(0, alpha2old + alpha1old - c);
  153.                 H = min(c, alpha2old + alpha1old);
  154.         }

  155.         //cal new
  156.         double lena = kernal(x[alpha1index], x[alpha1index], num_dimension) + kernal(x[alpha2index], x[alpha2index], num_dimension) - 2 * kernal(x[alpha1index], x[alpha2index], num_dimension);
  157.         alpha2new = alpha2old + y[alpha2index] * (calE(alpha1index, num_dimension) - calE(alpha2index, num_dimension)) / lena;

  158.         if (alpha2new > H)
  159.         {
  160.                 alpha2new = H;
  161.         }
  162.         else if (alpha2new < L)
  163.         {
  164.                 alpha2new = L;
  165.         }

  166.         alpha1new = alpha1old + y[alpha1index] * y[alpha2index] * (alpha2old - alpha2new);

  167.         double energyold = target_function();

  168.         alpha[alpha1index] = alpha1new;
  169.         alpha[alpha2index] = alpha2new;

  170.         double gap = 0.001;

  171.         recalB(alpha1index, alpha2index, num_dimension, alpha1old, alpha2old);
  172.         return true;
  173. }

  174. bool optimize()
  175. {
  176.         int alpha1index = -1;
  177.         int alpha2index = -1;
  178.         double alpha2new = 0;
  179.         double alpha1new = 0;

  180.         //cal E[]
  181.         for (int i = 0; i < num_samples; i++)
  182.         {
  183.                 E[i].val = calE(i, num_dimension);
  184.                 E[i].index = i;
  185.         }

  186.         //traverse the alpha1index with 0 < && < c
  187.         for (int i = 0; i < num_samples; i++)
  188.         {
  189.                 alpha1new = alpha[i];

  190.                 if (alpha1new > 0 && alpha1new < c)
  191.                 {

  192.                         if (satisfy_constrains(i, num_dimension))
  193.                                 continue;

  194.                         sort(E, E + num_samples, cmp);

  195.                         //simply find the maximum or minimun;
  196.                         if (alpha1new > 0)
  197.                         {
  198.                                 if (E[0].index == i)
  199.                                 {
  200.                                         ;
  201.                                 }
  202.                                 else
  203.                                 {
  204.                                         alpha1index = i;
  205.                                         alpha2index = E[0].index;
  206.                                         if (optimizehelp(alpha1index, alpha2index))
  207.                                         {
  208.                                                 return true;
  209.                                         }
  210.                                 }
  211.                         }
  212.                         else
  213.                         {
  214.                                 if (E[num_samples - 1].index == i)
  215.                                 {
  216.                                         ;
  217.                                 }
  218.                                 else
  219.                                 {
  220.                                         alpha1index = i;
  221.                                         alpha2index = E[num_samples - 1].index;
  222.                                         if (optimizehelp(alpha1index, alpha2index))
  223.                                         {
  224.                                                 return true;
  225.                                         }
  226.                                 }
  227.                         }


  228.                         //find the alpha2 > 0 && < c
  229.                         for (int j = 0; j < num_samples; j++)
  230.                         {
  231.                                 alpha2new = alpha[j];

  232.                                 if (alpha2new > 0 && alpha2new < c)
  233.                                 {
  234.                                         alpha1index = i;
  235.                                         alpha2index = j;
  236.                                         if (optimizehelp(alpha1index, alpha2index))
  237.                                         {
  238.                                                 return true;
  239.                                         }
  240.                                 }
  241.                         }

  242.                         //find other alpha2
  243.                         for (int j = 0; j < num_samples; j++)
  244.                         {
  245.                                 alpha2new = alpha[j];

  246.                                 if (!(alpha2new > 0 && alpha2new < c))
  247.                                 {
  248.                                         alpha1index = i;
  249.                                         alpha2index = j;
  250.                                         if (optimizehelp(alpha1index, alpha2index))
  251.                                         {
  252.                                                 return true;
  253.                                         }
  254.                                 }
  255.                         }
  256.                 }
  257.         }

  258.         //find all alpha1
  259.         for (int i = 0; i < num_samples; i++)
  260.         {
  261.                 alpha1new = alpha[i];

  262.                 if (!(alpha1new > 0 && alpha1new < c))
  263.                 {
  264.                         if (satisfy_constrains(i, num_dimension))
  265.                                 continue;

  266.                         sort(E, E + num_samples, cmp);

  267.                         //simply find the maximum or minimun;
  268.                         if (alpha1new > 0)
  269.                         {
  270.                                 if (E[0].index == i)
  271.                                 {
  272.                                         ;
  273.                                 }
  274.                                 else
  275.                                 {
  276.                                         alpha1index = i;
  277.                                         alpha2index = E[0].index;
  278.                                         if (optimizehelp(alpha1index, alpha2index))
  279.                                         {
  280.                                                 return true;
  281.                                         }
  282.                                 }
  283.                         }
  284.                         else
  285.                         {
  286.                                 if (E[num_samples - 1].index == i)
  287.                                 {
  288.                                         ;
  289.                                 }
  290.                                 else
  291.                                 {
  292.                                         alpha1index = i;
  293.                                         alpha2index = E[num_samples - 1].index;
  294.                                         if (optimizehelp(alpha1index, alpha2index))
  295.                                         {
  296.                                                 return true;
  297.                                         }
  298.                                 }
  299.                         }


  300.                         //find the alpha2 > 0 && < c
  301.                         for (int j = 0; j < num_samples; j++)
  302.                         {
  303.                                 alpha2new = alpha[j];

  304.                                 if (alpha2new > 0 && alpha2new < c)
  305.                                 {
  306.                                         alpha1index = i;
  307.                                         alpha2index = j;
  308.                                         if (optimizehelp(alpha1index, alpha2index))
  309.                                         {
  310.                                                 return true;
  311.                                         }
  312.                                 }
  313.                         }

  314.                         //find other alpha2
  315.                         for (int j = 0; j < num_samples; j++)
  316.                         {
  317.                                 alpha2new = alpha[j];

  318.                                 if (!(alpha2new > 0 && alpha2new < c))
  319.                                 {
  320.                                         alpha1index = i;
  321.                                         alpha2index = j;
  322.                                         if (optimizehelp(alpha1index, alpha2index))
  323.                                         {
  324.                                                 return true;
  325.                                         }
  326.                                 }
  327.                         }
  328.                 }
  329.         }

  330.         //for(int i = 0 ; i < num_samples; i++)
  331.         //{
  332.         //    alpha1new = alpha[i];

  333.         //    for(int j = 0 ; j < num_samples; j++)
  334.         //    {
  335.         //        if(1)
  336.         //        {
  337.         //            alpha1index = i;
  338.         //            alpha2index = j;
  339.         //            if(optimizehelp(alpha1index , alpha2index))
  340.         //            {
  341.         //                return true;
  342.         //            }
  343.         //        }
  344.         //    }
  345.         //}
  346.         return false;
  347. }

  348. bool check()
  349. {
  350.         double sum = 0;
  351.         for (int i = 0; i < num_samples; i++)
  352.         {
  353.                 sum += alpha[i] * y[i];
  354.                 if (!(0 <= alpha[i] && alpha[i] <= c))
  355.                 {
  356.                         printf("alpha[%d]: %lf wrong\n", i, alpha[i]);
  357.                         return false;
  358.                 }
  359.                 if (!satisfy_constrains(i, num_dimension))
  360.                 {
  361.                         printf("alpha[%d] not satisfy constrains\n", i);
  362.                         return false;
  363.                 }
  364.         }

  365.         if (fabs(sum) > eps)
  366.         {
  367.                 printf("Sum = %lf\n", sum);
  368.                 return false;
  369.         }
  370.         return true;
  371. }


  372. int toNum(string str)//Enclave无法接受string类型数据
  373. {
  374.         int ans = 0;
  375.         for (int i = 0; i < str.length(); i++)
  376.         {
  377.                 ans = ans * 10 + (str[i] - '0');
  378.         }
  379.         return ans;
  380. }

  381. void loaddata(string path)
  382. {
  383.         ifstream Filein;
  384.         try { Filein.open(path); }
  385.         catch (exception e)
  386.         {
  387.                 cout << "File open failed!";
  388.         }

  389.         string line;
  390.         int data_num = 0;
  391.         while (getline(Filein, line)) {
  392.                 int before = 0;
  393.                 int cnt = 0;
  394.                 data_num++;
  395.                 //cout << data_num << endl;
  396.                 for (unsigned int i = 0; i < line.length(); i++) {
  397.                         if (line[i] == ',' || line[i] == '\n') {
  398.                                 string sub = line.substr(before, i - before);
  399.                                 before = i + 1;
  400.                                 x[data_num - 1][cnt] = toNum(sub);
  401.                                 cnt++;
  402.                         }
  403.                 }
  404.                 //Data[data_num - 1][cnt] = toNum(line.substr(before, line.length()));
  405.                 y[data_num - 1] = toNum(line.substr(before, line.length()));
  406.         }
  407.         cout << "data loading done.\nthe amount of data is: " << data_num << endl;
  408. }

  409. int main()
  410. {
  411.         /*scanf_s("%d%d", &num_samples, &num_dimension);

  412.         for (int i = 0; i < num_samples; i++)
  413.         {
  414.                 for (int j = 0; j < num_dimension; j++)
  415.                 {
  416.                         scanf_s("%lf", &x[i][j]);
  417.                 }
  418.                 scanf_s("%lf", &y[i]);
  419.         }*/

  420.         loaddata("C:\\Users\\YY\\Desktop\\haberman1.txt");//获取数据集,并存于Data数组

  421.         c = 1;

  422.         //初值附为0;
  423.         for (int i = 0; i < num_samples; i++)
  424.         {
  425.                 alpha[i] = 0;
  426.         }

  427.         int count = 0;
  428.         while (optimize()) {
  429.                 calB();
  430.                 count++;
  431.         }
  432.         printf("%d ", count);

  433.         calW();
  434.         calB();

  435.         printf("y = ");

  436.         for (int i = 0; i < num_dimension; i++)
  437.         {
  438.                 printf("%lf * x[%d] + ", w[i], i);
  439.         }

  440.         //计算精度
  441.         int countt = 0;
  442.         for (int i = 0; i < num_samples; i++)
  443.         {
  444.                 double Y = 0;
  445.                 for (int j = 0; j < num_dimension; j++)
  446.                 {
  447.                         Y += x[i][j] * w[j];
  448.                 }
  449.                 if (Y + b == y[i])
  450.                 {
  451.                         countt++;
  452.                 }
  453.         }
  454.         printf("%lf\n", b);
  455.         double pro = double(countt) / double(num_samples);
  456.         printf("%f\n ", pro);

  457.         if (!check())
  458.                 printf("Not satisfy KKT.\n");
  459.         else
  460.                 printf("Satisfy KKT\n");

  461.         system("pause");
  462.         return 0;
  463. }
复制代码


haberman1.zip

984 Bytes, 下载次数: 0

售价: 50 鱼币  [记录]  [购买]

数据集

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-4-23 19:53

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表