鱼C论坛

 找回密码
 立即注册
查看: 2617|回复: 0

[其他分类] 初识机器学习——鸢尾花分类

[复制链接]
发表于 2019-8-15 13:57:42 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 程序员的救赎 于 2019-8-23 13:49 编辑

一、查看数据集


        简单说明一下数据集的构成及其作用:
                训练集——学习知识——>课本;
                验证集——检验学习能力——>课后习题;  
                测试集——检验泛化能力——>考试;   
               一般来说,验证集从训练集中划分,有时候并不是必须的。

        鸢尾花数据集是机器学习和统计学中一个经典的数据集,包含在scikit-learn的datasets模块中
  1. from sklearn.datasets import load_iris
  2. iris_dataset = load_iris()
复制代码


        load_iris返回的iris对象是一个Bunch对象,与字典非常相似,里面包含键和值
  1. iris_dataset.keys()
复制代码
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])

        DESCR键对应的值是数据集的简要说明,这里只截取部分
  1. print(iris_dataset['DESCR'][:193])
复制代码
.. _iris_dataset:

        Iris plants dataset
        --------------------

        **Data Set Characteristics:**

           :Number of Instances: 150 (50 in each of three classes)
           :Number of Attributes: 4 numeric, pre

        target_names 键对应的值是一个字符串数组,里面包含要预测的花的花种
  1. iris_dataset['target_names']
复制代码
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

        feature_names 键对应的值是一个字符串列表,对每一个特征进行了说明
  1. iris_dataset['feature_names']
复制代码
['sepal length (cm)',
         'sepal width (cm)',
        'petal length (cm)',
         'petal width (cm)']

        数据包含在target和data中
  1. iris_dataset['data'][:5]  # 打印前五行
复制代码
array([[5.1, 3.5, 1.4, 0.2],
              [4.9, 3. , 1.4, 0.2],
              [4.7, 3.2, 1.3, 0.2],
              [4.6, 3.1, 1.5, 0.2],
              [5. , 3.6, 1.4, 0.2]])

  1. iris_dataset['data'].shape  # 查看数组形状
复制代码
(150, 4)
        可以看出,数组形状为150x4。每一行为一朵花的数据(sample),每一列为花的不同属性(feature)

  1. iris_dataset['target']
复制代码
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
               1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
               1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
               2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
               2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
        target为一维数组,0, 1,2分别和上面的target_names下标对应的类别相对应,一个数字标记了一朵花的类别 (label)



二、划分训练数据和测试数据

  1. from sklearn.model_selection import train_test_split
复制代码

        划分数据集, 可以设置测试集的比例(默认为0.25), 指定随机数生成的种子(确保每次得到的输出是固定的)
  1. X_train, X_test, y_train, y_test = train_test_split(
  2. iris_dataset['data'], iris_dataset['target'], random_state=0)
复制代码

        数据集用X表示, 标签用y表示是受函数y = f(x)启发,使用大写的X代表二维数组,而小写的y代表一维标量。 打乱数据集后再进行划分的目的是防止因为标签是有序的而导致测试集或者训练集中的数据类别不完整,进而影响训练结果

  1. X_train.shape
复制代码
(112, 4)
  1. X_test.shape
复制代码
(38, 4)



三、观察数据

        在构建学习模型之前,通常要检查一下数据。一是做模型选择,二是检测数据是否异常。模型选择要根据数据的特定进行判断,而检测数据异常最佳方法之一是将其可视化
  1. import pandas as pd
  2. import mglearn
  3. iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
复制代码


        绘制散点图
  1. grr = pd.plotting.scatter_matrix(iris_dataframe,
  2.                                  c=y_train,  # 颜色参数,使用的训练集的label
  3.                                  figsize=(15, 15), # 图像尺寸
  4.                                  marker='o',  # 标记类型
  5.                                  hist_kwds={'bins': 20},
  6.                                  s=60,
  7.                                  alpha=.8, # 图像透明度
  8.                                  cmap=mglearn.cm3)
复制代码

下载.png

        从图中可以知道,从花瓣和花萼的测量标签数据基本可以将三个类别区分开,说明机器学习很可能学会区分它们



四、训练模型——k近邻
  1. from sklearn.neighbors import KNeighborsClassifier
  2. knn = KNeighborsClassifier(n_neighbors=1)  # 设置邻居数目
  3. knn.fit(X_train, y_train)
复制代码
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=None, n_neighbors=1, p=2,
           weights='uniform')

        模型训练好之后,需要进行预测
  1. import numpy as np
  2. X_new = np.array([[5, 2.9, 1, 0.2]])
  3. X_new.shape
复制代码
(1, 4)
  1. predicion = knn.predict(X_new)
  2. predicion
复制代码
array([0])

        预测结果表明这个花属于类别0, 即setosa



五、模型评估

  1. y_pred = knn.predict(X_test)
  2. y_pred
复制代码
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
       0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 2])

        计算预测精度
  1. np.mean(y_pred == y_test)
复制代码
0.9736842105263158
        预测精度代表了我们模型在面对新的数据时的处理能力(泛化能力),精度越高,说明模型越好。

        至此,第一个机器学习任务结束


机器学习的一般步骤:
      获取数据——数据清洗和处理——划分数据集——模型选择——参数设置——训练模型——模型评估——调整模型——预测结果

      由于这里数据集是标准数据集,所以不需要清洗数据。


本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-4-20 03:10

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表