糖逗 发表于 2020-11-16 12:58:42

Python实现CART

本帖最后由 糖逗 于 2020-11-16 14:30 编辑

参考书籍:《机器学习实战》

1.CART回归(预剪枝,通过tolS和tolT两个参数控制)
import numpy as np

def loadDataSet(fileName):   
    dataMat = []               
    fr = open(fileName)
    for line in fr.readlines():
      curLine = line.strip().split('\t')
      fltLine = list(map(float,curLine))
      dataMat.append(fltLine)
    return np.mat(dataMat)

def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet > value),:]
    mat1 = dataSet <= value),:]
    return mat0,mat1

def regLeaf(dataSet):
    return np.mean(dataSet[:,-1])

def regErr(dataSet):
    return np.var(dataSet[:,-1]) * np.shape(dataSet)

def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
    tolS = ops#每次特征选择后的降低总方差的最小值
    tolN = ops#每次换分后每个子节点的最少样本个数
    if len(set(dataSet[:,-1].T.tolist())) == 1:#已经没有可分的直接返回成叶子节点
      return None, leafType(dataSet)
    m,n = np.shape(dataSet)
    S = errType(dataSet)
    bestS = np.inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
      for splitVal in set(dataSet[:,featIndex].tolist()):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (np.shape(mat0) < tolN) or (np.shape(mat1) < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS:
      return None, leafType(dataSet)#叶子节点
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (np.shape(mat0) < tolN) or (np.shape(mat1) < tolN):
      return None, leafType(dataSet)
    return bestIndex,bestValue
                     

def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
    #递归的形式生成树
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree


if __name__ == "__main__":
    myDat = loadDataSet(r"C:\...\ex00.txt")
    res = createTree(myDat)
   




2.CART回归(后剪枝)
import numpy as np

def loadDataSet(fileName):   
    dataMat = []               
    fr = open(fileName)
    for line in fr.readlines():
      curLine = line.strip().split('\t')
      fltLine = list(map(float,curLine))
      dataMat.append(fltLine)
    return np.mat(dataMat)

def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet > value),:]
    mat1 = dataSet <= value),:]
    return mat0,mat1

def regLeaf(dataSet):
    return np.mean(dataSet[:,-1])

def regErr(dataSet):
    return np.var(dataSet[:,-1]) * np.shape(dataSet)

def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
    tolS = ops#每次特征选择后的降低总方差的最小值
    tolN = ops#每次换分后每个子节点的最少样本个数
    if len(set(dataSet[:,-1].T.tolist())) == 1:#已经没有可分的直接返回成叶子节点
      return None, leafType(dataSet)
    m,n = np.shape(dataSet)
    S = errType(dataSet)
    bestS = np.inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
      for splitVal in set(dataSet[:,featIndex].tolist()):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (np.shape(mat0) < tolN) or (np.shape(mat1) < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS:
      return None, leafType(dataSet)#叶子节点
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (np.shape(mat0) < tolN) or (np.shape(mat1) < tolN):
      return None, leafType(dataSet)
    return bestIndex,bestValue



def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
    #递归的形式生成树
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree


def isTree(obj):
    return (type(obj).__name__=='dict')

def getMean(tree):#树的递归
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right']) / 2.0
   
def prune(tree, testData):
    if np.shape(testData) == 0: return getMean(tree)
    if (isTree(tree['right']) or isTree(tree['left'])):
      lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] =prune(tree['right'], rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
      lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
      errorNoMerge = sum(np.power(lSet[:,-1] - tree['left'],2)) +\
            sum(np.power(rSet[:,-1] - tree['right'],2))
      treeMean = (tree['left'] + tree['right'])/2.0
      errorMerge = sum(np.power(testData[:, -1] - treeMean, 2))
      if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
      else:
            return tree
    return tree

if __name__ == "__main__":
    myDat = loadDataSet(r"C:\...\ex2.txt")
    myDatTest = loadDataSet(r"C:\...\ex2test.txt")
    myTree = createTree(myDat, ops = (0, 1))
    res = prune(myTree, myDatTest)
#先基于训练集生成完整的树,然后通过验证集进行后剪枝(相当于调参)。






3.模型树
import numpy as np

def loadDataSet(fileName):   
    dataMat = []               
    fr = open(fileName)
    for line in fr.readlines():
      curLine = line.strip().split('\t')
      fltLine = list(map(float,curLine))
      dataMat.append(fltLine)
    return np.mat(dataMat)

def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet > value),:]
    mat1 = dataSet <= value),:]
    return mat0,mat1

def regLeaf(dataSet):
    return np.mean(dataSet[:,-1])#返回叶子节点的均值

def regErr(dataSet):
    return np.var(dataSet[:,-1]) * np.shape(dataSet)

def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
    tolS = ops#每次特征选择后的降低总方差的最小值
    tolN = ops#每次换分后每个子节点的最少样本个数
    if len(set(dataSet[:,-1].T.tolist())) == 1:#已经没有可分的直接返回成叶子节点
      return None, leafType(dataSet)
    m,n = np.shape(dataSet)
    S = errType(dataSet)
    bestS = np.inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
      for splitVal in set(dataSet[:,featIndex].tolist()):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (np.shape(mat0) < tolN) or (np.shape(mat1) < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS:
      return None, leafType(dataSet)#叶子节点
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (np.shape(mat0) < tolN) or (np.shape(mat1) < tolN):
      return None, leafType(dataSet)
    return bestIndex,bestValue



def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
    #递归的形式生成树
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree


def linearSolve(dataSet):
    m, n = np.shape(dataSet)
    X = np.mat(np.ones((m,n))); Y = np.mat(np.ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T*X
    if np.linalg.det(xTx) == 0.0:
      raise NameError('This matrix is singular, cannot do inverse,\n\
      try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)#线性回归公式
    return ws,X,Y

def modelLeaf(dataSet):
    ws,X,Y = linearSolve(dataSet)
    return ws

def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(np.power(Y - yHat,2))

if __name__ == "__main__":
    myDat = loadDataSet(r"C:\...\ex2.txt")
    myTree = createTree(myDat, modelLeaf, modelErr, (1, 10))

糖逗 发表于 2020-11-16 14:31:13

{:10_259:}
页: [1]
查看完整版本: Python实现CART