鱼C论坛

 找回密码
 立即注册
查看: 1517|回复: 0

[技术交流] 机器学习系列------多元分类

[复制链接]
发表于 2018-6-16 08:28:54 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
        以前我们只是把5从10个数里分辨出来,如果想可以分辨出任何数,就要用到多元分类器。还是把昨天代码复制过来:
from sklearn.datasets import fetch_mldata



mnist=fetch_mldata('MNIST original',data_home='.\datasets')


X,y=mnist["data"],mnist["target"]
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

X_train,X_test,y_train,y_test=X[:60000],X[60000:],y[:60000],y[60000:]

import numpy as np
shuffle_index=np.random.permutation(60000)
X_train,y_train=X_train[shuffle_index],y_train[shuffle_index]
some_digit=X[36000]

from sklearn.linear_model import SGDClassifier

sgd_clf=SGDClassifier(random_state=42)
        下面只需要2句代码就可以实现了:
sgd_clf.fit(X_train,y_train)
sgd_clf.predict([some_digit])
        他的原理其实就是分10组二元分类,综合到一起就是多元的了,输出结果为:array([ 5.])。下面看一下决策函数的得分:
some_digit_scores=sgd_clf.decision_function([some_digit])
some_digit_scores
        输出为:array([[-118859.94408043, -358705.50982291, -522104.71162894,
         -23094.88221178, -365053.0820119 ,  256769.03159862,
        -599218.80497511, -306327.89660489, -545664.1740147 ,
        -689004.37831298]])
        这就是0-1十个数的得分,中间那个就是5,分数最高,假设我们用0作为阀值,高过0的只会选出256769.03159862来。也可以使用方法得出最高分数:
np.argmax(some_digit_scores)
        输出为:5。我们再看看他里面的结构:
sgd_clf.classes_
        输出为:array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.])。这跟上面那10个数是一一对应的。看看他的第五个数:
sgd_clf.classes_[5]
        输出为:5.0 。下面还有一种多元分类的方法是把梯度下降方法放进一对一分类器里:
from sklearn.multiclass import OneVsOneClassifier

ovo_clf=OneVsOneClassifier(SGDClassifier(random_state=42))
ovo_clf.fit(X_train,y_train)
ovo_clf.predict([some_digit])
        输出为:array([ 5.])。用随机森林分类器也可以做到:
from sklearn.ensemble import RandomForestClassifier

forest_clf=RandomForestClassifier(random_state=42)
forest_clf.fit(X_train,y_train)
forest_clf.predict([some_digit])
        输出为:array([ 5.])。然后预测下分数:
forest_clf.predict_proba([some_digit])
        输出为:array([[ 0. ,  0. ,  0.2,  0. ,  0. ,  0.7,  0.1,  0. ,  0. ,  0. ]])。可以看到里面最高分0.7也是位于5那个位置。下面用我们以前学过的给评分:
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf,X_train,y_train,cv=3,scoring="accuracy")
        输出为:array([ 0.84588082,  0.83749187,  0.87423113])。标准化的方法:
from sklearn.preprocessing import StandardScaler

scaler=StandardScaler()
X_train_scaler=scaler.fit_transform(X_train.astype(np.float64))
cross_val_score(sgd_clf,X_train,y_train,cv=3,scoring="accuracy")
        输出为:array([ 0.84588082,  0.83749187,  0.87423113])

本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-11-22 09:44

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表