|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
输出结果0.37010975031949667 9 不应该啊
- import pandas as pd #导入数据集
- url=r"C:\Users\zmj佳佳佳\Desktop\第六步离散化测试.csv"
- df = pd.read_csv(url, header = None,low_memory=False)#将数据集分为训练集和测试集
- df.columns=["grade","dti","delinq_2yrs","earliest_cr_line","fico_range_low","inq_last_6mths",
- "mths_since_last_delinq","pub_rec","revol_bal","revol_util","mths_since_last_major_derog",
- "tot_cur_bal","open_acc_6m","open_il_12m","open_il_24m","mths_since_rcnt_il","open_rv_12m",
- "open_rv_24m","max_bal_bc","all_util","inq_last_12m","acc_open_past_24mths","avg_cur_bal",
- "bc_open_to_buy","mo_sin_old_il_acct","mo_sin_old_rev_tl_op","mo_sin_rcnt_rev_tl_op","mo_sin_rcnt_tl",
- "mort_acc","mths_since_recent_bc_dlq","mths_since_recent_inq","mths_since_recent_revol_delinq",
- "num_accts_ever_120_pd","num_actv_bc_tl","num_actv_rev_tl","num_bc_sats","num_bc_tl",
- "num_rev_accts","num_rev_tl_bal_gt_0","num_tl_90g_dpd_24m","num_tl_op_past_12m","pct_tl_nvr_dlq",
- "pub_rec_bankruptcies"]
- #将数据集分为训练集和测试集
- from sklearn.model_selection import train_test_split
- from sklearn.model_selection import GridSearchCV
- from sklearn.ensemble import RandomForestClassifier
- x, y = df.iloc[:, 1:].values, df.iloc[:, 0].values
- x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.3, random_state = 0)
- feat_labels = df.columns[1:]
- #param={"n_estimators":[10,20],"max_depth":[5,8]}
- #网格搜索与交叉验证
- #gc=GridSearchCV(forest,param_grid=param,cv=3)
- #gc.fit(x_train,y_train)
- #print("准确率:",gc.score(x_test,y_test))
- #print("查看选择的参数模型:",gc.best_params_)
- #n_estimators的学习曲线
- from sklearn.model_selection import cross_val_score
- import matplotlib.pyplot as plt
- superpa = []
- for i in range(0,10):
- forest = RandomForestClassifier(n_estimators=i+1,n_jobs=-1)
- rfc_s = cross_val_score(forest,x,y,cv=2).mean()
- superpa.append(rfc_s)
- print(max(superpa),superpa.index(max(superpa)))
- plt.figure(figsize=[20,5])
- plt.plot(range(1,11),superpa)
- plt.show()
复制代码 |
|