鱼C论坛

 找回密码
 立即注册
查看: 2936|回复: 2

minirocket 数据维度

[复制链接]
发表于 2023-10-11 16:35:24 | 显示全部楼层
这个错误是因为 MiniRocket 要求输入的时间序列长度至少为 9,而您的数据长度只有 1,因此需要使用零填充将其长度增加到 9。您可以使用 tsai 库中的 pad_min_len 函数来实现这一点。以下是修改后的代码:

  1. from tsai.all import MiniRocket, pad_min_len
  2. import numpy as np
  3. import pandas as pd
  4. from sklearn.model_selection import train_test_split
  5. from tsai.basics import *
  6. from tsai.models.MINIROCKET import MiniRocketClassifier

  7. df1 = pd.read_csv("train.csv")
  8. df1 = np.array(df1)
  9. X = df1[:, 1:891]
  10. X = np.expand_dims(df1[:, 1:891].astype(float), axis=2)
  11. Y = df1[:, 0]
  12. print(X.shape, Y.shape)
  13. df2 = pd.read_csv("test.csv")
  14. df2 = np.array(df1)
  15. X_test = df2[:, 1:891]

  16. # 使用 pad_min_len 函数将 X 的长度增加到 9
  17. X = pad_min_len(X, 9)

  18. X_train, X_valid, y_train, y_valid = train_test_split(X, Y, test_size=0.8, random_state=42)
  19. print(X_train.shape, y_train.shape, X_valid.shape, y_valid.shape)
  20. model = MiniRocketClassifier()
  21. model.fit(X_train, y_train)

  22. cls = MiniRocket()
  23. y = model.score(X_valid, y_valid)
  24. print(y)
复制代码


希望能够帮到您!
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-10-28 07:03

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表