import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
#分类指标
class cat_plot():
def __init__(self, data, cat):
self.data = data
self.cat = cat
def Precision(self, threshold):
data["预测分类"] = self.data.iloc[:, 1].apply(lambda x : 1 if x > threshold else 0)
TP = np.sum((self.data["预测分类"] == 1) & (self.data["二分类"] == 1))
FP = np.sum((self.data["预测分类"] == 1) & (self.data["二分类"] == 0))
comput = TP / (TP + FP)
return comput
def Recall(self, threshold):
data["预测分类"] = self.data.iloc[:, 1].apply(lambda x : 1 if x > threshold else 0)
TP = np.sum((self.data["预测分类"] == 1) & (self.data["二分类"] == 1))
FN = np.sum((self.data["预测分类"] == 0) & (self.data["二分类"] == 0))
comput = TP / (TP + FN)
return comput
def FPR(self, threshold):
data["预测分类"] = self.data.iloc[:, 1].apply(lambda x : 1 if x > threshold else 0)
FP = np.sum((self.data["预测分类"] == 1) & (self.data["二分类"] == 0))
TN = np.sum((self.data["预测分类"] == 0) & (self.data["二分类"] == 1))
comput = FP / (FP + TN)
return comput
def plot(self):
point_x = []
point_y = []
if self.cat == "P-R":
for i in range(len(self.data)):
threshold = self.data.iloc[i, 1]
point_x.append(self.Recall(threshold))
point_y.append(self.Precision(threshold))
ax1 = sns.lineplot(x = point_x, y = point_y)
ax1.set_title("P-R")
plt.show()
elif self.cat == "ROC":
for i in range(len(self.data)):
threshold = self.data.iloc[i, 1]
point_x.append(self.FPR(threshold))
point_y.append(self.Recall(threshold))
ax2 = sns.lineplot(x = point_x, y = point_y)
ax2.set_title("ROC")
plt.show()
def AUC(self):
self.data = self.data.sort_values(by= "预测概率",ascending = True)
self.data["rank"] = self.data["预测概率"].rank(method='average')
M = np.sum(self.data["二分类"])
comput = np.sum(self.data.loc[data["二分类"] == 1, "rank"]) - (M+1)*M / 2
print(comput/(M * (len(self.data) - M)))
if __name__ == '__main__':
#数据
data = {"二分类" :[1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1],
"预测概率" :[0.99, 0.12, 0.29, 0.45, 0.93, 0.78, 0.89, 0.34, 0.78,
0.67, 0.99, 0.17, 0.88, 0.29, 0.72, 0.11, 0.89, 0.49, 0.34, 0.78,
0.11, 0.99]}
data = pd.DataFrame(data)
temp1 = cat_plot(data, "P-R")
temp2 = cat_plot(data, "ROC")
temp1.plot()
temp2.plot()
temp1.AUC()