鱼C论坛

 找回密码
 立即注册
查看: 338|回复: 8

如何在生成的inputs字典中正确加入labels这个键值对

[复制链接]
发表于 2024-11-3 13:44:52 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. def prepare_input(cfg, text):
  2.     inputs = cfg.tokenizer.encode_plus(
  3.         text,
  4.         return_tensors=None,
  5.         add_special_tokens=True,#自动在每个文本前后添加特殊标记(如CLS和SEP)
  6.         return_attention_mask=True # 计算注意力(attention)时忽略那些paddle值
  7.     )
  8.    
  9.     if len(inputs['input_ids']) > CFG.max_len:#句子长度大于max_length时截断
  10.         inputs['input_ids'] = inputs['input_ids'][:CFG.max_len]
  11.         inputs['attention_mask'] = inputs['attention_mask'][:CFG.max_len]
  12.         inputs['token_type_ids'] = inputs['token_type_ids'][:CFG.max_len]

  13.     for k, v in inputs.items():
  14.         inputs[k] = torch.tensor(v, dtype=torch.long)
  15.     return inputs


  16. class LLMDataset(Dataset):
  17.     def __init__(self, cfg, df,is_grad):
  18.         self.cfg = cfg
  19.         self.df = df
  20.         self.texts = df["cleaned"].values
  21.         self.is_grad = is_grad

  22.     def __len__(self):
  23.         return len(self.texts)

  24.     def __getitem__(self,idx):
  25.         inputs = prepare_input(self.cfg, self.texts[idx])#inputs是一个字典
  26.         if self.is_grad:#训练集
  27.             labels = self.df.loc[idx]['label']#标签数据形状为
  28.             
  29.             return {'input_ids':inputs['input_ids'],
  30.                     'attention_mask':inputs['attention_mask'],
  31.                     "token_type_ids": inputs['token_type_ids'],
  32.                     # Our loss_fn wants it to be a "float" type
  33.                     'labels':torch.tensor(labels,dtype=torch.float).unsqueeze(dim=0)}
复制代码


如何正确把labels加入字典inputs中
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2024-11-3 13:51:04 | 显示全部楼层
要在生成的 `inputs` 字典中正确加入 `labels` 键值对,你需要确保在 `prepare_input` 函数中返回的 `inputs` 是一个字典,并且在 `__getitem__` 方法中正确地将 `labels` 添加到这个字典中。以下是修改后的代码:

  1. def prepare_input(cfg, text):
  2.     inputs = cfg.tokenizer.encode_plus(
  3.         text,
  4.         return_tensors=None,
  5.         add_special_tokens=True,
  6.         return_attention_mask=True
  7.     )
  8.    
  9.     if len(inputs['input_ids']) > cfg.max_len:
  10.         inputs['input_ids'] = inputs['input_ids'][:cfg.max_len]
  11.         inputs['attention_mask'] = inputs['attention_mask'][:cfg.max_len]
  12.         if 'token_type_ids' in inputs:
  13.             inputs['token_type_ids'] = inputs['token_type_ids'][:cfg.max_len]

  14.     for k, v in inputs.items():
  15.         inputs[k] = torch.tensor(v, dtype=torch.long)
  16.     return inputs

  17. class LLMDataset(Dataset):
  18.     def __init__(self, cfg, df, is_grad):
  19.         self.cfg = cfg
  20.         self.df = df
  21.         self.texts = df['text'].values  # 假设文本列名为 'text'
  22.         self.is_grad = is_grad

  23.     def __len__(self):
  24.         return len(self.texts)

  25.     def __getitem__(self, idx):
  26.         inputs = prepare_input(self.cfg, self.texts[idx])
  27.         if self.is_grad:
  28.             labels = self.df.loc[idx, 'label']  # 假设标签列名为 'label'
  29.             inputs['labels'] = torch.tensor(labels, dtype=torch.float).unsqueeze(dim=0)
  30.         
  31.         return inputs
复制代码


在这个修改后的代码中:

1. `prepare_input` 函数返回一个包含 `input_ids`、`attention_mask` 和可能的 `token_type_ids` 的字典。
2. 在 `__getitem__` 方法中,使用 `self.df.loc[idx, 'label']` 获取标签值,并将其添加到 `inputs` 字典中。
3. 确保在 `prepare_input` 中对 `inputs` 的每个值都转换为 `torch.tensor`。

请根据你的数据框架的实际列名调整代码中的 `'text'` 和 `'label'

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-11-3 16:07:10 | 显示全部楼层
FishC 发表于 2024-11-3 13:51
要在生成的 `inputs` 字典中正确加入 `labels` 键值对,你需要确保在 `prepare_input` 函数中返回的 `input ...
  1. ---------------------------------------------------------------------------
  2. KeyError                                  Traceback (most recent call last)
  3. Cell In[41], line 4
  4.       1 print('Training Start!')
  5.       2 print('=' * 100)
  6. ----> 4 train(model,
  7.       5     device,
  8.       6     train_dataloader,
  9.       7     valid_dataloader,
  10.       8     CFG.epochs,
  11.       9     loss_fn,
  12.      10     optimizer,
  13.      11     metric)
  14.      13 del model,train_dataloader, valid_dataloader
  15.      14 gc.collect()

  16. Cell In[40], line 25, in train(model, device, train_dataloader, valid_dataloader, epochs, loss_fn, optimizer, metric)
  17.      23 train_attention_mask = batch['attention_mask'].to(device)
  18.      24 train_token_type_ids = batch['token_type_ids'].to(device)
  19. ---> 25 train_labels = batch['labels'].squeeze().to(device).long()#label真实值long()转化成一维张量
  20.      27 # You can refer to the class "TweetsModel" for understand
  21.      28 # what would be logits
  22.      29 logits = model(train_input_ids, train_attention_mask,train_token_type_ids).to(device)

  23. KeyError: 'labels'
复制代码


报错信息如上,代码如下
  1. print('Training Start!')
  2. print('=' * 100)

  3. train(model,
  4.     device,
  5.     train_dataloader,
  6.     valid_dataloader,
  7.     CFG.epochs,
  8.     loss_fn,
  9.     optimizer,
  10.     metric)

  11. del model,train_dataloader, valid_dataloader
  12. gc.collect()
复制代码
  1. import gc,os
  2. from tqdm.auto import tqdm # visualizing tool for progress

  3. # They will be used to pick the best model.pt given to the valid loss
  4. best_model_epoch, valid_loss_values = [],[]
  5. valid_loss_min = [1] # arbitrary loss I set here
  6. def train(model,device,train_dataloader,valid_dataloader,epochs,loss_fn,optimizer,metric):

  7.     for epoch in range(epochs):
  8.         gc.collect() # memory cleaning垃圾回收机制,减少占用内存
  9.         model.train()

  10.         train_loss = 0
  11.         train_step = 0
  12.         pbar = tqdm(train_dataloader, total=len(train_dataloader))#tqdm参数是一个iterable

  13.         for batch in pbar: # you can also write like "for batch in tqdm(train_dataloader"
  14.             optimizer.zero_grad() # initialize
  15.             train_step += 1
  16.             

  17.             train_input_ids = batch['input_ids'].to(device)#batch是一个字典
  18.             train_attention_mask = batch['attention_mask'].to(device)
  19.             train_token_type_ids = batch['token_type_ids'].to(device)
  20.             train_labels = batch['labels'].squeeze().to(device).long()#label真实值long()转化成一维张量
  21.             
  22.             # You can refer to the class "TweetsModel" for understand
  23.             # what would be logits
  24.             logits = model(train_input_ids, train_attention_mask,train_token_type_ids).to(device)
  25.             predictions = torch.argmax(logits, dim=1) # get an index from larger one
  26.             detached_predictions = predictions.detach().cpu().numpy()
  27.             
  28.             loss = loss_fn(logits, train_labels)
  29.             loss.backward()
  30.             optimizer.step()
  31.             model.zero_grad()

  32.             train_loss += loss.detach().cpu().numpy().item()

  33.             pbar.set_postfix({'train_loss':train_loss/train_step})#设置进度条显示信息
  34.         pbar.close()

  35.         with torch.no_grad():
  36.             model.eval()

  37.             valid_loss = 0
  38.             valid_step = 0
  39.             total_valid_score = 0

  40.             y_pred = [] # for getting f1_score that is a metric of the competition
  41.             y_true = []

  42.             pbar = tqdm(valid_dataloader)
  43.             for batch,labels in pbar:
  44.                 valid_step += 1

  45.                 valid_input_ids = batch['input_ids'].to(device)
  46.                 valid_attention_mask = batch['attention_mask'].to(device)
  47.                 valid_token_type_ids = batch['token_type_ids'].to(device)
  48.                 valid_labels = batch['labels'].squeeze().to(device).long()

  49.                 logits = model(valid_input_ids, valid_attention_mask).to(device)
  50.                 predictions = torch.argmax(logits, dim=1)
  51.                 detached_predictions = predictions.detach().cpu().numpy()
  52.                
  53.                 loss = loss_fn(logits, valid_labels)
  54.                 valid_loss += loss.detach().cpu().numpy().item()

  55.                 y_pred.extend(predictions.cpu().numpy())
  56.                 y_true.extend(valid_labels.cpu().numpy())

  57.             valid_loss /= valid_step
  58.             f1 = f1_score(y_true,y_pred)

  59.             print(f'Epoch [{epoch+1}/{epochs}] Score: {f1}')
  60.             print(f'Epoch [{epoch+1}/{epochs}] Valid_loss: {valid_loss}')

  61.             if valid_loss < min(valid_loss_min):
  62.                 print('model improved!')
  63.             else:
  64.                 print('model not improved')
  65.    
  66.             torch.save(model.state_dict(), f'epoch:{epoch+1}_model.pt')#state_dict 是一个字典对象,包含了模型的所有可学习参数(如权重和偏置)及其当前值
  67.             print('save checkpoint!')
  68.             valid_loss_min.append(valid_loss)
  69.             print(f'valid_loss_min:{min(valid_loss_min)}')

  70.         best_model_epoch.append(f'/kaggle/working/epoch:{epoch+1}_model.pt')
  71.         valid_loss_values.append(valid_loss)
  72.         print('='*100)

  73.     select_best_model() # refer to below function
  74.     print('Train/Valid Completed!!')
  75.     del train_dataloader, valid_dataloader # memory cleaning
  76.     gc.collect()

  77. def select_best_model():
  78.     best_model = best_model_epoch[np.array(valid_loss_values).argmin()]
  79.     os.rename(best_model, best_model.split('.pt')[0] + '_best.pt')#重命名文件
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-11-3 16:13:25 | 显示全部楼层
Terence888 发表于 2024-11-03 16:07
-------------------------------------------------- ...


从错误信息来看,`KeyError: 'labels'` 表示在 `batch` 字典中找不到 `labels` 键。要解决这个问题,你需要确保在数据加载器中正确地将 `labels` 包含在 `batch` 中。

以下是一些可能的解决方案:

1. **检查数据集的 `__getitem__` 方法**:确保在 `LLMDataset` 的 `__getitem__` 方法中,`inputs` 字典中包含 `labels` 键。

2. **检查数据加载器**:确保在创建数据加载器时,`collate_fn` 函数正确地将 `labels` 包含在 `batch` 中。

3. **修改 `train` 函数**:在 `train` 函数中,确保从 `batch` 中正确地提取 `labels`。

以下是一些可能的代码修改:

### 修改 `LLMDataset` 的 `__getitem__` 方法

确保 `inputs` 字典中包含 `labels`:

  1. def __getitem__(self, idx):
  2.     inputs = prepare_input(self.cfg, self.texts[idx])
  3.     if self.is_grad:
  4.         labels = self.df.loc[idx, 'label']  # 假设标签列名为 'label'
  5.         inputs['labels'] = torch.tensor(labels, dtype=torch.float).unsqueeze(dim=0)
  6.     return inputs
复制代码


### 修改 `train` 函数

确保从 `batch` 中正确地提取 `labels`:

  1. for batch in pbar:
  2.     optimizer.zero_grad()
  3.     train_step += 1

  4.     train_input_ids = batch['input_ids'].to(device)
  5.     train_attention_mask = batch['attention_mask'].to(device)
  6.     train_token_type_ids = batch['token_type_ids'].to(device)
  7.     train_labels = batch['labels'].squeeze().to(device).long()

  8.     logits = model(train_input_ids, train_attention_mask, train_token_type_ids).to(device)
  9.     predictions = torch.argmax(logits, dim=1)
  10.     detached_predictions = predictions.detach().cpu().numpy()

  11.     loss = loss_fn(logits, train_labels)
  12.     loss.backward()
  13.     optimizer.step()
  14.     model.zero_grad()

  15.     train_loss += loss.detach().cpu().numpy().item()

  16.     pbar.set_postfix({'train_loss': train_loss / train_step})
复制代码


确保在数据集和数据加载器中正确地处理 `labels`,这样在训练过程中就不会出现 `KeyError

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-11-3 16:13:30 | 显示全部楼层
Terence888 发表于 2024-11-03 16:07
-------------------------------------------------- ...


从错误信息来看,`KeyError: 'labels'` 表示在数据加载器中没有找到 `labels` 这个键。要解决这个问题,你需要确保在数据集的 `__getitem__` 方法中正确地将 `labels` 键值对添加到返回的字典中。

以下是一些可能的解决方案:

1. **检查数据集的 `__getitem__` 方法**:确保 `inputs` 字典中包含 `labels` 键。

2. **确保数据加载器返回的批次包含 `labels`**:在训练循环中,`batch` 应该是一个包含 `labels` 的字典。

以下是如何修改 `LLMDataset` 类的 `__getitem__` 方法以确保 `labels` 被正确添加:

  1. def __getitem__(self, idx):
  2.     inputs = prepare_input(self.cfg, self.texts[idx])
  3.     if self.is_grad:
  4.         labels = self.df.iloc[idx]['label']  # 假设标签列名为 'label'
  5.         inputs['labels'] = torch.tensor(labels, dtype=torch.float).unsqueeze(dim=0)
  6.     return inputs
复制代码


3. **检查训练循环**:在训练循环中,确保你正确地从 `batch` 中提取 `labels`:

  1. for batch in pbar:
  2.     train_input_ids = batch['input_ids'].to(device)
  3.     train_attention_mask = batch['attention_mask'].to(device)
  4.     train_token_type_ids = batch['token_type_ids'].to(device)
  5.     train_labels = batch['labels'].squeeze().to(device).long()
复制代码


确保数据集和数据加载器的实现与模型训练代码一致,这样可以避免 `KeyError`。如果 `labels` 是从数据集中提取的,确保数据集的每个样本都包含 `labels

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-11-3 16:19:11 | 显示全部楼层
FishC 发表于 2024-11-3 16:13
从错误信息来看,`KeyError: 'labels'` 表示在 `batch` 字典中找不到 `labels` 键。要解决这个问题,你 ...
  1. class CollateCls:
  2.     def __init__(self, cfg):
  3.         self.tokenizer = cfg.tokenizer
  4.         self.cfg = cfg
  5.         
  6.     def __call__(self, batch):
  7.         output = dict()
  8.         output["input_ids"] = [sample["input_ids"] for sample in batch]
  9.         output["attention_mask"] = [sample["attention_mask"] for sample in batch]
  10.         output["token_type_ids"] = [sample["token_type_ids"] for sample in batch]
  11.         # calculate max token length of this batch
  12.         batch_max = max([len(ids) for ids in output["input_ids"]])

  13.         # add padding
  14.         if self.tokenizer.padding_side == "right":
  15.             output["input_ids"] = [
  16.                 list(s) + (batch_max - len(s)) * [self.tokenizer.pad_token_id]
  17.                 for s in output["input_ids"]
  18.             ]
  19.             output["attention_mask"] = [
  20.                 list(s) + (batch_max - len(s)) * [0] for s in output["attention_mask"]
  21.             ]
  22.             output["token_type_ids"] = [list(s) + (batch_max - len(s)) * [0] for s in output["token_type_ids"]]

  23.         else:
  24.             output["input_ids"] = [
  25.                 (batch_max - len(s)) * [self.tokenizer.pad_token_id] + list(s)
  26.                 for s in output["input_ids"]
  27.             ]
  28.             output["attention_mask"] = [
  29.                 (batch_max - len(s)) * [0] + list(s) for s in output["attention_mask"]
  30.             ]
  31.             output["token_type_ids"] = [(batch_max - len(s)) * [0] + list(s) for s in output["token_type_ids"]]
  32.         
  33.         
  34.         # convert to tensors
  35.         output["input_ids"] = torch.tensor(output["input_ids"], dtype=torch.long)
  36.         output["attention_mask"] = torch.tensor(output["attention_mask"], dtype=torch.long)
  37.         output["token_type_ids"] = torch.tensor(output["token_type_ids"], dtype=torch.long)
  38.         
  39.         return output
复制代码


如何确保在创建数据加载器时,`collate_fn` 函数正确地将 `labels` 包含在 `batch` 中
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-11-3 16:19:22 | 显示全部楼层
Terence888 发表于 2024-11-03 16:19
class CollateCls:
    def __init__(self, cfg):
  ...


要确保在创建数据加载器时,`collate_fn` 函数正确地将 `labels` 包含在 `batch` 中,你需要在 `CollateCls` 类的 `__call__` 方法中处理 `labels`。假设每个样本是一个字典,包含 `input_ids` 和 `labels`,你可以这样修改 `CollateCls`:

  1. class CollateCls:
  2.     def __init__(self, cfg):
  3.         self.tokenizer = cfg.tokenizer
  4.         self.cfg = cfg
  5.         
  6.     def __call__(self, batch):
  7.         input_ids = [sample['input_ids'] for sample in batch]
  8.         labels = [sample['labels'] for sample in batch]

  9.         # calculate max token length of this batch
  10.         batch_max = max([len(ids) for ids in input_ids])

  11.         # add padding
  12.         if self.tokenizer.padding_side == "right":
  13.             input_ids = [
  14.                 list(ids) + (batch_max - len(ids)) * [self.tokenizer.pad_token_id]
  15.                 for ids in input_ids
  16.             ]
  17.         else:
  18.             input_ids = [
  19.                 (batch_max - len(ids)) * [self.tokenizer.pad_token_id] + list(ids)
  20.                 for ids in input_ids
  21.             ]
  22.         
  23.         # convert to tensors
  24.         input_ids = torch.tensor(input_ids, dtype=torch.long)
  25.         labels = torch.tensor(labels, dtype=torch.float)  # or torch.long, depending on your use case
  26.         
  27.         return {'input_ids': input_ids, 'labels': labels}
复制代码


在这个修改后的代码中:

1. **提取 `input_ids` 和 `labels`**:从 `batch` 中分别提取 `input_ids` 和 `labels`。
2. **计算最大长度并填充**:根据 `input_ids` 的最大长度进行填充。
3. **转换为张量**:将 `input_ids` 和 `labels` 转换为 PyTorch 张量。
4. **返回字典**:返回一个包含 `input_ids` 和 `labels` 的字典。

这样,`collate_fn` 函数就能正确地将 `labels` 包含在 `batch` 中。确保在创建数据加载器时使用这个 `collate_fn

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-11-3 16:27:04 | 显示全部楼层
FishC 发表于 2024-11-3 16:19
要确保在创建数据加载器时,`collate_fn` 函数正确地将 `labels` 包含在 `batch` 中,你需要在 `Collat ...
  1. ---------------------------------------------------------------------------
  2. TypeError                                 Traceback (most recent call last)
  3. Cell In[50], line 4
  4.       1 print('Training Start!')
  5.       2 print('=' * 100)
  6. ----> 4 train(model,
  7.       5     device,
  8.       6     train_dataloader,
  9.       7     valid_dataloader,
  10.       8     CFG.epochs,
  11.       9     loss_fn,
  12.      10     optimizer,
  13.      11     metric)
  14.      13 del model,train_dataloader, valid_dataloader
  15.      14 gc.collect()

  16. Cell In[40], line 17, in train(model, device, train_dataloader, valid_dataloader, epochs, loss_fn, optimizer, metric)
  17.      14 train_step = 0
  18.      15 pbar = tqdm(train_dataloader, total=len(train_dataloader))#tqdm参数是一个iterable
  19. ---> 17 for batch in pbar: # you can also write like "for batch in tqdm(train_dataloader"
  20.      18     optimizer.zero_grad() # initialize
  21.      19     train_step += 1

  22. File /opt/conda/lib/python3.10/site-packages/tqdm/notebook.py:250, in tqdm_notebook.__iter__(self)
  23.     248 try:
  24.     249     it = super().__iter__()
  25. --> 250     for obj in it:
  26.     251         # return super(tqdm...) will not catch exception
  27.     252         yield obj
  28.     253 # NB: except ... [ as ...] breaks IPython async KeyboardInterrupt

  29. File /opt/conda/lib/python3.10/site-packages/tqdm/std.py:1181, in tqdm.__iter__(self)
  30.    1178 time = self._time
  31.    1180 try:
  32. -> 1181     for obj in iterable:
  33.    1182         yield obj
  34.    1183         # Update and possibly print the progressbar.
  35.    1184         # Note: does not call self.update(1) for speed optimisation.

  36. File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
  37.     627 if self._sampler_iter is None:
  38.     628     # TODO(https://github.com/pytorch/pytorch/issues/76750)
  39.     629     self._reset()  # type: ignore[call-arg]
  40. --> 630 data = self._next_data()
  41.     631 self._num_yielded += 1
  42.     632 if self._dataset_kind == _DatasetKind.Iterable and \
  43.     633         self._IterableDataset_len_called is not None and \
  44.     634         self._num_yielded > self._IterableDataset_len_called:

  45. File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1344, in _MultiProcessingDataLoaderIter._next_data(self)
  46.    1342 else:
  47.    1343     del self._task_info[idx]
  48. -> 1344     return self._process_data(data)

  49. File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1370, in _MultiProcessingDataLoaderIter._process_data(self, data)
  50.    1368 self._try_put_index()
  51.    1369 if isinstance(data, ExceptionWrapper):
  52. -> 1370     data.reraise()
  53.    1371 return data

  54. File /opt/conda/lib/python3.10/site-packages/torch/_utils.py:706, in ExceptionWrapper.reraise(self)
  55.     702 except TypeError:
  56.     703     # If the exception takes multiple arguments, don't try to
  57.     704     # instantiate since we don't know how to
  58.     705     raise RuntimeError(msg) from None
  59. --> 706 raise exception

  60. TypeError: Caught TypeError in DataLoader worker process 0.
  61. Original Traceback (most recent call last):
  62.   File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
  63.     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  64.   File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
  65.     return self.collate_fn(data)
  66.   File "/tmp/ipykernel_30/3299407265.py", line 89, in __call__
  67.     output['labels'] = torch.tensor(output['labels'], dtype=torch.long)
  68. TypeError: only integer tensors of a single element can be converted to an index
复制代码

再次报错如上
函数如何修改
  1. class CollateCls:
  2.     def __init__(self, cfg):
  3.         self.tokenizer = cfg.tokenizer
  4.         self.cfg = cfg
  5.         
  6.     def __call__(self, batch):
  7.         output = dict()
  8.         output["input_ids"] = [sample["input_ids"] for sample in batch]
  9.         output["attention_mask"] = [sample["attention_mask"] for sample in batch]
  10.         output["token_type_ids"] = [sample["token_type_ids"] for sample in batch]
  11.         output['labels'] = [sample["labels"] for sample in batch]
  12.         # calculate max token length of this batch
  13.         batch_max = max([len(ids) for ids in output["input_ids"]])

  14.         # add padding
  15.         if self.tokenizer.padding_side == "right":
  16.             output["input_ids"] = [
  17.                 list(s) + (batch_max - len(s)) * [self.tokenizer.pad_token_id]
  18.                 for s in output["input_ids"]
  19.             ]
  20.             output["attention_mask"] = [
  21.                 list(s) + (batch_max - len(s)) * [0] for s in output["attention_mask"]
  22.             ]
  23.             output["token_type_ids"] = [list(s) + (batch_max - len(s)) * [0] for s in output["token_type_ids"]]

  24.         else:
  25.             output["input_ids"] = [
  26.                 (batch_max - len(s)) * [self.tokenizer.pad_token_id] + list(s)
  27.                 for s in output["input_ids"]
  28.             ]
  29.             output["attention_mask"] = [
  30.                 (batch_max - len(s)) * [0] + list(s) for s in output["attention_mask"]
  31.             ]
  32.             output["token_type_ids"] = [(batch_max - len(s)) * [0] + list(s) for s in output["token_type_ids"]]
  33.         
  34.         
  35.         # convert to tensors
  36.         output["input_ids"] = torch.tensor(output["input_ids"], dtype=torch.long)
  37.         output["attention_mask"] = torch.tensor(output["attention_mask"], dtype=torch.long)
  38.         output["token_type_ids"] = torch.tensor(output["token_type_ids"], dtype=torch.long)
  39.         output['labels'] = torch.tensor(output['labels'],dtype=torch.float).unsqueeze(dim=0)
  40.         
  41.         return output
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-11-3 16:27:16 | 显示全部楼层
Terence888 发表于 2024-11-03 16:27
-------------------------------------------------- ...


从错误信息来看,`TypeError: only integer tensors of a single element can be converted to an index` 表示在尝试将 `output` 转换为张量时出现了问题。这个问题通常是因为 `output` 的结构不正确,可能是因为 `output` 包含了不正确的数据类型或格式。

以下是如何修改 `CollateCls` 类的 `__call__` 方法,以确保正确处理 `batch` 中的 `input_ids` 和 `labels`:

  1. class CollateCls:
  2.     def __init__(self, cfg):
  3.         self.tokenizer = cfg.tokenizer
  4.         self.cfg = cfg
  5.         
  6.     def __call__(self, batch):
  7.         input_ids = [sample['input_ids'] for sample in batch]
  8.         labels = [sample['labels'] for sample in batch]

  9.         # calculate max token length of this batch
  10.         batch_max = max([len(ids) for ids in input_ids])

  11.         # add padding
  12.         if self.tokenizer.padding_side == "right":
  13.             input_ids = [
  14.                 ids + [self.tokenizer.pad_token_id] * (batch_max - len(ids))
  15.                 for ids in input_ids
  16.             ]
  17.         else:
  18.             input_ids = [
  19.                 [self.tokenizer.pad_token_id] * (batch_max - len(ids)) + ids
  20.                 for ids in input_ids
  21.             ]
  22.         
  23.         # convert to tensors
  24.         input_ids = torch.tensor(input_ids, dtype=torch.long)
  25.         labels = torch.tensor(labels, dtype=torch.float)  # or torch.long, depending on your use case
  26.         
  27.         return {'input_ids': input_ids, 'labels': labels}
复制代码


### 关键修改:

1. **提取 `input_ids` 和 `labels`**:从 `batch` 中分别提取 `input_ids` 和 `labels`。
2. **填充 `input_ids`**:根据 `batch_max` 的长度进行填充,使用 `self.tokenizer.pad_token_id` 作为填充值。
3. **转换为张量**:将 `input_ids` 和 `labels` 转换为 PyTorch 张量。
4. **返回字典**:返回一个包含 `input_ids` 和 `labels` 的字典。

确保在数据集的 `__getitem__` 方法中返回的每个样本是一个包含 `input_ids` 和 `labels` 的字典,这样 `collate_fn` 函数才能正确处理这些数据。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-18 16:38

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表