鱼C论坛

 找回密码
 立即注册
查看: 419|回复: 11

自然语言处理疑问

[复制链接]
发表于 2024-10-30 23:45:00 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. import re
  2. import string

  3. # Cleaning Functions
  4. def remove_tag(text):
  5.     tag = re.compile(r'@\S+')#匹配连续字符
  6.     return tag.sub(r'',text)#使用sub函数用空串替换

  7. def remove_URL(text):
  8.     # http:... / https:... / www... #匹配网页链接
  9.     url = re.compile(r'https?://\S+|www\.\S+')
  10.     return re.sub(url,'',text)

  11. def remove_html(text):#匹配特殊符号
  12.     # < > / ( )
  13.     html = re.compile(r'<[^>]+>|\([^)]+\)')
  14.     return html.sub(r'',text)

  15. def remove_punct(text):
  16.     # ['!','"','$','%','&',"'",'(',')','*',
  17.     # '+',',','-','.','/',':',';','<','=',
  18.     # '>','?','@','[','\\',']','^','_','`',
  19.     # '{','|','}','~']
  20.     punctuations = list(string.punctuation)
  21.     table = str.maketrans('', '', ''.join(punctuations))
  22.     return text.translate(table)

复制代码


其中,
  1. r'@\S+'
复制代码
是什么意思
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2024-10-30 23:50:17 | 显示全部楼层
`r'@\S+'` 是一个正则表达式,用于匹配以 `@` 开头的连续非空白字符序列。它通常用于去除像 Twitter 或社交媒体平台上的用户标签(例如 `@username`)。

- `@` 匹配字符 `@`。
- `\S` 匹配任何非空白字符(相当于 `[^ \t\n\r\f\v]`)。
- `+` 表示前面的模式可以重复一次或多次。

因此,`r'@\S+'` 可以匹配一个 `@` 符号后面跟随的一个或多个非空白字符,等效于提取并移除诸如 `@username` 这样的用户标签。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-10-30 23:51:51 | 显示全部楼层
FishC 发表于 2024-10-30 23:50
`r'@\S+'` 是一个正则表达式,用于匹配以 `@` 开头的连续非空白字符序列。它通常用于去除像 Twitter 或社交 ...
  1. text.translate(table)
复制代码

这个语句表示什么意思
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-10-30 23:51:59 | 显示全部楼层
Terence888 发表于 2024-10-30 23:51
text.translate(table)
这个语句表示什么意思


`text.translate(table)` 是用来处理字符串的,该语法用于替换字符串中的字符。具体来说,这里 `table` 是通过 `str.maketrans()` 函数生成的一个翻译映射表,用于指定每个字符的替换规则。

在你提供的函数 `remove_punct(text)` 中,`table` 是生成用来去除标点符号的映射表。`str.maketrans('', '', ''.join(punctuations))` 创建了一个转换映射,其中所有在 `punctuations` 列表中的字符都会被替换为空字符(即被删除)。

因此,`text.translate(table)` 会遍历 `text` 中的每一个字符,如果这个字符在 `table` 中存在对应的替换(在这个例子中,被替换成空字符),就进行替换,最终返回一个去除了标点符号的字符串。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-10-31 00:08:44 | 显示全部楼层
FishC 发表于 2024-10-30 23:51
`text.translate(table)` 是用来处理字符串的,该语法用于替换字符串中的字符。具体来说,这里 `table` ...
  1. from torch.utils.data import Dataset
  2. import torch

  3. class TweetDataset(Dataset):
  4.     def __init__(self,df,is_grad,tokenizer):
  5.         self.df = df # Pandas.DataFrame
  6.         self.is_grad = is_grad # True: train,valid / False: test
  7.         self.tokenizer = tokenizer

  8.     def __len__(self):
  9.         return len(self.df) # number of samples

  10.     def __getitem__(self,idx):
  11.         text = self.df.loc[idx]['text'] # extracting text from each row

  12.         encoded_dict = self.tokenizer.encode_plus(
  13.             text,
  14.             add_special_tokens=True,#自动在每个文本前后添加特殊标记(如CLS和SEP)
  15.             padding='max_length',#补0
  16.             truncation=True,#句子长度大于max_length时截断
  17.             max_length=84, # given to the max_length of tokenized text
  18.             return_tensors='pt', # PyTorch
  19.             return_attention_mask=True, # We should put it into the model
  20.         )

  21.         if self.is_grad:
  22.             labels = self.df.loc[idx]['target']
  23.             # [batch,1,max_len(84)] -> [batch,max_len]
  24.             return {'input_ids':encoded_dict['input_ids'].squeeze(),
  25.                     'attention_mask':encoded_dict['attention_mask'].squeeze(),
  26.                     # Our loss_fn wants it to be a "float" type
  27.                     'labels':torch.tensor(labels,dtype=torch.float).unsqueeze(dim=0)}
  28.         else:
  29.             # [batch,1,max_len(84)] -> [batch,max_len]
  30.             return {'input_ids':encoded_dict['input_ids'].squeeze(),
  31.                     'attention_mask':encoded_dict['attention_mask'].squeeze()}
复制代码


上述代码中,编码器的参数
  1. return_attention_mask=True
复制代码
表示什么意思
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-10-31 00:19:44 | 显示全部楼层
Terence888 发表于 2024-10-31 00:08
from torch.utils.data import Dataset
import torch ...


在自然语言处理中,我们经常需要对文本进行填充和截断,以使得所有输入的序列具有相同的长度。当输入序列经过填充(padding)或截断(truncation)处理后,我们需要知道哪些部分是真实的文本,哪些部分是为了调整长度而添加的填充。

`tokenizer.encode_plus()` 函数中的 `return_attention_mask=True` 参数用于生成一个attention mask。attention mask是一个与输入序列长度相同的二进制矩阵,其中:

- `1` 表示实际的文本标记(token)。
- `0` 表示填充值(padding)。

通过这个mask,模型可以在处理填充的输入序列时忽略填充的部分,从而提高效率并确保模型只注意实际的文本内容。

例如,如果你的最大序列长度是8,而输入的文本序列长度是5,那么填充后的序列可能是这样:
  1. [Actual_Token_1, Actual_Token_2, Actual_Token_3, Actual_Token_4, Actual_Token_5, PAD, PAD, PAD]
复制代码


相应的 `attention_mask` 就会是:
  1. [1, 1, 1, 1, 1, 0, 0, 0]
复制代码


这个 `attention_mask` 告诉模型在计算注意力(attention)时忽略那些填充值(即值为0的部分)。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-10-31 10:38:42 | 显示全部楼层
FishC 发表于 2024-10-31 00:19
在自然语言处理中,我们经常需要对文本进行填充和截断,以使得所有输入的序列具有相同的长度。当输入序 ...
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. from transformers import BertForSequenceClassification

  5. # Never Detach Tensor during forward
  6. class TweetsModel(nn.Module):
  7.     '''
  8.     To be honest, under the setting like this, there is no need to inherit.
  9.     It's because I used "BertForSequenceClassification" which has final layer
  10.     that is composed of "hidden size 2" for binary classification.

  11.     So, you can think of this unnecessary inheritance is kind of "practice" for myself :)
  12.     '''   
  13.     def __init__(self,model_name):
  14.         super().__init__()
  15.         self.model = BertForSequenceClassification.from_pretrained(model_name)#from_pretrained方法会加载相应的预训练模型权重

  16.     def forward(self,input_ids,attention_mask):
  17.         output = self.model(input_ids=input_ids,attention_mask=attention_mask)
  18.         logits = output.logits
  19.         return logits
复制代码


最后一行logits是计算什么的
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-10-31 10:38:48 | 显示全部楼层
Terence888 发表于 2024-10-31 10:38
import numpy as np
import torch
import torch.nn  ...


在这段代码中,`logits` 是模型的输出,表示未经过激活函数处理的原始预测值。具体来说:

- `BertForSequenceClassification` 是一个用于序列分类任务的预训练 BERT 模型。它的最后一层是一个线性层,用于将 BERT 的输出转换为分类任务的输出。
- `logits` 是这个线性层的输出,通常用于二分类或多分类任务。在二分类任务中,`logits` 通常包含两个值,分别对应于两个类别的分数。
- 这些分数可以通过激活函数(如 softmax)转换为概率,用于最终的分类决策。

在训练过程中,`logits` 通常会被传递给损失函数(如交叉熵损失),以计算模型的损失并进行反向传播。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-10-31 10:50:17 | 显示全部楼层
FishC 发表于 2024-10-31 10:38
在这段代码中,`logits` 是模型的输出,表示未经过激活函数处理的原始预测值。具体来说:

- `BertFo ...
  1. # optimizer
  2. from transformers import AdamW

  3. optimizer = AdamW(model.parameters(),
  4.                 lr=6e-6,
  5.                 eps=1e-8,
  6.                 no_deprecation_warning=True)
复制代码


上述是不是gpt模型?优化器最后一个参数
  1. no_deprecation_warning=True
复制代码
表示什么意思
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-10-31 18:29:23 | 显示全部楼层
Terence888 发表于 2024-10-31 10:50
上述是不是gpt模型?优化器最后一个参数表示什么意思

可以自己查
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-10-31 20:56:42 | 显示全部楼层
FishC 发表于 2024-10-31 10:38
在这段代码中,`logits` 是模型的输出,表示未经过激活函数处理的原始预测值。具体来说:

- `BertFo ...
  1. import gc,os
  2. from tqdm.auto import tqdm # visualizing tool for progress

  3. # They will be used to pick the best model.pt given to the valid loss
  4. best_model_epoch, valid_loss_values = [],[]
  5. valid_loss_min = [1] # arbitrary loss I set here
  6. def train(model,device,train_dataloader,valid_dataloader,epochs,loss_fn,optimizer,metric):

  7.     for epoch in range(epochs):
  8.         gc.collect() # memory cleaning垃圾回收机制,减少占用内存
  9.         model.train()

  10.         train_loss = 0
  11.         train_step = 0
  12.         pbar = tqdm(train_dataloader)#tqdm参数是一个iterable

  13.         for batch in pbar: # you can also write like "for batch in tqdm(train_dataloader"
  14.             optimizer.zero_grad() # initialize
  15.             train_step += 1

  16.             train_input_ids = batch['input_ids'].to(device)
  17.             train_attention_mask = batch['attention_mask'].to(device)
  18.             train_labels = batch['labels'].squeeze().to(device).long()
  19.             
  20.             # You can refer to the class "TweetsModel" for understand
  21.             # what would be logits
  22.             logits = model(train_input_ids, train_attention_mask).to(device)
  23.             predictions = torch.argmax(logits, dim=1) # get an index from larger one
  24.             detached_predictions = predictions.detach().cpu().numpy()
  25.             
  26.             loss = loss_fn(logits, train_labels)
  27.             loss.backward()
  28.             optimizer.step()
  29.             model.zero_grad()

  30.             train_loss += loss.detach().cpu().numpy().item()

  31.             pbar.set_postfix({'train_loss':train_loss/train_step})
  32.         pbar.close()

  33.         with torch.no_grad():
  34.             model.eval()

  35.             valid_loss = 0
  36.             valid_step = 0
  37.             total_valid_score = 0

  38.             y_pred = [] # for getting f1_score that is a metric of the competition
  39.             y_true = []

  40.             pbar = tqdm(valid_dataloader)
  41.             for batch in pbar:
  42.                 valid_step += 1

  43.                 valid_input_ids = batch['input_ids'].to(device)
  44.                 valid_attention_mask = batch['attention_mask'].to(device)
  45.                 valid_labels = batch['labels'].squeeze().to(device).long()

  46.                 logits = model(valid_input_ids, valid_attention_mask).to(device)
  47.                 predictions = torch.argmax(logits, dim=1)
  48.                 detached_predictions = predictions.detach().cpu().numpy()
  49.                
  50.                 loss = loss_fn(logits, valid_labels)
  51.                 valid_loss += loss.detach().cpu().numpy().item()

  52.                 y_pred.extend(predictions.cpu().numpy())
  53.                 y_true.extend(valid_labels.cpu().numpy())

  54.             valid_loss /= valid_step
  55.             f1 = f1_score(y_true,y_pred)

  56.             print(f'Epoch [{epoch+1}/{epochs}] Score: {f1}')
  57.             print(f'Epoch [{epoch+1}/{epochs}] Valid_loss: {valid_loss}')

  58.             if valid_loss < min(valid_loss_min):
  59.                 print('model improved!')
  60.             else:
  61.                 print('model not improved')
  62.    
  63.             torch.save(model.state_dict(), f'epoch:{epoch+1}_model.pt')
  64.             print('save checkpoint!')
  65.             valid_loss_min.append(valid_loss)
  66.             print(f'valid_loss_min:{min(valid_loss_min)}')

  67.         best_model_epoch.append(f'/kaggle/working/epoch:{epoch+1}_model.pt')
  68.         valid_loss_values.append(valid_loss)
  69.         print('='*100)

  70.     select_best_model() # refer to below function
  71.     print('Train/Valid Completed!!')
  72.     del train_dataloader, valid_dataloader # memory cleaning
  73.     gc.collect()

  74. def select_best_model():
  75.     best_model = best_model_epoch[np.array(valid_loss_values).argmin()]
  76.     os.rename(best_model, best_model.split('.pt')[0] + '_best.pt')
复制代码


其中最后一行不理解
  1. os.rename(best_model, best_model.split('.pt')[0] + '_best.pt')
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-10-31 21:03:09 | 显示全部楼层
Terence888 发表于 2024-10-31 20:56
import gc,os
from tqdm.auto import tqdm # visuali ...


在代码的最后一行 `os.rename(best_model, best_model.split('.pt') + '_best.pt')` 中,`os.rename` 函数用于重命名文件或目录。

然而,这行代码有一些问题:

1. `best_model` 是一个列表(`best_model_epoch`),而不是一个字符串。`os.rename` 需要两个字符串参数:源文件名和目标文件名。
2. `best_model.split('.pt')` 试图对列表进行字符串操作,这是不正确的。
3. `split('.pt')` 返回一个列表,因此 `+ '_best.pt'` 也会导致错误。

要修复这段代码,首先需要确保 `best_model` 是一个字符串,代表要重命名的文件路径。假设 `best_model_epoch` 中存储的是文件路径的字符串列表,代码应该如下:

  1. def select_best_model():
  2.     best_model = best_model_epoch[-1]  # 假设最后一个是最好的模型
  3.     new_name = best_model.replace('.pt', '_best.pt')
  4.     os.rename(best_model, new_name)
复制代码


这样,`best_model` 是一个字符串,`replace` 方法用于生成新的文件名,并且 `os.rename` 可以正确执行重命名操作。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-18 16:59

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表