|

楼主 |
发表于 2024-10-20 15:50:43
|
显示全部楼层
获取第k折的训练集和验证集函数
- def get_k_fold_data(k, i, X, y):
- assert k > 1
- fold_size = X.shape[0] // k
- X_train, y_train = None, None
- for j in range(k):
- idx = slice(j * fold_size, (j + 1) * fold_size)
- X_part, y_part = X[idx, :], y[idx]
- if j == i:
- X_valid, y_valid = X_part, y_part
- elif X_train is None:
- X_train, y_train = X_part, y_part
- else:
- X_train = torch.cat([X_train, X_part], 0)
- y_train = torch.cat([y_train, y_part], 0)
- return X_train, y_train, X_valid, y_valid
复制代码
按k的顺序迭代
- for i in range(k):
- X_train, X_valid, y_train, y_valid = get_k_fold_data(k, i, X, y) #获取第k折的训练集和验证集
- print(f'FOLD {i}')
- print('--------------------------------')
- # 将数据转换为NumPy数组,然后再转换为PyTorch张量
- X_train = torch.tensor(X_train.values, dtype=torch.float32)
- y_train = torch.tensor(y_train.values, dtype=torch.float32)
- X_valid = torch.tensor(X_valid.values, dtype=torch.float32)
- y_valid = torch.tensor(y_valid.values, dtype=torch.float32)
- # 获取一个数据迭代器
- train_dataset = data.TensorDataset(X_train, y_train)
- valid_dataset = data.TensorDataset(X_valid, y_valid)
- train_iter = data.DataLoader(train_dataset, batch_size, shuffle=True)
- valid_iter = data.DataLoader(valid_dataset, batch_size, shuffle=True)
复制代码
报错如下- ---------------------------------------------------------------------------
- TypeError Traceback (most recent call last)
- File /opt/conda/lib/python3.10/site-packages/pandas/core/indexes/base.py:3805, in Index.get_loc(self, key)
- 3804 try:
- -> 3805 return self._engine.get_loc(casted_key)
- 3806 except KeyError as err:
- File index.pyx:167, in pandas._libs.index.IndexEngine.get_loc()
- File index.pyx:173, in pandas._libs.index.IndexEngine.get_loc()
- TypeError: '(slice(0, 142, None), slice(None, None, None))' is an invalid key
- During handling of the above exception, another exception occurred:
- InvalidIndexError Traceback (most recent call last)
- Cell In[116], line 62
- 59 nn_model = SimpleNN()
- 61 for i in range(k):
- ---> 62 X_train, X_valid, y_train, y_valid = get_k_fold_data(k, i, X, y) #获取第k折的训练集和验证集
- 63 print(f'FOLD {i}')
- 64 print('--------------------------------')
- Cell In[116], line 43, in get_k_fold_data(k, i, X, y)
- 41 for j in range(k):
- 42 idx = slice(j * fold_size, (j + 1) * fold_size)
- ---> 43 X_part, y_part = X[idx, :], y[idx]
- 44 if j == i:
- 45 X_valid, y_valid = X_part, y_part
- File /opt/conda/lib/python3.10/site-packages/pandas/core/frame.py:4102, in DataFrame.__getitem__(self, key)
- 4100 if self.columns.nlevels > 1:
- 4101 return self._getitem_multilevel(key)
- -> 4102 indexer = self.columns.get_loc(key)
- 4103 if is_integer(indexer):
- 4104 indexer = [indexer]
- File /opt/conda/lib/python3.10/site-packages/pandas/core/indexes/base.py:3817, in Index.get_loc(self, key)
- 3812 raise KeyError(key) from err
- 3813 except TypeError:
- 3814 # If we have a listlike key, _check_indexing_error will raise
- 3815 # InvalidIndexError. Otherwise we fall through and re-raise
- 3816 # the TypeError.
- -> 3817 self._check_indexing_error(key)
- 3818 raise
- File /opt/conda/lib/python3.10/site-packages/pandas/core/indexes/base.py:6059, in Index._check_indexing_error(self, key)
- 6055 def _check_indexing_error(self, key):
- 6056 if not is_scalar(key):
- 6057 # if key is not a scalar, directly raise an error (the code below
- 6058 # would convert to numpy arrays and raise later any way) - GH29926
- -> 6059 raise InvalidIndexError(key)
- InvalidIndexError: (slice(0, 142, None), slice(None, None, None))
复制代码 |
|