鱼C论坛

 找回密码
 立即注册
查看: 1583|回复: 1

[技术交流] Python实现CART

[复制链接]
发表于 2020-11-16 12:58:42 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 糖逗 于 2020-11-16 14:30 编辑

参考书籍:《机器学习实战》

1.CART回归(预剪枝,通过tolS和tolT两个参数控制)
  1. import numpy as np

  2. def loadDataSet(fileName):     
  3.     dataMat = []               
  4.     fr = open(fileName)
  5.     for line in fr.readlines():
  6.         curLine = line.strip().split('\t')
  7.         fltLine = list(map(float,curLine))
  8.         dataMat.append(fltLine)
  9.     return np.mat(dataMat)

  10. def binSplitDataSet(dataSet, feature, value):
  11.     mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:]
  12.     mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
  13.     return mat0,mat1

  14. def regLeaf(dataSet):
  15.     return np.mean(dataSet[:,-1])

  16. def regErr(dataSet):
  17.     return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]

  18. def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
  19.     tolS = ops[0]#每次特征选择后的降低总方差的最小值
  20.     tolN = ops[1]#每次换分后每个子节点的最少样本个数
  21.     if len(set(dataSet[:,-1].T.tolist()[0])) == 1:#已经没有可分的直接返回成叶子节点
  22.         return None, leafType(dataSet)
  23.     m,n = np.shape(dataSet)
  24.     S = errType(dataSet)
  25.     bestS = np.inf; bestIndex = 0; bestValue = 0
  26.     for featIndex in range(n-1):
  27.         for splitVal in set(dataSet[:,featIndex].tolist()[0]):
  28.             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
  29.             if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
  30.             newS = errType(mat0) + errType(mat1)
  31.             if newS < bestS:
  32.                 bestIndex = featIndex
  33.                 bestValue = splitVal
  34.                 bestS = newS
  35.     if (S - bestS) < tolS:
  36.         return None, leafType(dataSet)#叶子节点
  37.     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
  38.     if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
  39.         return None, leafType(dataSet)
  40.     return bestIndex,bestValue
  41.                        

  42. def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
  43.     #递归的形式生成树
  44.     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
  45.     if feat == None: return val
  46.     retTree = {}
  47.     retTree['spInd'] = feat
  48.     retTree['spVal'] = val
  49.     lSet, rSet = binSplitDataSet(dataSet, feat, val)
  50.     retTree['left'] = createTree(lSet, leafType, errType, ops)
  51.     retTree['right'] = createTree(rSet, leafType, errType, ops)
  52.     return retTree  


  53. if __name__ == "__main__":
  54.     myDat = loadDataSet(r"C:\...\ex00.txt")
  55.     res = createTree(myDat)
  56.    
复制代码





2.CART回归(后剪枝)
  1. import numpy as np

  2. def loadDataSet(fileName):     
  3.     dataMat = []               
  4.     fr = open(fileName)
  5.     for line in fr.readlines():
  6.         curLine = line.strip().split('\t')
  7.         fltLine = list(map(float,curLine))
  8.         dataMat.append(fltLine)
  9.     return np.mat(dataMat)

  10. def binSplitDataSet(dataSet, feature, value):
  11.     mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:]
  12.     mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
  13.     return mat0,mat1

  14. def regLeaf(dataSet):
  15.     return np.mean(dataSet[:,-1])

  16. def regErr(dataSet):
  17.     return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]

  18. def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
  19.     tolS = ops[0]#每次特征选择后的降低总方差的最小值
  20.     tolN = ops[1]#每次换分后每个子节点的最少样本个数
  21.     if len(set(dataSet[:,-1].T.tolist()[0])) == 1:#已经没有可分的直接返回成叶子节点
  22.         return None, leafType(dataSet)
  23.     m,n = np.shape(dataSet)
  24.     S = errType(dataSet)
  25.     bestS = np.inf; bestIndex = 0; bestValue = 0
  26.     for featIndex in range(n-1):
  27.         for splitVal in set(dataSet[:,featIndex].tolist()[0]):
  28.             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
  29.             if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
  30.             newS = errType(mat0) + errType(mat1)
  31.             if newS < bestS:
  32.                 bestIndex = featIndex
  33.                 bestValue = splitVal
  34.                 bestS = newS
  35.     if (S - bestS) < tolS:
  36.         return None, leafType(dataSet)#叶子节点
  37.     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
  38.     if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
  39.         return None, leafType(dataSet)
  40.     return bestIndex,bestValue



  41. def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
  42.     #递归的形式生成树
  43.     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
  44.     if feat == None: return val
  45.     retTree = {}
  46.     retTree['spInd'] = feat
  47.     retTree['spVal'] = val
  48.     lSet, rSet = binSplitDataSet(dataSet, feat, val)
  49.     retTree['left'] = createTree(lSet, leafType, errType, ops)
  50.     retTree['right'] = createTree(rSet, leafType, errType, ops)
  51.     return retTree  


  52. def isTree(obj):
  53.     return (type(obj).__name__=='dict')

  54. def getMean(tree):#树的递归
  55.     if isTree(tree['right']): tree['right'] = getMean(tree['right'])
  56.     if isTree(tree['left']): tree['left'] = getMean(tree['left'])
  57.     return (tree['left']+tree['right']) / 2.0
  58.    
  59. def prune(tree, testData):
  60.     if np.shape(testData)[0] == 0: return getMean(tree)
  61.     if (isTree(tree['right']) or isTree(tree['left'])):
  62.         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
  63.     if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
  64.     if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)
  65.     if not isTree(tree['left']) and not isTree(tree['right']):
  66.         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
  67.         errorNoMerge = sum(np.power(lSet[:,-1] - tree['left'],2)) +\
  68.             sum(np.power(rSet[:,-1] - tree['right'],2))
  69.         treeMean = (tree['left'] + tree['right'])/2.0
  70.         errorMerge = sum(np.power(testData[:, -1] - treeMean, 2))
  71.         if errorMerge < errorNoMerge:
  72.             print("merging")
  73.             return treeMean
  74.         else:
  75.             return tree
  76.     return tree

  77. if __name__ == "__main__":
  78.     myDat = loadDataSet(r"C:\...\ex2.txt")
  79.     myDatTest = loadDataSet(r"C:\...\ex2test.txt")
  80.     myTree = createTree(myDat, ops = (0, 1))
  81.     res = prune(myTree, myDatTest)
  82. #先基于训练集生成完整的树,然后通过验证集进行后剪枝(相当于调参)。
复制代码






3.模型树
  1. import numpy as np

  2. def loadDataSet(fileName):     
  3.     dataMat = []               
  4.     fr = open(fileName)
  5.     for line in fr.readlines():
  6.         curLine = line.strip().split('\t')
  7.         fltLine = list(map(float,curLine))
  8.         dataMat.append(fltLine)
  9.     return np.mat(dataMat)

  10. def binSplitDataSet(dataSet, feature, value):
  11.     mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:]
  12.     mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
  13.     return mat0,mat1

  14. def regLeaf(dataSet):
  15.     return np.mean(dataSet[:,-1])#返回叶子节点的均值

  16. def regErr(dataSet):
  17.     return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]

  18. def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
  19.     tolS = ops[0]#每次特征选择后的降低总方差的最小值
  20.     tolN = ops[1]#每次换分后每个子节点的最少样本个数
  21.     if len(set(dataSet[:,-1].T.tolist()[0])) == 1:#已经没有可分的直接返回成叶子节点
  22.         return None, leafType(dataSet)
  23.     m,n = np.shape(dataSet)
  24.     S = errType(dataSet)
  25.     bestS = np.inf; bestIndex = 0; bestValue = 0
  26.     for featIndex in range(n-1):
  27.         for splitVal in set(dataSet[:,featIndex].tolist()[0]):
  28.             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
  29.             if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
  30.             newS = errType(mat0) + errType(mat1)
  31.             if newS < bestS:
  32.                 bestIndex = featIndex
  33.                 bestValue = splitVal
  34.                 bestS = newS
  35.     if (S - bestS) < tolS:
  36.         return None, leafType(dataSet)#叶子节点
  37.     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
  38.     if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
  39.         return None, leafType(dataSet)
  40.     return bestIndex,bestValue



  41. def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
  42.     #递归的形式生成树
  43.     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
  44.     if feat == None: return val
  45.     retTree = {}
  46.     retTree['spInd'] = feat
  47.     retTree['spVal'] = val
  48.     lSet, rSet = binSplitDataSet(dataSet, feat, val)
  49.     retTree['left'] = createTree(lSet, leafType, errType, ops)
  50.     retTree['right'] = createTree(rSet, leafType, errType, ops)
  51.     return retTree  


  52. def linearSolve(dataSet):  
  53.     m, n = np.shape(dataSet)
  54.     X = np.mat(np.ones((m,n))); Y = np.mat(np.ones((m,1)))
  55.     X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
  56.     xTx = X.T*X
  57.     if np.linalg.det(xTx) == 0.0:
  58.         raise NameError('This matrix is singular, cannot do inverse,\n\
  59.         try increasing the second value of ops')
  60.     ws = xTx.I * (X.T * Y)#线性回归公式
  61.     return ws,X,Y

  62. def modelLeaf(dataSet):
  63.     ws,X,Y = linearSolve(dataSet)
  64.     return ws

  65. def modelErr(dataSet):
  66.     ws,X,Y = linearSolve(dataSet)
  67.     yHat = X * ws
  68.     return sum(np.power(Y - yHat,2))

  69. if __name__ == "__main__":
  70.     myDat = loadDataSet(r"C:\...\ex2.txt")
  71.     myTree = createTree(myDat, modelLeaf, modelErr, (1, 10))
复制代码

本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2020-11-16 14:31:13 | 显示全部楼层
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-5-27 14:03

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表