马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
光谱数据维度有855个样本,890个变量,输入minirocket出现错误。
X.shape=(855,890),Y.shape=(855,1)
发生异常: ValueError
n_timepoints must be >= 9, but found 1; zero pad shorter series so that n_timepoints == 9
File "D:\0000可见光2\程序\MiniRocket\Test1.py", line 20, in <module>
model.fit(X_train, y_train)
ValueError: n_timepoints must be >= 9, but found 1; zero pad shorter series so that n_timepoints == 9
from tsai.all import MiniRocket
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tsai.basics import *
from tsai.models.MINIROCKET import MiniRocketClassifier
df1 = pd.read_csv("train.csv")
df1 = np.array(df1)
X = df1[:, 1:891]
X = np.expand_dims(df1[:, 1:891].astype(float), axis=2)
Y = df1[:, 0]
print(X.shape, Y.shape)
df2 = pd.read_csv("test.csv")
df2 = np.array(df1)
X_test = df2[:, 1:891]
X_train, X_valid, y_train, y_valid = train_test_split(X, Y, test_size=0.8, random_state=42)
print(X_train.shape, y_train.shape, X_valid.shape, y_valid.shape)
model = MiniRocketClassifier()
model.fit(X_train, y_train)
cls = MiniRocket()
y=model.score(X_valid, y_valid)
print(y)
|