Terence888 发表于 2024-11-2 20:57:51

transformer编码器解码器架构疑问

def prepare_input(cfg, text):
    inputs = cfg.tokenizer.encode_plus(
      text,
      return_tensors=None,
      add_special_tokens=True,
    )
   
    if len(inputs['input_ids']) > CFG.max_len:
      inputs['input_ids'] = inputs['input_ids'][:CFG.max_len]
      inputs['attention_mask'] = inputs['attention_mask'][:CFG.max_len]
      inputs['token_type_ids'] = inputs['token_type_ids'][:CFG.max_len]

    for k, v in inputs.items():
      inputs = torch.tensor(v, dtype=torch.long)
    return inputs


class TestDataset(Dataset):
    def __init__(self, cfg, df):
      self.cfg = cfg
      self.texts = df["full_text"].values

    def __len__(self):
      return len(self.texts)

    def __getitem__(self, item):
      inputs = prepare_input(self.cfg, self.texts)

      return {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "token_type_ids": inputs['token_type_ids'],
      }


class CollateCls:
    def __init__(self, cfg):
      self.tokenizer = cfg.tokenizer
      self.cfg = cfg
      
    def __call__(self, batch):
      output = dict()
      output["input_ids"] = for sample in batch]
      output["attention_mask"] = for sample in batch]
      output["token_type_ids"] = for sample in batch]
      # calculate max token length of this batch
      batch_max = max(])

      # add padding
      if self.tokenizer.padding_side == "right":
            output["input_ids"] = [
                list(s) + (batch_max - len(s)) *
                for s in output["input_ids"]
            ]
            output["attention_mask"] = [
                list(s) + (batch_max - len(s)) * for s in output["attention_mask"]
            ]
            output["token_type_ids"] = for s in output["token_type_ids"]]

      else:
            output["input_ids"] = [
                (batch_max - len(s)) * + list(s)
                for s in output["input_ids"]
            ]
            output["attention_mask"] = [
                (batch_max - len(s)) * + list(s) for s in output["attention_mask"]
            ]
            output["token_type_ids"] = [(batch_max - len(s)) * + list(s) for s in output["token_type_ids"]]
      
      
      # convert to tensors
      output["input_ids"] = torch.tensor(output["input_ids"], dtype=torch.long)
      output["attention_mask"] = torch.tensor(output["attention_mask"], dtype=torch.long)
      output["token_type_ids"] = torch.tensor(output["token_type_ids"], dtype=torch.long)
      
      return output

其中,inputs = cfg.tokenizer.encode_plus(
      text,
      return_tensors=None,
      add_special_tokens=True,
    )
   
该代码生成的inputs是什么形状的,tokenizer.encode_plus这个函数会生成哪些列

FishC 发表于 2024-11-2 20:58:02

`tokenizer.encode_plus` 是一个用于将文本编码为模型可接受格式的函数。它通常会生成一个字典,包含以下几个键值对:

1. **`input_ids`**: 这是文本被编码后的 token ID 列表,表示输入文本的实际内容。

2. **`token_type_ids`**: 这是一个可选的列表,用于区分不同的句子(在 BERT 中用于区分句子 A 和句子 B)。

3. **`attention_mask`**: 这是一个可选的列表,通常用于指示哪些 token 是填充(padding)token,哪些是实际的输入 token。填充值为 0,实际输入值为 1。

4. **`special_tokens_mask`**: 这是一个可选的列表,用于指示哪些 token 是特殊 token(如 , )。

5. **`offset_mapping`**: 这是一个可选的列表,表示每个 token 在原始文本中的起始和结束位置。

在你的代码中,`inputs` 的形状取决于输入文本的长度和 `max_len` 的设置。通常情况下,`input_ids`、`token_type_ids` 和 `attention_mask` 的长度会与输入文本的 token 数量一致。

如果 `inputs` 的长度超过 `CFG.max_len`,代码中似乎没有处理逻辑(可能是代码片段不完整或有误),需要根据实际需求进行截断或填充。

如果你有更多问题或需要进一步的帮助,请随时告诉我!

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
页: [1]
查看完整版本: transformer编码器解码器架构疑问