|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
这个是我跑的代码
from data import build_corpus
from evaluate import bilstm_train_and_eval
from utils import extend_maps,prepocess_data_for_lstmcrf, save_obj, load_obj
print("读取数据中...")
train_word_lists,train_tag_lists,word2id,tag2id = build_corpus("train")
dev_word_lists,dev_tag_lists = build_corpus("dev",make_vocab=False)
test_word_lists,test_tag_lists = build_corpus("test",make_vocab=False)
print("正在训练评估Bi-LSTM+CRF模型...")
crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True)
save_obj(crf_word2id, 'crf_word2id')
save_obj(crf_tag2id, 'crf_tag2id')
print(' '.join([i[0] for i in crf_tag2id.items()]))
train_word_lists, train_tag_lists = prepocess_data_for_lstmcrf(
train_word_lists, train_tag_lists
)
dev_word_lists, dev_tag_lists = prepocess_data_for_lstmcrf(
dev_word_lists, dev_tag_lists
)
test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
test_word_lists, test_tag_lists, test=True
)
lstmcrf_pred = bilstm_train_and_eval(
(train_word_lists, train_tag_lists),
(dev_word_lists, dev_tag_lists),
(test_word_lists, test_tag_lists),
crf_word2id, crf_tag2id
)
结果报错
Traceback (most recent call last):
File "C:\Users\qwert\Desktop\模型\pytorchbilstmcrf-information-extraction-main\pytorchbilstmcrf-information-extraction-main\main.py", line 56, in <module>
crf_word2id, crf_tag2id
File "C:\Users\qwert\Desktop\模型\pytorchbilstmcrf-information-extraction-main\pytorchbilstmcrf-information-extraction-main\evaluate.py", line 23, in bilstm_train_and_eval
bilstm_operator.train(train_word_lists,train_tag_lists,dev_word_lists,dev_tag_lists,word2id,tag2id)
File "C:\Users\qwert\Desktop\模型\pytorchbilstmcrf-information-extraction-main\pytorchbilstmcrf-information-extraction-main\operate_bilstm.py", line 40, in train
word_lists, tag_lists, _ = sort_by_lengths(word_lists, tag_lists)
File "C:\Users\qwert\Desktop\模型\pytorchbilstmcrf-information-extraction-main\pytorchbilstmcrf-information-extraction-main\utils.py", line 10, in sort_by_lengths
word_lists, tag_lists = list(zip(*pairs))
ValueError: not enough values to unpack (expected 2, got 0)
之后我检查了测试集和训练集,里边都没有空行
根据报错找到utils.py文件
import pickle
import torch
def sort_by_lengths(word_lists,tag_lists):
pairs = list(zip(word_lists, tag_lists))
indices = sorted(range(len(pairs)), key=lambda x: len(pairs[x][0]), reverse=True)
pairs = [pairs[i] for i in indices]
word_lists, tag_lists = list(zip(*pairs))
return word_lists, tag_lists, indices
def tensorized(batch, maps):
PAD = maps.get('<pad>')
UNK = maps.get('<unk>')
max_len = len(batch[0])
batch_size = len(batch)
batch_tensor = torch.ones(batch_size, max_len).long() * PAD
for i, l in enumerate(batch):
for j, e in enumerate(l):
batch_tensor[i][j] = maps.get(e, UNK)
lengths = [len(l) for l in batch]
return batch_tensor, lengths
def save_obj(obj, name):
with open('data/'+ name + '.pkl', 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def load_obj(name):
with open('data/' + name + '.pkl', 'rb') as f:
return pickle.load(f)
def prepocess_data_for_lstmcrf(word_lists, tag_lists, test=False):
assert len(word_lists) == len(tag_lists)
for i in range(len(word_lists)):
word_lists[i].append("<end>")
if not test: # 如果是测试数据,就不需要加end token了
tag_lists[i].append("<end>")
return word_lists, tag_lists
def flatten_lists(lists):
"""将list of list 压平成list"""
flatten_list = []
for list_ in lists:
if type(list_) == list:
flatten_list.extend(list_)
else:
flatten_list.append(list_)
return flatten_list
def extend_maps(word2id, tag2id, for_crf=True):
word2id['<unk>'] = len(word2id)
word2id['<pad>'] = len(word2id)
tag2id['<unk>'] = len(tag2id)
tag2id['<pad>'] = len(tag2id)
# 如果是加了CRF的bilstm 那么还要加入<start> 和 <end>token
if for_crf:
word2id['<start>'] = len(word2id)
word2id['<end>'] = len(word2id)
tag2id['<start>'] = len(tag2id)
tag2id['<end>'] = len(tag2id)
return word2id, tag2id
def save_model(model,file_name):
with open(file_name,'wb') as f:
pickle.dump(model,f)
f.close()
根据报错信息,问题出现在`sort_by_lengths`函数里。函数期望的入参是两个列表 `word_lists` 和 `tag_lists`,但是它从 `pairs` 中返回的观测值为空,导致 unpack 时没有获得所需的两个列表。我怀疑是`train_word_lists,train_tag_lists`未正常赋值或加载出来,建议您对其进行打印和检查。同时请确保这些表格已通过正确的格式传递到build_corpus方法中。如果使用的数据不是CSV等常见格式,则可能需要对它们进行特殊预处理以使其能够被正确读取。
|
|