|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
代码如下:
from numpy import *
import math
import copy
import pickle as pp
class ID3DTree(object):
def __init__(self):
self.tree = {}
self.dataSet = []
self.labels = []
def loadDataSet(self,path,labels):
recordlist = []
fp = open(path,"rb")
content = fb.read()
fp.close()
rowlist = content.splitlines()
recordlist = [row.split("\t") for row in rowlist if row.strip()]
self.dataSet = recordlist
self.labels = labels
def train(self):
labels = copy.deepcopy(self.labels)
self.tree = self.buildTree(self.dataSet,labels)
def buildTree(self,dataSet,labels):
catelist = [data[-1] for data in dataSet]
if catelist.count(catelist[0]) == len(catelist):
return catelist[0]
if len(dataSet[0]) == 1:
return self.maxCate(catelist)
besfFeat = self.getBestFeat(dataSet)
bestFeatLabel = labels[besfFeat]
tree = {bestFeatLabel:{}}
del(labels[besfFeat])
uniqueVals = set([data[bestFeat] for data in dataSet])
for value in uniqueVals:
subLabels = labels[:]
splitDataSet = self.splitDataSet(dataSet,besfFeat,value)
subTree = self.buildTree(splitDataSet,subLabels)
tree[bestFeatLabel][value] = subTree
return tree
def maxCate(self,catelist):
items = dict([(catelist.count(i),i) for i in catelist])
return items[max(items.keys())]
def getBestFeat(self,dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = self.computeEntropy(dataSet)
bestInfoGain = 0.0;
besfFeature = -1
for i in xrange(numFeatures):
uniqueVals = set([data[i] for data in dataSet])
newEntropy = 0.0
for value in uniqueVals:
subDataSet = self.splitDataSet(dataSet,i,value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * self.computeEntropy(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
besfFeature = id
return besfFeature
def computeEntropy(self,dataSet):
datalen = float(len(dataSet))
catelist = [data[-1] for data in dataSet]
items = dict([(i,catelist.count(i)) for i in catelist])
infoEntropy = 0.0
for key in items:
prob = float(items[key])/datalen
infoEntropy -= prob * math.log(prob,2)
return infoEntropy
def splitDataSet(self,dataSet,axis,value):
rtnList = []
for featVec in dataSet:
if featVec[axis] == value:
rFeatVec = featVec[:axis]
rFeatVec.extend(featVec[axis+1:])
rtnList.append(rFeatVec)
return rtnList
|
-
|