import tensorflow as tf
import numpy as np
import os
# 读取数据
def load_data(file_path):
data = []
labels = []
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line.strip().split('\t')
data.append(line[0])
labels.append(line[1])
return data, labels
# 数据预处理
def preprocess(data, labels):
# 构建词汇表
word2idx = {}
for sentence in data:
for word in sentence:
if word not in word2idx:
word2idx[word] = len(word2idx) + 1 # 0为padding的索引
label2idx = {"B": 0, "I": 1, "O": 2}
# 将句子转化为索引序列,并对标签进行编码
data_ids = [[word2idx[word] for word in sentence] for sentence in data]
label_ids = [[label2idx[label] for label in sentence] for sentence in labels]
# 计算最大句子长度
max_len = max([len(sentence) for sentence in data_ids])
# 填充句子长度
padded_data = tf.keras.preprocessing.sequence.pad_sequences(data_ids, maxlen=max_len, padding='post', value=0)
padded_labels = tf.keras.preprocessing.sequence.pad_sequences(label_ids, maxlen=max_len, padding='post', value=2)
return padded_data, padded_labels, max_len, word2idx, label2idx
# 构建模型
def build_model(input_dim, output_dim):
model = tf.keras.Sequential([
tf.keras.layers.Embedding(input_dim, 128),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True)),
tf.keras.layers.Dense(output_dim, activation='softmax')
])
return model
# 训练模型
def train(model, x_train, y_train, max_len):
# 配置训练参数
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, batch_size=32, epochs=10)
# 预测函数
def predict(model, sentence, max_len, word2idx, label2idx):
sentence = sentence.strip().split(' ')
input_ids = [word2idx.get(word, 0) for word in sentence]
input_ids = tf.keras.preprocessing.sequence.pad_sequences([input_ids], maxlen=max_len, padding='post', value=0)
pred_ids = model.predict(input_ids)[0]
preds = [label2idx[np.argmax(pred)] for pred in pred_ids]
result = []
entity = ""
for i, pred in enumerate(preds):
if pred == 0:
if entity != "":
result.append(entity)
entity = ""
elif pred == 1:
entity += sentence[i]
else:
continue
return result
if __name__ == '__main__':
# 加载数据
data, labels = load_data('data.txt')
# 数据预处理
x_train, y_train, max_len, word2idx, label2idx = preprocess(data, labels)
# 构建模型
model = build_model(len(word2idx)+1, len(label2idx))
# 训练模型
train(model, x_train, y_train, max_len)
# 进行预测
sentence = input("请输入句子:")
result = predict(model, sentence, max_len, word2idx, label2idx)
print("识别结果:", result)