鱼C论坛

 找回密码
 立即注册
查看: 2688|回复: 0

关于sklearn分类器里的fit的用法

[复制链接]
发表于 2022-7-23 16:17:20 | 显示全部楼层 |阅读模式
20鱼币
这次我需要完成一个自己书写集成类学习的方法,但是由于bagging的方法需要生存n个分类器,所以需要训练n次麻。clf.fit(x_train,y_train) 我需要每次都生成一个新的分类器么? dt = DecisionTreeClassifier(**best_para_DT )
还是每次我只要在函数初始化的时候输入一个clf 然后只要不停的调用fit(x,y) 而不需要每次新生成 dt = DecisionTreeClassifier(**best_para_DT ) 就可以了


  1. class Bagging_DT(object):
  2.     '''
  3.     算法步骤:
  4.     1. 随机k次抽取m百分比数据
  5.     2. 每次生成一个分类器,最终生成k个分类器
  6.     3. 结果由k个分类器取平均值产生
  7.     '''
  8.     def __init__(self , k ,  best_para ):
  9.         self.k = k #int
  10.         self.best_para = best_para # model 的实例对象,

  11.     def fit(self, x, y):
  12.         x = np.array(x)
  13.         y = np.array(y)
  14.         leg =  len(x)
  15.         self.clf_set = []
  16.         for i in range(self.k):
  17.             dt = DecisionTreeClassifier(**best_para_DT ) #将最好的参数输入
  18.             rand_series = np.random.randint(leg, size = leg) #有放回的生成随机数
  19.             dt.fit(x[rand_series], y[rand_series]) #训练svm
  20.             self.clf_set.append(dt) #训练好的svm模型放入

  21.         return self.clf_set

  22.     def predict(self,x):
  23.         results = []
  24.         x = np.array(x)
  25.         for clf in self.clf_set:
  26.             clf_result = clf.predict(x)
  27.             results.append(clf_result)

  28.         #结果返回的是平均值
  29.         res_score = np.mean( np.array(results) , axis= 0)
  30.         res_clf = res_score.copy() #返回的是平均值
  31.         res_score[res_score >= 0.5 ]= 1
  32.         res_score[res_score < 0.5 ]= 0

  33.         return res_score ,res_clf


  34. # DT bagging 的实现
  35. bag_DT = Bagging_DT(10000,  best_para_DT)
  36. model_DT = bag_DT.fit(x_train_DT , y_train_DT )
  37. pred_DT , pred_score_DT =  bag_DT.predict(x_val_DT)
  38. scc_DT = accuracy_score(y_val_DT , pred_DT)
  39. print(f"决策森林对应的正确率为{scc_DT:.3f}")
  40. auc_DT = roc_auc_score (y_val_DT , pred_score_DT)
  41. print(f"决策森林对应的auc为{auc_DT:.3f}")
复制代码

小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-27 18:36

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表