鱼C论坛

 找回密码
 立即注册
查看: 1863|回复: 2

[技术交流] Python实现kNN

[复制链接]
发表于 2020-11-3 21:40:46 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 糖逗 于 2020-11-4 12:54 编辑

参考书籍《机器学习实战》
本地测试环境:python3.7
import numpy as np
import operator

class kNN():
    def __init__(self, inX, dataSet, labels, k):
        self.inX = inX
        self.dataSet = dataSet
        self.labels = labels
        self.k = k
        
    def classify0(self):
        dataSetSize = self.dataSet.shape[0]
        #https://blog.csdn.net/laobai1015/article/details/85719724
        diffMat = np.tile(self.inX, (dataSetSize, 1)) - dataSet
        sqDiffMat = diffMat ** 2
        sqDistances = sqDiffMat.sum(axis = 1)
        distances = sqDistances ** 0.5
        sortedDistIndicies = distances.argsort()
        classCount = {}
        for i in range(self.k):
            voteIlabel = self.labels[sortedDistIndicies[i]]
            #https://blog.csdn.net/weixin_45683963/article/details/103898093
            classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
        #https://www.runoob.com/python3/python3-func-sorted.html
        sortedClassCount = sorted(classCount.items(), 
                                  ##https://www.cnblogs.com/zhoufankui/p/6274172.html
                                  key = operator.itemgetter(1),
                                  reverse = True)#reverse = True为降序
        print(sortedClassCount[0][0])
    
if __name__ == '__main__':
    group = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]]) #特征
    labels = ['A', 'A', 'B', 'B']#分类
    kNN([0, 0], group, labels, 3).classify0()#[0,0]是待分了的特征


import numpy as np
import operator

class kNN():
    def __init__(self, filename, k):
        self.filename = filename
        self.k = k
    
    #读取数据
    def file2matrix(filename):
        fr = open(filename)
        arrayOLines = fr.readlines()
        numberOfLines = len(arrayOLines)
        returnMat = np.zeros((numberOfLines, 3))
        classLabelVector = []
        index = 0
        for line in arrayOLines:
            line = line.strip()
            listFromLine = line.split('\t')
            returnMat[index, :] = listFromLine[0:3]
            classLabelVector.append(int(listFromLine[-1]))
            index += 1
        return returnMat, classLabelVector
    
    #归一化
    def autoNorm(dataSet):
        minVals = dataSet.min(0)
        maxVals = dataSet.max(0)
        ranges = maxVals - minVals
        normDataSet = np.zeros(np.shape(dataSet))
        m = dataSet.shape[0]
        normDataSet = dataSet - np.tile(minVals, (m, 1))
        normDataSet = normDataSet / np.tile(ranges, (m, 1))
        return normDataSet, ranges, minVals
    
    #kNN
    def classify0(self, inX, dataSet, labels):
            dataSetSize = dataSet.shape[0]
            diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
            sqDiffMat = diffMat ** 2
            sqDistances = sqDiffMat.sum(axis = 1)
            distances = sqDistances ** 0.5
            sortedDistIndicies = distances.argsort()
            classCount = {}
            for i in range(self.k):
                voteIlabel = labels[sortedDistIndicies[i]]
                classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
            sortedClassCount = sorted(classCount.items(), 
                                      key = operator.itemgetter(1),
                                      reverse = True)#reverse = True为降序
            return sortedClassCount[0][0]
    
    #测试准确率
    def datingClassTest(self, hoRatio):
        #只取了数据量的10%作为样本
        #hoRatio = 0.10
        datingDataMat, datingLabels = kNN.file2matrix(self.filename)
        normMat, ranges, minVals = kNN.autoNorm(datingDataMat)
        m = normMat.shape[0]
        numTestVecs = int(m * hoRatio)
        errorCount = 0.0
        for i in range(numTestVecs):
            classifierResult = kNN.classify0(self, normMat[i, :], normMat[numTestVecs:m, :],
                                         datingLabels[numTestVecs:m])
            print("the classifier came back with: %d, the real answer is %d"
                  %(classifierResult, datingLabels[i]))
            if(classifierResult != datingLabels[i]):
                errorCount += 1
        print("the total error rate is: %f" %(errorCount / float(numTestVecs)))
        
    #预测    
    def classifyPerson(self):
        resultList = ["not at all", "in small doses", "in large doses"]
        percentTats = float(input("percentage of time spent playing video games?"))#用来获取控制台的输入
        ffMiles = float(input("frequent filer miles earned per year?"))
        iceCream = float(input("liters of ice cream consumed per year?"))
        datingDataMat, datingLabels = kNN.file2matrix(self.filename)
        normMat, ranges, minVals = kNN.autoNorm(datingDataMat)
        inArr = np.array([ffMiles, percentTats, iceCream])
        classifierResult = kNN.classify0(self, (inArr - minVals)/ranges, normMat, datingLabels)
        print("You will probably like this person:", resultList[classifierResult - 1])#因为实际类别是1,2,3,下标识从0开始的,所以要-1

if __name__ == '__main__':
    path = 'datingTestSet2.txt'
    test1 = kNN(path, 3)
    test1.datingClassTest(0.1)
    test1.classifyPerson()
import numpy as np
import operator
import os

class kNN():
    def __init__(self, path):
        self.path = path
        
    #将矩阵打平,编程一维向量
    def img2vector(filname):
        returnVect = np.zeros((1, 1024))
        fr = open(filname)
        for i in range(32):
            lineStr = fr.readline()
            for j in range(32):
                returnVect[0, 32 * i + j] = int(lineStr[j])
        return returnVect


    def classify0(inX, dataSet, labels, k):
            dataSetSize = dataSet.shape[0]
            diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
            sqDiffMat = diffMat ** 2
            sqDistances = sqDiffMat.sum(axis = 1)
            distances = sqDistances ** 0.5
            sortedDistIndicies = distances.argsort()
            classCount = {}
            for i in range(k):
                voteIlabel = labels[sortedDistIndicies[i]]
                classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
            sortedClassCount = sorted(classCount.items(), 
                                      key = operator.itemgetter(1),
                                      reverse = True)#reverse = True为降序
            return sortedClassCount[0][0]

    #在全部的数据上训练,然后在全部的数据上进行检验
    def handwritingClassTest(self):
        hwLabels = []
        #返回文件下的所有文件名
        trainingFileList = os.listdir(self.path + "testDigits")
        m = len(trainingFileList)
        trainingMat = np.zeros((m, 1024))
        for i in range(m):
            fileNameStr = trainingFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            hwLabels.append(classNumStr)
            trainingMat[i, :] = kNN.img2vector(self.path + "testDigits/%s" % fileNameStr)
        testFileList = os.listdir(self.path + "testDigits")
        errorCount = 0
        mTest = len(testFileList)
        for i in range(mTest):
            fileNameStr = testFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            vectorUnderTest = kNN.img2vector(self.path + "testDigits\%s" % fileNameStr)
            classifierResult = kNN.classify0(vectorUnderTest, trainingMat, hwLabels, 3)
            print("the classifier came back with:%d, the real answer if :%d" %(classifierResult, classNumStr))
            if(classifierResult != classNumStr) :
                errorCount += 1
        print("\nthe total numner of errors is: %d" % errorCount)
        print("\nthe total error rate is %f" %(errorCount / float(mTest)))

if __name__ == '__main__':
    path = "C:/.../machinelearninginaction/Ch02/digits/"
    kNN(path).handwritingClassTest()

本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2020-11-4 10:55:46 | 显示全部楼层
每个框都是独立的个体,可以独立运行
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2020-11-4 12:55:50 | 显示全部楼层
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-1-17 23:16

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表