鱼C论坛

 找回密码
 立即注册
查看: 382|回复: 1

gbdt模型训练遇到报错

[复制链接]
发表于 2024-10-22 00:31:18 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
代码如下
  1. #Gradient Boost Decision Tree GBDT
  2. from sklearn.ensemble import GradientBoostingClassifier
  3. from sklearn.model_selection import GridSearchCV
  4. from sklearn.metrics import accuracy_score

  5. gbc = GradientBoostingClassifier()
  6. param_test = {'n_estimators': [50, 100],
  7.               'min_samples_split': [3, 4, 5, 6, 7],
  8.               'max_depth': [3, 4, 5, 6]}
  9. gbc_grid = GridSearchCV(gbc, param_test, cv=5, refit=True, verbose=1)
  10. gbc_grid.fit(X_train,y_train)
  11. train_score = gbc.score(X_train,y_train)
  12. gbc_pre = gbc.predict(X_valid)
  13. valid_score = accuracy_score(y_valid,gbc_pre)
  14. print(valid_score)
复制代码


报错如下
  1. Fitting 5 folds for each of 40 candidates, totalling 200 fits
  2. /opt/conda/lib/python3.10/site-packages/sklearn/base.py:432: UserWarning: X has feature names, but GradientBoostingClassifier was fitted without feature names
  3.   warnings.warn(
  4. ---------------------------------------------------------------------------
  5. NotFittedError                            Traceback (most recent call last)
  6. Cell In[25], line 12
  7.      10 gbc_grid = GridSearchCV(gbc, param_test, cv=5, refit=True, verbose=1)
  8.      11 gbc_grid.fit(X_train,y_train)
  9. ---> 12 train_score = gbc.score(X_train,y_train)
  10.      13 gbc_pre = gbc.predict(X_valid)
  11.      14 valid_score = accuracy_score(y_valid,gbc_pre)

  12. File /opt/conda/lib/python3.10/site-packages/sklearn/base.py:668, in ClassifierMixin.score(self, X, y, sample_weight)
  13.     643 """
  14.     644 Return the mean accuracy on the given test data and labels.
  15.     645
  16.    (...)
  17.     664     Mean accuracy of ``self.predict(X)`` w.r.t. `y`.
  18.     665 """
  19.     666 from .metrics import accuracy_score
  20. --> 668 return accuracy_score(y, self.predict(X), sample_weight=sample_weight)

  21. File /opt/conda/lib/python3.10/site-packages/sklearn/ensemble/_gb.py:1308, in GradientBoostingClassifier.predict(self, X)
  22.    1293 def predict(self, X):
  23.    1294     """Predict class for X.
  24.    1295
  25.    1296     Parameters
  26.    (...)
  27.    1306         The predicted values.
  28.    1307     """
  29. -> 1308     raw_predictions = self.decision_function(X)
  30.    1309     encoded_labels = self._loss._raw_prediction_to_decision(raw_predictions)
  31.    1310     return self.classes_.take(encoded_labels, axis=0)

  32. File /opt/conda/lib/python3.10/site-packages/sklearn/ensemble/_gb.py:1264, in GradientBoostingClassifier.decision_function(self, X)
  33.    1243 """Compute the decision function of ``X``.
  34.    1244
  35.    1245 Parameters
  36.    (...)
  37.    1259     array of shape (n_samples,).
  38.    1260 """
  39.    1261 X = self._validate_data(
  40.    1262     X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False
  41.    1263 )
  42. -> 1264 raw_predictions = self._raw_predict(X)
  43.    1265 if raw_predictions.shape[1] == 1:
  44.    1266     return raw_predictions.ravel()

  45. File /opt/conda/lib/python3.10/site-packages/sklearn/ensemble/_gb.py:687, in BaseGradientBoosting._raw_predict(self, X)
  46.     685 def _raw_predict(self, X):
  47.     686     """Return the sum of the trees raw predictions (+ init estimator)."""
  48. --> 687     raw_predictions = self._raw_predict_init(X)
  49.     688     predict_stages(self.estimators_, X, self.learning_rate, raw_predictions)
  50.     689     return raw_predictions

  51. File /opt/conda/lib/python3.10/site-packages/sklearn/ensemble/_gb.py:673, in BaseGradientBoosting._raw_predict_init(self, X)
  52.     671 def _raw_predict_init(self, X):
  53.     672     """Check input and compute raw predictions of the init estimator."""
  54. --> 673     self._check_initialized()
  55.     674     X = self.estimators_[0, 0]._validate_X_predict(X, check_input=True)
  56.     675     if self.init_ == "zero":

  57. File /opt/conda/lib/python3.10/site-packages/sklearn/ensemble/_gb.py:382, in BaseGradientBoosting._check_initialized(self)
  58.     380 def _check_initialized(self):
  59.     381     """Check that the estimator is initialized, raising an error if not."""
  60. --> 382     check_is_fitted(self)

  61. File /opt/conda/lib/python3.10/site-packages/sklearn/utils/validation.py:1390, in check_is_fitted(estimator, attributes, msg, all_or_any)
  62.    1385     fitted = [
  63.    1386         v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
  64.    1387     ]
  65.    1389 if not fitted:
  66. -> 1390     raise NotFittedError(msg % {"name": type(estimator).__name__})

  67. NotFittedError: This GradientBoostingClassifier instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2024-10-22 00:31:26 | 显示全部楼层
从报错信息来看,`GradientBoostingClassifier`实例`gbc`没有经过训练,所以无法进行评分或者预测。这是因为你在`GridSearchCV`中训练的是`gbc_grid`,而你试图用`gbc`进行评分和预测。

以下是你需要做的更改:

1. 使用`gbc_grid.best_estimator_`来获取经过训练的最佳模型实例。

2. 将`gbc.score(X_train, y_train)`改为`gbc_grid.score(X_train, y_train)`。

3. 将`gbc.predict(X_valid)`改为`gbc_grid.predict(X_valid)`。

修改后的代码应该如下所示:

  1. param_test = {'n_estimators': [100, 200],
  2.               'min_samples_split': [2, 5],
  3.               'max_depth': [3, 4]}
  4. gbc_grid = GridSearchCV(gbc, param_test, cv=5, refit=True, verbose=1)

  5. # Train the model using X_train and y_train
  6. gbc_grid.fit(X_train, y_train)

  7. # Use the best estimator found
  8. train_score = gbc_grid.score(X_train, y_train)
  9. gbc_pre = gbc_grid.predict(X_valid)
  10. valid_score = accuracy_score(y_valid, gbc_pre)

  11. print(valid_score)
复制代码


请注意,`param_test`中的参数列表应该包含具体的值供GridSearchCV搜索。另外,确保你传入的`X_train`和`X_valid`数据形式与模型的期望一致。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-19 16:47

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表