Julia999 发表于 2019-7-31 22:47:59

C++实现SVM

在机器学习中,SVM相对来说是比较难的了,所以我对SVM也是琢磨了很久,所以利用C++来实现SVM,来加深自己对SVM的理解。#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include<string>
#include<fstream>
#include<sstream>

using std::sort;
using std::fabs;
using namespace std;

const int MAX_DIMENSION = 3;
const int MAX_SAMPLES = 306;
double x;
double y;
double alpha;
double w;
double b;
double c;
double eps = 1e-6;


int num_samples = 306;
int num_dimension = 3;

struct _E {
      double val;
      int index;
}E;

bool cmp(const _E & a, const _E & b)
{
      return a.val < b.val;
}



double max(double a, double b)
{
      return a > b ? a : b;
}

double min(double a, double b)
{
      return a > b ? b : a;
}

double kernal(double x1[], double x2[], double dimension)
{
      double ans = 0;
      for (int i = 0; i < dimension; i++)
      {
                ans += x1 * x2;
      }
      return ans;
}

double target_function()
{
      double ans = 0;
      for (int i = 0; i < num_samples; i++)
      {
                for (int j = 0; j < num_samples; j++)
                {
                        ans += alpha * alpha * y * y * kernal(x, x, num_dimension);
                }
      }

      for (int i = 0; i < num_samples; i++)
      {
                ans -= alpha;
      }

      return ans;
}


double g(double _x[], int dimension)
{
      double ans = b;

      for (int i = 0; i < num_samples; i++)
      {
                ans += alpha * y * kernal(x, _x, dimension);
      }

      return ans;
}

bool satisfy_constrains(int i, int dimension)
{
      if (alpha == 0)
      {
                if (y * g(x, dimension) >= 1)
                        return true;
                else
                        return false;
      }
      else if (alpha > 0 && alpha < c)
      {
                if (y * g(x, dimension) == 1)
                        return true;
                else
                        return false;
      }
      else
      {
                if (y * g(x, dimension) <= 1)
                        return true;
                else
                        return false;
      }
}


double calE(int i, int dimension)
{
      return g(x, dimension) - y;
}

void calW()
{
      for (int i = 0; i < num_dimension; i++)
      {
                w = 0;
                for (int j = 0; j < num_samples; j++)
                {
                        w += alpha * y * x;
                }
      }
      return;
}

void calB()
{
      double ans = y;
      for (int i = 0; i < num_samples; i++)
      {
                ans -= y * alpha * kernal(x, x, num_dimension);
      }
      b = ans;
      return;
}


void recalB(int alpha1index, int alpha2index, int dimension, double alpha1old, double alpha2old)
{
      double alpha1new = alpha;
      double alpha2new = alpha;

      alpha = alpha1old;
      alpha = alpha2old;

      double e1 = calE(alpha1index, num_dimension);
      double e2 = calE(alpha2index, num_dimension);

      alpha = alpha1new;
      alpha = alpha2new;

      double b1new = -e1 - y * kernal(x, x, dimension)*(alpha1new - alpha1old);
      b1new -= y * kernal(x, x, dimension)*(alpha2new - alpha2old) + b;

      double b2new = -e2 - y * kernal(x, x, dimension)*(alpha1new - alpha1old);
      b1new -= y * kernal(x, x, dimension)*(alpha2new - alpha2old) + b;

      b = (b1new + b2new) / 2;
}

bool optimizehelp(int alpha1index, int alpha2index)
{
      double alpha1new = alpha;
      double alpha2new = alpha;

      double alpha1old = alpha;
      double alpha2old = alpha;

      double H, L;

      if (fabs(y - y) > eps)
      {
                L = max(0, alpha2old - alpha1old);
                H = min(c, c + alpha2old - alpha1old);
      }
      else
      {
                L = max(0, alpha2old + alpha1old - c);
                H = min(c, alpha2old + alpha1old);
      }

      //cal new
      double lena = kernal(x, x, num_dimension) + kernal(x, x, num_dimension) - 2 * kernal(x, x, num_dimension);
      alpha2new = alpha2old + y * (calE(alpha1index, num_dimension) - calE(alpha2index, num_dimension)) / lena;

      if (alpha2new > H)
      {
                alpha2new = H;
      }
      else if (alpha2new < L)
      {
                alpha2new = L;
      }

      alpha1new = alpha1old + y * y * (alpha2old - alpha2new);

      double energyold = target_function();

      alpha = alpha1new;
      alpha = alpha2new;

      double gap = 0.001;

      recalB(alpha1index, alpha2index, num_dimension, alpha1old, alpha2old);
      return true;
}

bool optimize()
{
      int alpha1index = -1;
      int alpha2index = -1;
      double alpha2new = 0;
      double alpha1new = 0;

      //cal E[]
      for (int i = 0; i < num_samples; i++)
      {
                E.val = calE(i, num_dimension);
                E.index = i;
      }

      //traverse the alpha1index with 0 < && < c
      for (int i = 0; i < num_samples; i++)
      {
                alpha1new = alpha;

                if (alpha1new > 0 && alpha1new < c)
                {

                        if (satisfy_constrains(i, num_dimension))
                              continue;

                        sort(E, E + num_samples, cmp);

                        //simply find the maximum or minimun;
                        if (alpha1new > 0)
                        {
                              if (E.index == i)
                              {
                                        ;
                              }
                              else
                              {
                                        alpha1index = i;
                                        alpha2index = E.index;
                                        if (optimizehelp(alpha1index, alpha2index))
                                        {
                                                return true;
                                        }
                              }
                        }
                        else
                        {
                              if (E.index == i)
                              {
                                        ;
                              }
                              else
                              {
                                        alpha1index = i;
                                        alpha2index = E.index;
                                        if (optimizehelp(alpha1index, alpha2index))
                                        {
                                                return true;
                                        }
                              }
                        }


                        //find the alpha2 > 0 && < c
                        for (int j = 0; j < num_samples; j++)
                        {
                              alpha2new = alpha;

                              if (alpha2new > 0 && alpha2new < c)
                              {
                                        alpha1index = i;
                                        alpha2index = j;
                                        if (optimizehelp(alpha1index, alpha2index))
                                        {
                                                return true;
                                        }
                              }
                        }

                        //find other alpha2
                        for (int j = 0; j < num_samples; j++)
                        {
                              alpha2new = alpha;

                              if (!(alpha2new > 0 && alpha2new < c))
                              {
                                        alpha1index = i;
                                        alpha2index = j;
                                        if (optimizehelp(alpha1index, alpha2index))
                                        {
                                                return true;
                                        }
                              }
                        }
                }
      }

      //find all alpha1
      for (int i = 0; i < num_samples; i++)
      {
                alpha1new = alpha;

                if (!(alpha1new > 0 && alpha1new < c))
                {
                        if (satisfy_constrains(i, num_dimension))
                              continue;

                        sort(E, E + num_samples, cmp);

                        //simply find the maximum or minimun;
                        if (alpha1new > 0)
                        {
                              if (E.index == i)
                              {
                                        ;
                              }
                              else
                              {
                                        alpha1index = i;
                                        alpha2index = E.index;
                                        if (optimizehelp(alpha1index, alpha2index))
                                        {
                                                return true;
                                        }
                              }
                        }
                        else
                        {
                              if (E.index == i)
                              {
                                        ;
                              }
                              else
                              {
                                        alpha1index = i;
                                        alpha2index = E.index;
                                        if (optimizehelp(alpha1index, alpha2index))
                                        {
                                                return true;
                                        }
                              }
                        }


                        //find the alpha2 > 0 && < c
                        for (int j = 0; j < num_samples; j++)
                        {
                              alpha2new = alpha;

                              if (alpha2new > 0 && alpha2new < c)
                              {
                                        alpha1index = i;
                                        alpha2index = j;
                                        if (optimizehelp(alpha1index, alpha2index))
                                        {
                                                return true;
                                        }
                              }
                        }

                        //find other alpha2
                        for (int j = 0; j < num_samples; j++)
                        {
                              alpha2new = alpha;

                              if (!(alpha2new > 0 && alpha2new < c))
                              {
                                        alpha1index = i;
                                        alpha2index = j;
                                        if (optimizehelp(alpha1index, alpha2index))
                                        {
                                                return true;
                                        }
                              }
                        }
                }
      }

      //for(int i = 0 ; i < num_samples; i++)
      //{
      //    alpha1new = alpha;

      //    for(int j = 0 ; j < num_samples; j++)
      //    {
      //      if(1)
      //      {
      //            alpha1index = i;
      //            alpha2index = j;
      //            if(optimizehelp(alpha1index , alpha2index))
      //            {
      //                return true;
      //            }
      //      }
      //    }
      //}
      return false;
}

bool check()
{
      double sum = 0;
      for (int i = 0; i < num_samples; i++)
      {
                sum += alpha * y;
                if (!(0 <= alpha && alpha <= c))
                {
                        printf("alpha[%d]: %lf wrong\n", i, alpha);
                        return false;
                }
                if (!satisfy_constrains(i, num_dimension))
                {
                        printf("alpha[%d] not satisfy constrains\n", i);
                        return false;
                }
      }

      if (fabs(sum) > eps)
      {
                printf("Sum = %lf\n", sum);
                return false;
      }
      return true;
}


int toNum(string str)//Enclave无法接受string类型数据
{
      int ans = 0;
      for (int i = 0; i < str.length(); i++)
      {
                ans = ans * 10 + (str - '0');
      }
      return ans;
}

void loaddata(string path)
{
      ifstream Filein;
      try { Filein.open(path); }
      catch (exception e)
      {
                cout << "File open failed!";
      }

      string line;
      int data_num = 0;
      while (getline(Filein, line)) {
                int before = 0;
                int cnt = 0;
                data_num++;
                //cout << data_num << endl;
                for (unsigned int i = 0; i < line.length(); i++) {
                        if (line == ',' || line == '\n') {
                              string sub = line.substr(before, i - before);
                              before = i + 1;
                              x = toNum(sub);
                              cnt++;
                        }
                }
                //Data = toNum(line.substr(before, line.length()));
                y = toNum(line.substr(before, line.length()));
      }
      cout << "data loading done.\nthe amount of data is: " << data_num << endl;
}

int main()
{
      /*scanf_s("%d%d", &num_samples, &num_dimension);

      for (int i = 0; i < num_samples; i++)
      {
                for (int j = 0; j < num_dimension; j++)
                {
                        scanf_s("%lf", &x);
                }
                scanf_s("%lf", &y);
      }*/

      loaddata("C:\\Users\\YY\\Desktop\\haberman1.txt");//获取数据集,并存于Data数组

      c = 1;

      //初值附为0;
      for (int i = 0; i < num_samples; i++)
      {
                alpha = 0;
      }

      int count = 0;
      while (optimize()) {
                calB();
                count++;
      }
      printf("%d ", count);

      calW();
      calB();

      printf("y = ");

      for (int i = 0; i < num_dimension; i++)
      {
                printf("%lf * x[%d] + ", w, i);
      }

      //计算精度
      int countt = 0;
      for (int i = 0; i < num_samples; i++)
      {
                double Y = 0;
                for (int j = 0; j < num_dimension; j++)
                {
                        Y += x * w;
                }
                if (Y + b == y)
                {
                        countt++;
                }
      }
      printf("%lf\n", b);
      double pro = double(countt) / double(num_samples);
      printf("%f\n ", pro);

      if (!check())
                printf("Not satisfy KKT.\n");
      else
                printf("Satisfy KKT\n");

      system("pause");
      return 0;
}

页: [1]
查看完整版本: C++实现SVM