鱼C论坛

 找回密码
 立即注册
查看: 3687|回复: 7

[技术交流] python实现双向LSTM+CRF【tensorflow2.4.0】【命名实体识别】

[复制链接]
发表于 2021-11-14 15:03:49 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 糖逗 于 2021-12-28 19:52 编辑

一、数据预处理
  1. char_vocab_path = "E:/.../1.NLP/zh-nlp-demo-master/data/char_vocabs.txt" # 字典文件
  2. train_data_path = "E:/.../1.NLP/地址识别项目/data/train.conll" # 训练数据
  3. test_data_path = "E:/.../1.NLP/地址识别项目/data/dev.conll" # 测试数据

  4. special_words = ['<PAD>', '<UNK>'] # 特殊词表示
  5. '''
  6. <UNK>: 低频词或未在词表中的词
  7. <PAD>: 补全字符
  8. <GO>/<SOS>: 句子起始标识符
  9. <EOS>: 句子结束标识符
  10. [SEP]:两个句子之间的分隔符
  11. [MASK]:填充被掩盖掉的字符
  12. '''

  13. # "BIO"标记的标签
  14. import pandas as pd
  15. store = pd.read_table(r"E:\工作\7.理论学习\1.NLP\地址识别项目\data\mytag.dic", header = None)
  16. store.loc[0, 0] = "O"
  17. store.loc[24, 0] = "B-prov"

  18. store1 = store.to_dict()
  19. idx2label = store1[0]
  20. # 索引和BIO标签对应
  21. label2idx = {idx: label for label, idx in idx2label.items()}
  22. print(label2idx)
  23. # 读取字符词典文件
  24. with open(char_vocab_path, "r", encoding="utf8") as fo:
  25.     char_vocabs = [line.strip() for line in fo]
  26. char_vocabs = special_words + char_vocabs

  27. # 字符和索引编号对应
  28. idx2vocab = {idx: char for idx, char in enumerate(char_vocabs)}
  29. vocab2idx = {char: idx for idx, char in idx2vocab.items()}
复制代码

二、数据统计描述
  1. import pandas as pd
  2. temp = pd.read_table(r"E:\...\1.NLP\地址识别项目\data\train.txt", header = None)
  3. temp["长度"] = temp.loc[:, 0].apply(lambda x: len(str(x)))
  4. print(temp.head())
  5. print(max(temp["长度"]))

  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. %matplotlib inline
  9. plt.hist(np.array(temp["长度"]), bins=10, rwidth=0.9, density=True)
复制代码


三、模型定义
  1. import tensorflow as tf
  2. import tensorflow_addons as tfa
  3. print(tf.__version__)
  4. print(tfa.__version__)
  5. from tensorflow import keras
  6. from tensorflow.keras import layers, models
  7. from tensorflow.keras import backend as K

  8. class CRF(layers.Layer):
  9.     def __init__(self, label_size):
  10.         super(CRF, self).__init__()
  11.         self.trans_params = tf.Variable(
  12.             tf.random.uniform(shape=(label_size, label_size)), name="transition")
  13.    
  14.     @tf.function
  15.     def call(self, inputs, labels, seq_lens):
  16.         log_likelihood, self.trans_params = tfa.text.crf_log_likelihood(
  17.                                                 inputs, labels, seq_lens,
  18.                                                 transition_params=self.trans_params)
  19.         loss = tf.reduce_sum(-log_likelihood)
  20.         return loss

  21. from transformers import TFBertForTokenClassification

  22. EPOCHS = 20
  23. BATCH_SIZE = 64
  24. EMBED_DIM = 128
  25. HIDDEN_SIZE = 64
  26. MAX_LEN = 55
  27. VOCAB_SIZE = len(vocab2idx)
  28. CLASS_NUMS = len(label2idx)

  29. inputs = layers.Input(shape=(MAX_LEN,), dtype='int32')
  30. targets = layers.Input(shape=(MAX_LEN,),dtype='int32')
  31. seq_lens = layers.Input(shape=(), dtype='int32')

  32. PRETRAINED_MODEL_NAME = r"D:\bert_model\bert-base-chinese"  # 指定为中文
  33. #x = TFBertForTokenClassification.from_pretrained(PRETRAINED_MODEL_NAME, num_labels = 100)(inputs)

  34. x = layers.Embedding(input_dim=VOCAB_SIZE, output_dim=EMBED_DIM, mask_zero=True)(inputs)
  35. x = layers.Bidirectional(layers.LSTM(HIDDEN_SIZE, return_sequences=True))(x)
  36. print(x.shape)
  37. logits = layers.Dense(CLASS_NUMS)(x)
  38. loss = CRF(label_size=CLASS_NUMS)(logits, targets, seq_lens)

  39. model = models.Model(inputs=[inputs, targets, seq_lens], outputs=loss)

  40. print(model.summary())
  41. model.compile(loss=lambda y_true, y_pred: y_pred, optimizer='adam')#,  metrics=[metric])
复制代码



四、数据处理
  1. from tensorflow.keras.preprocessing import sequence
  2. import numpy as np

  3. # 读取训练语料
  4. def read_corpus(corpus_path, vocab2idx, label2idx):
  5.     datas, labels = [], []
  6.     with open(corpus_path, encoding='utf-8') as fr:
  7.         lines = fr.readlines()
  8.     sent_, tag_ = [], []
  9.     for line in lines:
  10.         if line != '\n':
  11.             char, label = line.strip().split()
  12.             sent_.append(char)
  13.             tag_.append(label)
  14.         else:
  15.             sent_ids = [vocab2idx[char] if char in vocab2idx else vocab2idx['<UNK>'] for char in sent_]
  16.             tag_ids = [label2idx[label] if label in label2idx else 0 for label in tag_]
  17.             datas.append(sent_ids)
  18.             labels.append(tag_ids)
  19.             sent_, tag_ = [], []
  20.     return datas, labels

  21. # 加载训练集
  22. train_datas, train_labels = read_corpus(train_data_path, vocab2idx, label2idx)
  23. # 加载测试集
  24. test_datas, test_labels = read_corpus(test_data_path, vocab2idx, label2idx)

  25. train_datas = sequence.pad_sequences(train_datas, maxlen=MAX_LEN, padding='post')
  26. train_labels = sequence.pad_sequences(train_labels, maxlen=MAX_LEN, padding='post')
  27. train_seq_lens = np.array([MAX_LEN] * len(train_labels))
  28. labels = np.ones(len(train_labels))
  29. # train_labels = keras.utils.to_categorical(train_labels, CLASS_NUMS)
  30. test_datas = sequence.pad_sequences(test_datas, maxlen=MAX_LEN, padding = "post")
  31. test_labels = sequence.pad_sequences(test_labels, maxlen=MAX_LEN, padding = "post")
  32. test_seq_lens = np.array([MAX_LEN] * len(test_labels))

  33. print(np.shape(train_datas), np.shape(train_labels))
复制代码


五、模型训练
  1. # 训练
  2. import os
  3. #os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  4. history = model.fit(x=[train_datas, train_labels, train_seq_lens], y=labels,validation_split=0.1, batch_size=BATCH_SIZE, epochs=20)#.history

  5. #acc = model.history['sparse_categorical_accuracy']
  6. #val_acc = model.history['val_sparse_categorical_accuracy']
  7. loss = history['loss']
  8. val_loss = history['val_loss']
  9. print('loss:',loss)
  10. print('val_loss:',val_loss)
  11. import matplotlib.pyplot as plt
  12. from matplotlib.ticker import MaxNLocator
  13. plt.plot(list(range(1, 21)) , history["loss"])
  14. plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
复制代码


六、模型效果查看
  1. trans_params = model.get_layer('crf').get_weights()[0]
  2. # 获得BiLSTM的输出logits
  3. sub_model = models.Model(inputs=model.get_layer('input_1').input,
  4.                         outputs=model.get_layer('dense').output)

  5. def predict(model, inputs, input_lens):
  6.     logits = sub_model.predict(inputs)
  7.     # 获取CRF层的转移矩阵
  8.     # crf_decode:viterbi解码获得结果
  9.     pred_seq, viterbi_score = tfa.text.crf_decode(logits, trans_params, input_lens)
  10.     return pred_seq
  11. test_datas, test_labels = read_corpus(test_data_path, vocab2idx, label2idx)

  12. maxlen = 55
  13. sentence = "北京市西城区阜成门外大街0号万通金融中心0-0层"
  14. sent_chars = list(sentence)
  15. sent2id = [vocab2idx[word] if word in vocab2idx else vocab2idx['<UNK>'] for word in sent_chars]
  16. sent2id_new = np.array([sent2id[:maxlen] + [0] * (maxlen-len(sent2id))])
  17. test_lens = np.array([55])

  18. pred_seq = predict(model, sent2id_new, test_lens)
  19. print(pred_seq)

  20. y_label = pred_seq.numpy().reshape(1, -1)[0]
  21. #print(y_label)
  22. y_ner = [idx2label[i] for i in y_label][0:len(sent_chars)]

  23. #print(sent2id)
  24. print(y_ner)
  25. # 对预测结果进行命名实体解析和提取
  26. def get_valid_nertag(input_data, result_tags):
  27.     result_words = []
  28.     start, end =0, 1 # 实体开始结束位置标识
  29.     tag_label = "O" # 实体类型标识
  30.     for i, tag in enumerate(result_tags):
  31.         if tag.startswith("B"):
  32.             if tag_label != "O": # 当前实体tag之前有其他实体
  33.                 result_words.append((input_data[start: end], tag_label)) # 获取实体
  34.             tag_label = tag.split("-")[1] # 获取当前实体类型
  35.             start, end = i, i+1 # 开始和结束位置变更
  36.         elif tag.startswith("I"):
  37.             temp_label = tag.split("-")[1]
  38.             if temp_label == tag_label: # 当前实体tag是之前实体的一部分
  39.                 end += 1 # 结束位置end扩展
  40.         elif tag == "O":
  41.             if tag_label != "O": # 当前位置非实体 但是之前有实体
  42.                 result_words.append((input_data[start: end], tag_label)) # 获取实体
  43.                 tag_label = "O"  # 实体类型置"O"
  44.             start, end = i, i+1 # 开始和结束位置变更
  45.     if tag_label != "O": # 最后结尾还有实体
  46.         result_words.append((input_data[start: end], tag_label)) # 获取结尾的实体
  47.     return result_words

  48. result_words = get_valid_nertag(sent_chars, y_ner)
  49. for (word, tag) in result_words:
  50.     print("".join(word), tag)
复制代码




数据:
游客,如果您要查看本帖隐藏内容请回复

代码参考:https://www.cnblogs.com/huanghaocs/p/14673020.html
预测背景说明:https://tianchi.aliyun.com/compe ... 03.9.493e2448u7nhbg


本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2021-11-14 15:45:18 | 显示全部楼层
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2021-11-16 21:02:14 | 显示全部楼层
66
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2021-11-18 22:08:37 | 显示全部楼层
1

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2021-11-20 13:50:42 | 显示全部楼层
没看懂来学习
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2021-12-28 08:31:21 | 显示全部楼层
数据呢
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2021-12-28 08:32:40 | 显示全部楼层
网盘链接失效了
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2023-12-30 23:28:53 | 显示全部楼层
厉害呀
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-3-29 00:37

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表