|
楼主 |
发表于 2024-10-20 15:50:43
|
显示全部楼层
获取第k折的训练集和验证集函数def get_k_fold_data(k, i, X, y):
assert k > 1
fold_size = X.shape[0] // k
X_train, y_train = None, None
for j in range(k):
idx = slice(j * fold_size, (j + 1) * fold_size)
X_part, y_part = X[idx, :], y[idx]
if j == i:
X_valid, y_valid = X_part, y_part
elif X_train is None:
X_train, y_train = X_part, y_part
else:
X_train = torch.cat([X_train, X_part], 0)
y_train = torch.cat([y_train, y_part], 0)
return X_train, y_train, X_valid, y_valid
按k的顺序迭代for i in range(k):
X_train, X_valid, y_train, y_valid = get_k_fold_data(k, i, X, y) #获取第k折的训练集和验证集
print(f'FOLD {i}')
print('--------------------------------')
# 将数据转换为NumPy数组,然后再转换为PyTorch张量
X_train = torch.tensor(X_train.values, dtype=torch.float32)
y_train = torch.tensor(y_train.values, dtype=torch.float32)
X_valid = torch.tensor(X_valid.values, dtype=torch.float32)
y_valid = torch.tensor(y_valid.values, dtype=torch.float32)
# 获取一个数据迭代器
train_dataset = data.TensorDataset(X_train, y_train)
valid_dataset = data.TensorDataset(X_valid, y_valid)
train_iter = data.DataLoader(train_dataset, batch_size, shuffle=True)
valid_iter = data.DataLoader(valid_dataset, batch_size, shuffle=True)
报错如下---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/pandas/core/indexes/base.py:3805, in Index.get_loc(self, key)
3804 try:
-> 3805 return self._engine.get_loc(casted_key)
3806 except KeyError as err:
File index.pyx:167, in pandas._libs.index.IndexEngine.get_loc()
File index.pyx:173, in pandas._libs.index.IndexEngine.get_loc()
TypeError: '(slice(0, 142, None), slice(None, None, None))' is an invalid key
During handling of the above exception, another exception occurred:
InvalidIndexError Traceback (most recent call last)
Cell In[116], line 62
59 nn_model = SimpleNN()
61 for i in range(k):
---> 62 X_train, X_valid, y_train, y_valid = get_k_fold_data(k, i, X, y) #获取第k折的训练集和验证集
63 print(f'FOLD {i}')
64 print('--------------------------------')
Cell In[116], line 43, in get_k_fold_data(k, i, X, y)
41 for j in range(k):
42 idx = slice(j * fold_size, (j + 1) * fold_size)
---> 43 X_part, y_part = X[idx, :], y[idx]
44 if j == i:
45 X_valid, y_valid = X_part, y_part
File /opt/conda/lib/python3.10/site-packages/pandas/core/frame.py:4102, in DataFrame.__getitem__(self, key)
4100 if self.columns.nlevels > 1:
4101 return self._getitem_multilevel(key)
-> 4102 indexer = self.columns.get_loc(key)
4103 if is_integer(indexer):
4104 indexer = [indexer]
File /opt/conda/lib/python3.10/site-packages/pandas/core/indexes/base.py:3817, in Index.get_loc(self, key)
3812 raise KeyError(key) from err
3813 except TypeError:
3814 # If we have a listlike key, _check_indexing_error will raise
3815 # InvalidIndexError. Otherwise we fall through and re-raise
3816 # the TypeError.
-> 3817 self._check_indexing_error(key)
3818 raise
File /opt/conda/lib/python3.10/site-packages/pandas/core/indexes/base.py:6059, in Index._check_indexing_error(self, key)
6055 def _check_indexing_error(self, key):
6056 if not is_scalar(key):
6057 # if key is not a scalar, directly raise an error (the code below
6058 # would convert to numpy arrays and raise later any way) - GH29926
-> 6059 raise InvalidIndexError(key)
InvalidIndexError: (slice(0, 142, None), slice(None, None, None))
|
|