鱼C论坛

 找回密码
 立即注册
查看: 1759|回复: 0

[技术交流] 机器学习系列------决策函数

[复制链接]
发表于 2018-6-14 09:51:09 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
        新启一个文件,首先把以前的代码复制过来:
from sklearn.datasets import fetch_mldata



mnist=fetch_mldata('MNIST original',data_home='.\datasets')


X,y=mnist["data"],mnist["target"]
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

X_train,X_test,y_train,y_test=X[:60000],X[60000:],y[:60000],y[60000:]

import numpy as np
shuffle_index=np.random.permutation(60000)
X_train,y_train=X_train[shuffle_index],y_train[shuffle_index]
some_digit=X[36000]
y_train_5=(y_train==5)
y_test_5=(y_test==5)
from sklearn.linear_model import SGDClassifier

sgd_clf=SGDClassifier(random_state=42)
sgd_clf.fit(X_train,y_train_5)
        建立决策函数,决策函数是通过调试阀值帮助我们预测精准度的一个工具:
y_scores=sgd_clf.decision_function([some_digit])
y_scores
        会显示:array([ 66994.58438748]),这个数是预测准确的一个分数,但是必须要有一个阀值来判断边界:
threshold=0
y_some_digit_pred=(y_scores>threshold)
y_some_digit_pred
        以上我们把阀值设定为0,他会显示:array([ True], dtype=bool)。因为我们这个分数已经超过了阀值,就说明预测结果是对的,但是阀值又通过什么来确定呢,我们把阀值设定高点试试:
threshold=200000
y_some_digit_pred=(y_scores>threshold)
y_some_digit_pred
        显示:array([False], dtype=bool)。就说明阀值一定要取的正合适才行,就要通过准确率召回曲线确定阀值:
from sklearn.model_selection import cross_val_predict
y_scores=cross_val_predict(sgd_clf,X_train,y_train_5,cv=3,method="decision_function")
from sklearn.metrics import precision_recall_curve
precisions,recalls,thresholds=precision_recall_curve(y_train_5,y_scores[:,1])
        然后我们写个画图功能的函数:
def plot_precision_recall_vs_threshold(precisions,recalls,thresholds):
    plt.plot(thresholds,precisions[0:-1],"b--",label="Precision")
    plt.plot(thresholds,recalls[:-1],"g-",label="Recall")
    plt.xlabel("Threshold")
    plt.legend(loc="upper left")
    plt.ylim([0,1])
        调用画图函数:
plot_precision_recall_vs_threshold(precisions,recalls,thresholds)
plt.show()
        输出的图像为:
dsfsdfd.png
        从上图可以得出准确度追高的时候在50000左右,我们可以看一下决策函数大于50000时的评分:
from sklearn.metrics import precision_score,recall_score
y_train_pred_90=(y_scores>50000)
precision_score(y_train_5,np.argmax(y_train_pred_90, axis=1))
        显示结果为:0.90712074303405577
recall_score(y_train_5,np.argmax(y_train_pred_90, axis=1))
        0.54049068437557646

本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-11-22 08:55

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表