鱼C论坛

 找回密码
 立即注册
查看: 3331|回复: 1

[技术交流] Python实现ID3【决策树】

[复制链接]
发表于 2020-11-5 16:57:50 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 糖逗 于 2020-11-5 16:59 编辑

参考书籍:《机器学习实战》

  1. import numpy as np
  2. import operator
  3. import matplotlib.pyplot as plt
  4. from math import log

  5. def calcShannonEnt(dataSet):
  6.     numEntries = len(dataSet)
  7.     labelCounts = {}
  8.     for featVec in dataSet:
  9.         currentLabel = featVec[-1]
  10.         if currentLabel not in labelCounts.keys():
  11.             labelCounts[currentLabel] = 0
  12.         labelCounts[currentLabel] += 1
  13.     shannonEnt = 0
  14.     for key in labelCounts:
  15.         prob = float(labelCounts[key]) / numEntries
  16.         shannonEnt -= prob * log(prob, 2)#以2为底数
  17.     return shannonEnt

  18. #根据特征划分数据集
  19. def splitDataSet(dataSet, axis, value):
  20.     retDataSet = []
  21.     for featVec in dataSet:
  22.         if featVec[axis] == value:
  23.             reduceFeatVec = featVec[:axis]#不包括axis
  24.             reduceFeatVec.extend(featVec[axis + 1 :])
  25.             retDataSet.append(reduceFeatVec)
  26.     return retDataSet

  27. #选择最好的数据集划分方式
  28. def chooseBestFeatureToSplit(dataSet):
  29.     numFeatures = len(dataSet[0]) - 1
  30.     baseEntropy = calcShannonEnt(dataSet)
  31.     bestInfoGain = 0
  32.     bestFeature = -1
  33.     for i in range(numFeatures):#每个特征单独计算
  34.         featList = [example[i] for example in dataSet]
  35.         uniqueVals = set(featList)
  36.         newEntropy = 0
  37.         for value in uniqueVals:
  38.             subDataSet = splitDataSet(dataSet, i, value)
  39.             prob = len(subDataSet) / float(len(dataSet))
  40.             newEntropy += prob * calcShannonEnt(subDataSet)
  41.         infoGain = baseEntropy - newEntropy
  42.         if(infoGain > bestInfoGain):
  43.             bestInfoGain = infoGain
  44.             bestFeature = i
  45.     return bestFeature

  46. #数据集已经处理了所有特征,但类标签依然不是唯一的,采用多数表决的方法确定返回的类
  47. def majorityCnt(classList):
  48.     classCount = {}
  49.     for vote in classList:
  50.         if vote not in classCount.keys():
  51.             classCount[vote] = 0
  52.         classCount[vote] += 1
  53.     sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1),
  54.                               reverse = True)
  55.     return sortedClassCount[0][0]

  56. #创建树的代码(递归)
  57. def createTree(dataSet, labels):
  58.     classList = [example[-1] for example in dataSet]
  59.     if classList.count(classList[0]) == len(classList):#count()方法用于统计某个元素在列表中出现的次数。
  60.         return classList[0]
  61.     if len(dataSet[0]) == 1:
  62.         return majorityCnt(classList)
  63.     bestFeat = chooseBestFeatureToSplit(dataSet)
  64.     bestFeatLabel = labels[bestFeat]
  65.     myTree = {bestFeatLabel:{}}
  66.     del(labels[bestFeat])
  67.     featValues = [example[bestFeat] for example in dataSet]
  68.     uniqueVals = set(featValues)
  69.     for value in uniqueVals:
  70.         subLabels = labels[:]
  71.         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),
  72.               subLabels)
  73.     return myTree

  74. def classify(inputTree, featLabels, testVec):
  75.     firstStr = list(inputTree.keys())[0]
  76.     secondDict = inputTree[firstStr]
  77.     featIndex = featLabels.index(firstStr)
  78.     for key in secondDict.keys():
  79.         if testVec[featIndex] == key:
  80.             if type(secondDict[key]).__name__ == "dict":#如果是字典的话,接着向下找
  81.                 classLabel = classify(secondDict[key], featLabels, testVec)
  82.             else:
  83.                 classLabel = secondDict[key]
  84.     return classLabel

  85. if __name__ == '__main__':
  86.     dataSet = [[1, 1, "yes"],     
  87.                [1, 1, "yes"],
  88.                [1, 0, "no"],
  89.                [0, 1, "no"],
  90.                [0, 1, "no"]]
  91.     labels = ["no surfacing", "flippers"]
  92.     tree = createTree(dataSet, labels)
  93.     labels = ["no surfacing", "flippers"]#因为createTree阶段会删除labels中的值
  94.     res = classify(tree, labels, [1, 1])
  95.     print(res)
复制代码

本帖被以下淘专辑推荐:

小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2020-11-5 17:00:00 | 显示全部楼层
递归
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-5-20 05:47

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表