鱼C论坛

 找回密码
 立即注册
查看: 335|回复: 1

预测数据时报错是什么原因

[复制链接]
发表于 2024-11-10 00:06:33 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
报错如下
  1. ---------------------------------------------------------------------------
  2. ValueError                                Traceback (most recent call last)
  3. File /opt/conda/lib/python3.10/site-packages/pandas/core/indexes/range.py:413, in RangeIndex.get_loc(self, key)
  4.     412 try:
  5. --> 413     return self._range.index(new_key)
  6.     414 except ValueError as err:

  7. ValueError: 0 is not in range

  8. The above exception was the direct cause of the following exception:

  9. KeyError                                  Traceback (most recent call last)
  10. Cell In[46], line 1
  11. ----> 1 pre_valid = inference(model,valid_eval_dataloader)#测试集预测值

  12. Cell In[42], line 6, in inference(model, test_dataloader)
  13.       3 model.eval()
  14.       5 with torch.no_grad():
  15. ----> 6     for batch in tqdm(test_dataloader):
  16.       7         input_ids = batch['input_ids'].to(device)
  17.       8         attention_mask = batch['attention_mask'].to(device)

  18. File /opt/conda/lib/python3.10/site-packages/tqdm/notebook.py:250, in tqdm_notebook.__iter__(self)
  19.     248 try:
  20.     249     it = super().__iter__()
  21. --> 250     for obj in it:
  22.     251         # return super(tqdm...) will not catch exception
  23.     252         yield obj
  24.     253 # NB: except ... [ as ...] breaks IPython async KeyboardInterrupt

  25. File /opt/conda/lib/python3.10/site-packages/tqdm/std.py:1181, in tqdm.__iter__(self)
  26.    1178 time = self._time
  27.    1180 try:
  28. -> 1181     for obj in iterable:
  29.    1182         yield obj
  30.    1183         # Update and possibly print the progressbar.
  31.    1184         # Note: does not call self.update(1) for speed optimisation.

  32. File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
  33.     627 if self._sampler_iter is None:
  34.     628     # TODO(https://github.com/pytorch/pytorch/issues/76750)
  35.     629     self._reset()  # type: ignore[call-arg]
  36. --> 630 data = self._next_data()
  37.     631 self._num_yielded += 1
  38.     632 if self._dataset_kind == _DatasetKind.Iterable and \
  39.     633         self._IterableDataset_len_called is not None and \
  40.     634         self._num_yielded > self._IterableDataset_len_called:

  41. File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self)
  42.     671 def _next_data(self):
  43.     672     index = self._next_index()  # may raise StopIteration
  44. --> 673     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  45.     674     if self._pin_memory:
  46.     675         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

  47. File /opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:52, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
  48.      50         data = self.dataset.__getitems__(possibly_batched_index)
  49.      51     else:
  50. ---> 52         data = [self.dataset[idx] for idx in possibly_batched_index]
  51.      53 else:
  52.      54     data = self.dataset[possibly_batched_index]

  53. File /opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:52, in <listcomp>(.0)
  54.      50         data = self.dataset.__getitems__(possibly_batched_index)
  55.      51     else:
  56. ---> 52         data = [self.dataset[idx] for idx in possibly_batched_index]
  57.      53 else:
  58.      54     data = self.dataset[possibly_batched_index]

  59. Cell In[26], line 15, in LLMDataset.__getitem__(self, idx)
  60.      14 def __getitem__(self,idx):
  61. ---> 15     text = self.df.loc[idx,'cleaned'] # extracting text from each row
  62.      17     encoded_dict = self.tokenizer.encode_plus(
  63.      18         text,
  64.      19         add_special_tokens=True,#自动在每个文本前后添加特殊标记(如CLS和SEP)
  65.    (...)
  66.      24         return_attention_mask=True, # We should put it into the model,计算注意力(attention)时忽略那些paddle值
  67.      25     )
  68.      27     if self.is_grad:#训练集

  69. File /opt/conda/lib/python3.10/site-packages/pandas/core/indexing.py:1183, in _LocationIndexer.__getitem__(self, key)
  70.    1181     key = tuple(com.apply_if_callable(x, self.obj) for x in key)
  71.    1182     if self._is_scalar_access(key):
  72. -> 1183         return self.obj._get_value(*key, takeable=self._takeable)
  73.    1184     return self._getitem_tuple(key)
  74.    1185 else:
  75.    1186     # we by definition only have the 0th axis

  76. File /opt/conda/lib/python3.10/site-packages/pandas/core/frame.py:4221, in DataFrame._get_value(self, index, col, takeable)
  77.    4215 engine = self.index._engine
  78.    4217 if not isinstance(self.index, MultiIndex):
  79.    4218     # CategoricalIndex: Trying to use the engine fastpath may give incorrect
  80.    4219     #  results if our categories are integers that dont match our codes
  81.    4220     # IntervalIndex: IntervalTree has no get_loc
  82. -> 4221     row = self.index.get_loc(index)
  83.    4222     return series._values[row]
  84.    4224 # For MultiIndex going through engine effectively restricts us to
  85.    4225 #  same-length tuples; see test_get_set_value_no_partial_indexing

  86. File /opt/conda/lib/python3.10/site-packages/pandas/core/indexes/range.py:415, in RangeIndex.get_loc(self, key)
  87.     413         return self._range.index(new_key)
  88.     414     except ValueError as err:
  89. --> 415         raise KeyError(key) from err
  90.     416 if isinstance(key, Hashable):
  91.     417     raise KeyError(key)

  92. KeyError: 0
复制代码


相关代码如下

  1. from torch.utils.data import Dataset
  2. import torch

  3. #定义数据集
  4. class LLMDataset(Dataset):
  5.     def __init__(self,df,is_grad,tokenizer):
  6.         self.df = df # Pandas.DataFrame
  7.         self.is_grad = is_grad # True: train,valid / False: test
  8.         self.tokenizer = tokenizer

  9.     def __len__(self):
  10.         return len(self.df) # number of samples

  11.     def __getitem__(self,idx):
  12.         text = self.df.loc[idx,'cleaned'] # extracting text from each row
  13.         
  14.         encoded_dict = self.tokenizer.encode_plus(
  15.             text,
  16.             add_special_tokens=True,#自动在每个文本前后添加特殊标记(如CLS和SEP)
  17.             padding='max_length',#补0
  18.             truncation=True,#句子长度大于max_length时截断
  19.             max_length=512, # given to the max_length of tokenized text
  20.             return_tensors='pt', # PyTorch
  21.             return_attention_mask=True, # We should put it into the model,计算注意力(attention)时忽略那些paddle值
  22.         )

  23.         if self.is_grad:#训练集
  24.             labels = self.df.loc[idx]['label']
  25.             # [batch,1,max_len(84)] -> [batch,max_len]#使用squeeze降维
  26.             return {'input_ids':encoded_dict['input_ids'].squeeze(),
  27.                     'attention_mask':encoded_dict['attention_mask'].squeeze(),
  28.                     'token_type_ids':encoded_dict['token_type_ids'].squeeze(),
  29.                    # Our loss_fn wants it to be a "float" type
  30.                     'labels':torch.tensor(labels,dtype=torch.float).unsqueeze(dim=0)}
  31.         else:#测试集
  32.             # [batch,1,max_len(84)] -> [batch,max_len]
  33.             return {'input_ids':encoded_dict['input_ids'].squeeze(),
  34.                     'attention_mask':encoded_dict['attention_mask'].squeeze(),
  35.                    'token_type_ids':encoded_dict['token_type_ids'].squeeze()}model = LLMModel(configs['model_name']).to(device)
  36. def inference(model,test_dataloader):
  37.     all_preds = []
  38.     model.eval()

  39.     with torch.no_grad():
  40.         for batch in tqdm(test_dataloader):
  41.             input_ids = batch['input_ids'].to(device)
  42.             attention_mask = batch['attention_mask'].to(device)
  43.             
  44.             logits = model(input_ids,attention_mask)
  45.             logits = logits.detach().cpu().numpy()
  46.             all_preds.append(logits)
  47.    
  48.     return all_preds
  49. valid_dataset = LLMDataset(train_df[train_size:].drop('label',axis=1),False,tokenizer)
  50. valid_eval_dataloader = DataLoader(valid_dataset,batch_size=8,shuffle=False,pin_memory=True)
  51. pre_valid = inference(model,valid_eval_dataloader)#测试集预测值
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2024-11-10 00:06:44 | 显示全部楼层
在分析错误信息和代码后,我发现您的问题可能在于:

1. **`.loc` 用法错误**:从错误信息来看,出现了 `KeyError: 0`。在 `__getitem__` 方法中,您尝试直接使用 `self.df.loc` 来提取文本,但 `loc` 通常需要指定键或索引。正确用法应该是 `self.df.loc[idx]`,根据数据框的索引来获取行数据。

2. **缺少列名称**:稍微修改一下获取列数据的方法。如果您的数据框包含某个列名(比如 "text"),则需要指定该列名,比如 `self.df.loc[idx, 'text']` 来提取文本。

3. **Tensor 类型问题**:`encoded_dict` 的键值在 `encode_plus` 返回的是字典,您要确保对齐数据类型,比如 `encoded_dict['input_ids'].squeeze()`。

以下是更新后的 `__getitem__` 方法:

  1. def __getitem__(self, idx):
  2.     text = self.df.loc[idx, 'text']  # 假设 'text' 是列名
  3.    
  4.     encoded_dict = self.tokenizer.encode_plus(
  5.         text,
  6.         add_special_tokens=True,
  7.         padding='max_length',
  8.         truncation=True,
  9.         max_length=512,
  10.         return_tensors='pt',
  11.         return_attention_mask=True
  12.     )

  13.     # 查找相对列名是重要的, 下方假设列名为 'label'
  14.     if self.is_grad:
  15.         labels = self.df.loc[idx, 'label']
  16.         return {
  17.             'input_ids': encoded_dict['input_ids'].squeeze(),
  18.             'attention_mask': encoded_dict['attention_mask'].squeeze(),
  19.             'token_type_ids': encoded_dict['token_type_ids'].squeeze(),
  20.             'labels': torch.tensor(labels, dtype=torch.float).unsqueeze(dim=0)
  21.         }
  22.     else:
  23.         return {
  24.             'input_ids': encoded_dict['input_ids'].squeeze(),
  25.             'attention_mask': encoded_dict['attention_mask'].squeeze(),
  26.             'token_type_ids': encoded_dict['token_type_ids'].squeeze()
  27.         }
复制代码


通过这样的修改,基本上大多数涉及错误 `KeyError` 的原因就能解决。根据具体的列名稍微调整代码,并处理可能的Pandas DataFrame的索引问题。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-5-8 04:14

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表