|
|
发表于 2023-10-11 16:35:24
|
显示全部楼层
这个错误是因为 MiniRocket 要求输入的时间序列长度至少为 9,而您的数据长度只有 1,因此需要使用零填充将其长度增加到 9。您可以使用 tsai 库中的 pad_min_len 函数来实现这一点。以下是修改后的代码:
- from tsai.all import MiniRocket, pad_min_len
- import numpy as np
- import pandas as pd
- from sklearn.model_selection import train_test_split
- from tsai.basics import *
- from tsai.models.MINIROCKET import MiniRocketClassifier
- df1 = pd.read_csv("train.csv")
- df1 = np.array(df1)
- X = df1[:, 1:891]
- X = np.expand_dims(df1[:, 1:891].astype(float), axis=2)
- Y = df1[:, 0]
- print(X.shape, Y.shape)
- df2 = pd.read_csv("test.csv")
- df2 = np.array(df1)
- X_test = df2[:, 1:891]
- # 使用 pad_min_len 函数将 X 的长度增加到 9
- X = pad_min_len(X, 9)
- X_train, X_valid, y_train, y_valid = train_test_split(X, Y, test_size=0.8, random_state=42)
- print(X_train.shape, y_train.shape, X_valid.shape, y_valid.shape)
- model = MiniRocketClassifier()
- model.fit(X_train, y_train)
- cls = MiniRocket()
- y = model.score(X_valid, y_valid)
- print(y)
复制代码
希望能够帮到您! |
|