鱼C论坛

 找回密码
 立即注册
查看: 350|回复: 1

transformer编码器解码器架构疑问

[复制链接]
发表于 2024-11-2 20:57:51 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. def prepare_input(cfg, text):
  2.     inputs = cfg.tokenizer.encode_plus(
  3.         text,
  4.         return_tensors=None,
  5.         add_special_tokens=True,
  6.     )
  7.    
  8.     if len(inputs['input_ids']) > CFG.max_len:
  9.         inputs['input_ids'] = inputs['input_ids'][:CFG.max_len]
  10.         inputs['attention_mask'] = inputs['attention_mask'][:CFG.max_len]
  11.         inputs['token_type_ids'] = inputs['token_type_ids'][:CFG.max_len]

  12.     for k, v in inputs.items():
  13.         inputs[k] = torch.tensor(v, dtype=torch.long)
  14.     return inputs


  15. class TestDataset(Dataset):
  16.     def __init__(self, cfg, df):
  17.         self.cfg = cfg
  18.         self.texts = df["full_text"].values

  19.     def __len__(self):
  20.         return len(self.texts)

  21.     def __getitem__(self, item):
  22.         inputs = prepare_input(self.cfg, self.texts[item])

  23.         return {
  24.             "input_ids": inputs["input_ids"],
  25.             "attention_mask": inputs["attention_mask"],
  26.             "token_type_ids": inputs['token_type_ids'],
  27.         }


  28. class CollateCls:
  29.     def __init__(self, cfg):
  30.         self.tokenizer = cfg.tokenizer
  31.         self.cfg = cfg
  32.         
  33.     def __call__(self, batch):
  34.         output = dict()
  35.         output["input_ids"] = [sample["input_ids"] for sample in batch]
  36.         output["attention_mask"] = [sample["attention_mask"] for sample in batch]
  37.         output["token_type_ids"] = [sample["token_type_ids"] for sample in batch]
  38.         # calculate max token length of this batch
  39.         batch_max = max([len(ids) for ids in output["input_ids"]])

  40.         # add padding
  41.         if self.tokenizer.padding_side == "right":
  42.             output["input_ids"] = [
  43.                 list(s) + (batch_max - len(s)) * [self.tokenizer.pad_token_id]
  44.                 for s in output["input_ids"]
  45.             ]
  46.             output["attention_mask"] = [
  47.                 list(s) + (batch_max - len(s)) * [0] for s in output["attention_mask"]
  48.             ]
  49.             output["token_type_ids"] = [list(s) + (batch_max - len(s)) * [0] for s in output["token_type_ids"]]

  50.         else:
  51.             output["input_ids"] = [
  52.                 (batch_max - len(s)) * [self.tokenizer.pad_token_id] + list(s)
  53.                 for s in output["input_ids"]
  54.             ]
  55.             output["attention_mask"] = [
  56.                 (batch_max - len(s)) * [0] + list(s) for s in output["attention_mask"]
  57.             ]
  58.             output["token_type_ids"] = [(batch_max - len(s)) * [0] + list(s) for s in output["token_type_ids"]]
  59.         
  60.         
  61.         # convert to tensors
  62.         output["input_ids"] = torch.tensor(output["input_ids"], dtype=torch.long)
  63.         output["attention_mask"] = torch.tensor(output["attention_mask"], dtype=torch.long)
  64.         output["token_type_ids"] = torch.tensor(output["token_type_ids"], dtype=torch.long)
  65.         
  66.         return output
复制代码


其中,
  1. inputs = cfg.tokenizer.encode_plus(
  2.         text,
  3.         return_tensors=None,
  4.         add_special_tokens=True,
  5.     )
  6.    
复制代码

该代码生成的inputs是什么形状的,
  1. tokenizer.encode_plus
复制代码
这个函数会生成哪些列
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2024-11-2 20:58:02 | 显示全部楼层
`tokenizer.encode_plus` 是一个用于将文本编码为模型可接受格式的函数。它通常会生成一个字典,包含以下几个键值对:

1. **`input_ids`**: 这是文本被编码后的 token ID 列表,表示输入文本的实际内容。

2. **`token_type_ids`**: 这是一个可选的列表,用于区分不同的句子(在 BERT 中用于区分句子 A 和句子 B)。

3. **`attention_mask`**: 这是一个可选的列表,通常用于指示哪些 token 是填充(padding)token,哪些是实际的输入 token。填充值为 0,实际输入值为 1。

4. **`special_tokens_mask`**: 这是一个可选的列表,用于指示哪些 token 是特殊 token(如 [CLS], [SEP])。

5. **`offset_mapping`**: 这是一个可选的列表,表示每个 token 在原始文本中的起始和结束位置。

在你的代码中,`inputs` 的形状取决于输入文本的长度和 `max_len` 的设置。通常情况下,`input_ids`、`token_type_ids` 和 `attention_mask` 的长度会与输入文本的 token 数量一致。

如果 `inputs` 的长度超过 `CFG.max_len`,代码中似乎没有处理逻辑(可能是代码片段不完整或有误),需要根据实际需求进行截断或填充。

如果你有更多问题或需要进一步的帮助,请随时告诉我!

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-18 17:09

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表