|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
- def prepare_input(cfg, text):
- inputs = cfg.tokenizer.encode_plus(
- text,
- return_tensors=None,
- add_special_tokens=True,
- )
-
- if len(inputs['input_ids']) > CFG.max_len:
- inputs['input_ids'] = inputs['input_ids'][:CFG.max_len]
- inputs['attention_mask'] = inputs['attention_mask'][:CFG.max_len]
- inputs['token_type_ids'] = inputs['token_type_ids'][:CFG.max_len]
- for k, v in inputs.items():
- inputs[k] = torch.tensor(v, dtype=torch.long)
- return inputs
- class TestDataset(Dataset):
- def __init__(self, cfg, df):
- self.cfg = cfg
- self.texts = df["full_text"].values
- def __len__(self):
- return len(self.texts)
- def __getitem__(self, item):
- inputs = prepare_input(self.cfg, self.texts[item])
- return {
- "input_ids": inputs["input_ids"],
- "attention_mask": inputs["attention_mask"],
- "token_type_ids": inputs['token_type_ids'],
- }
- class CollateCls:
- def __init__(self, cfg):
- self.tokenizer = cfg.tokenizer
- self.cfg = cfg
-
- def __call__(self, batch):
- output = dict()
- output["input_ids"] = [sample["input_ids"] for sample in batch]
- output["attention_mask"] = [sample["attention_mask"] for sample in batch]
- output["token_type_ids"] = [sample["token_type_ids"] for sample in batch]
- # calculate max token length of this batch
- batch_max = max([len(ids) for ids in output["input_ids"]])
- # add padding
- if self.tokenizer.padding_side == "right":
- output["input_ids"] = [
- list(s) + (batch_max - len(s)) * [self.tokenizer.pad_token_id]
- for s in output["input_ids"]
- ]
- output["attention_mask"] = [
- list(s) + (batch_max - len(s)) * [0] for s in output["attention_mask"]
- ]
- output["token_type_ids"] = [list(s) + (batch_max - len(s)) * [0] for s in output["token_type_ids"]]
- else:
- output["input_ids"] = [
- (batch_max - len(s)) * [self.tokenizer.pad_token_id] + list(s)
- for s in output["input_ids"]
- ]
- output["attention_mask"] = [
- (batch_max - len(s)) * [0] + list(s) for s in output["attention_mask"]
- ]
- output["token_type_ids"] = [(batch_max - len(s)) * [0] + list(s) for s in output["token_type_ids"]]
-
-
- # convert to tensors
- output["input_ids"] = torch.tensor(output["input_ids"], dtype=torch.long)
- output["attention_mask"] = torch.tensor(output["attention_mask"], dtype=torch.long)
- output["token_type_ids"] = torch.tensor(output["token_type_ids"], dtype=torch.long)
-
- return output
复制代码
其中,- inputs = cfg.tokenizer.encode_plus(
- text,
- return_tensors=None,
- add_special_tokens=True,
- )
-
复制代码
该代码生成的inputs是什么形状的,这个函数会生成哪些列 |
|