Commit 15907570 by 20210828028

DuIE

parent 5edccf13
#!/usr/bin/env python
# coding: utf-8
# # 百度时间抽取baseline 与应用baseline抽取新闻句子中可能存在的关系信息。
# 引入关系抽取模型所需要用到的基础的包。
# In[17]:
import os
import random
import sys
import time
import codecs
import zipfile
import re
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import DataLoader
from tqdm import tqdm
from paddlenlp.transformers import ErnieTokenizer, ErnieForTokenClassification, LinearDecayWithWarmup
from typing import Optional, List, Union, Dict
from dataclasses import dataclass
import json
from extract_chinese_and_punct import ChineseAndPunctuationExtractor
from paddlenlp.utils.log import logger
import collections
# In[18]:
device = 'cpu'
@dataclass
class DataCollator:
"""
Collator for DuIE.
"""
def __call__(self, examples: List[Dict[str, Union[list, np.ndarray]]]):
batched_input_ids = np.stack([x['input_ids'] for x in examples])
seq_lens = np.stack([x['seq_lens'] for x in examples])
tok_to_orig_start_index = np.stack(
[x['tok_to_orig_start_index'] for x in examples])
tok_to_orig_end_index = np.stack(
[x['tok_to_orig_end_index'] for x in examples])
labels = np.stack([x['labels'] for x in examples])
return (batched_input_ids, seq_lens, tok_to_orig_start_index,
tok_to_orig_end_index, labels)
# In[19]:
class DuIEDataset(paddle.io.Dataset):
"""
Dataset of DuIE.
"""
def __init__(
self,
input_ids: List[Union[List[int], np.ndarray]],
seq_lens: List[Union[List[int], np.ndarray]],
tok_to_orig_start_index: List[Union[List[int], np.ndarray]],
tok_to_orig_end_index: List[Union[List[int], np.ndarray]],
labels: List[Union[List[int], np.ndarray, List[str], List[Dict]]]):
super(DuIEDataset, self).__init__()
self.input_ids = input_ids
self.seq_lens = seq_lens
self.tok_to_orig_start_index = tok_to_orig_start_index
self.tok_to_orig_end_index = tok_to_orig_end_index
self.labels = labels
def __len__(self):
if isinstance(self.input_ids, np.ndarray):
return self.input_ids.shape[0]
else:
return len(self.input_ids)
def __getitem__(self, item):
return {
"input_ids": np.array(self.input_ids[item]),
"seq_lens": np.array(self.seq_lens[item]),
"tok_to_orig_start_index":
np.array(self.tok_to_orig_start_index[item]),
"tok_to_orig_end_index": np.array(self.tok_to_orig_end_index[item]),
# If model inputs is generated in `collate_fn`, delete the data type casting.
"labels": np.array(
self.labels[item], dtype=np.float32),
}
@classmethod
def from_file(cls,
file_path: Union[str, os.PathLike],
tokenizer: ErnieTokenizer,
max_length: Optional[int]=512,
pad_to_max_length: Optional[bool]=None):
assert os.path.exists(file_path) and os.path.isfile(
file_path), f"{file_path} dose not exists or is not a file."
label_map_path = os.path.join(
os.path.dirname(file_path), "predicate2id.json")
assert os.path.exists(label_map_path) and os.path.isfile(
label_map_path
), f"{label_map_path} dose not exists or is not a file."
with open(label_map_path, 'r', encoding='utf8') as fp:
label_map = json.load(fp)
chineseandpunctuationextractor = ChineseAndPunctuationExtractor()
input_ids, seq_lens, tok_to_orig_start_index, tok_to_orig_end_index, labels = (
[] for _ in range(5))
dataset_scale = sum(1 for line in open(
file_path, 'r', encoding="UTF-8"))
logger.info("Preprocessing data, loaded from %s" % file_path)
with open(file_path, "r", encoding="utf-8") as fp:
lines = fp.readlines()
for line in tqdm(lines):
example = json.loads(line)
input_feature = convert_example_to_feature(
example, tokenizer, chineseandpunctuationextractor,
label_map, max_length, pad_to_max_length)
input_ids.append(input_feature.input_ids)
seq_lens.append(input_feature.seq_len)
tok_to_orig_start_index.append(
input_feature.tok_to_orig_start_index)
tok_to_orig_end_index.append(
input_feature.tok_to_orig_end_index)
labels.append(input_feature.labels)
return cls(input_ids, seq_lens, tok_to_orig_start_index,
tok_to_orig_end_index, labels)
# 关系抽取中的参数
# 每批输入模型的训练样本的条数
# batch_size=8
# 训练数据地址
# data_path='./data'
# 计算环境
# device='gpu'
# 是否预测
# do_predict=False
# 是否训练
# do_train=False
# 是否有训练好的参数模型输入 如果我训练好了一个模型我想继续训练的话
# init_checkpoint=None
# 学习率
# learning_rate=5e-05
# 最大输入样本长度
# max_seq_length=128
# 训练轮次
# num_train_epochs=3
# 模型输出地址
# output_dir='./checkpoints'
# 待预测文件地址
# predict_data_file='./data/test_data.json'
# 随机种子
# seed=42
# 初始化学习率提升百分比 为了加速我们模型收敛的速度
# warmup_ratio=0
# 权重衰减
# weight_decay=0.0
# 这种修改模式存在一种问题,是深复制浅复制的问题。
#
# In[20]:
batch_size=8
data_path='./duie_train.json'
device='cpu'
do_predict=False
do_train=False
init_checkpoint=None
learning_rate=5e-05
max_seq_length=128
num_train_epochs=128
output_dir='./checkpoints'
predict_data_file='./data/test_data.json'
seed=42
warmup_ratio=0
weight_decay=0.0
# baseline中关系抽取特定的BCE损失
# In[21]:
class BCELossForDuIE(nn.Layer):
def __init__(self, ):
super(BCELossForDuIE, self).__init__()
self.criterion = nn.BCEWithLogitsLoss(reduction='none')
def forward(self, logits, labels, mask):
loss = self.criterion(logits, labels)
mask = paddle.cast(mask, 'float32')
loss = loss * mask.unsqueeze(-1)
loss = paddle.sum(loss.mean(axis=2), axis=1) / paddle.sum(mask, axis=1)
loss = loss.mean()
return loss
# 设置随机种子 一定程度上我们的模型是可以在这个随机种子打乱的数据集上获取到我们公开的结果的。
# In[22]:
def set_random_seed(seed):
"""sets random seed"""
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
# !wget https://dataset-bj.cdn.bcebos.com/qianyan/duie_train.json.zip
# !unzip duie_train.json.zip
# @paddle.no_grad()参数的意义是停止求导 这样的注释的效果是在我们进行模型预测的时候可以获取到较为稳定的预测结果。
# In[23]:
def decoding(example_batch,
id2spo,
logits_batch,
seq_len_batch,
tok_to_orig_start_index_batch,
tok_to_orig_end_index_batch):
"""
model output logits -> formatted spo (as in data set file)
"""
formatted_outputs = []
for (i, (example, logits, seq_len, tok_to_orig_start_index, tok_to_orig_end_index)) in enumerate(zip(example_batch, logits_batch, seq_len_batch, tok_to_orig_start_index_batch, tok_to_orig_end_index_batch)):
logits = logits[1:seq_len +
1] # slice between [CLS] and [SEP] to get valid logits
logits[logits >= 0.5] = 1
logits[logits < 0.5] = 0
tok_to_orig_start_index = tok_to_orig_start_index[1:seq_len + 1]
tok_to_orig_end_index = tok_to_orig_end_index[1:seq_len + 1]
predictions = []
for token in logits:
predictions.append(np.argwhere(token == 1).tolist())
# format predictions into example-style output
formatted_instance = {}
text_raw = example['text']
complex_relation_label = [8, 10, 26, 32, 46]
complex_relation_affi_label = [9, 11, 27, 28, 29, 33, 47]
# flatten predictions then retrival all valid subject id
flatten_predictions = []
for layer_1 in predictions:
for layer_2 in layer_1:
flatten_predictions.append(layer_2[0])
subject_id_list = []
for cls_label in list(set(flatten_predictions)):
if 1 < cls_label <= 56 and (cls_label + 55) in flatten_predictions:
subject_id_list.append(cls_label)
subject_id_list = list(set(subject_id_list))
# fetch all valid spo by subject id
spo_list = []
for id_ in subject_id_list:
if id_ in complex_relation_affi_label:
continue # do this in the next "else" branch
if id_ not in complex_relation_label:
subjects = find_entity(text_raw, id_, predictions,
tok_to_orig_start_index,
tok_to_orig_end_index)
objects = find_entity(text_raw, id_ + 55, predictions,
tok_to_orig_start_index,
tok_to_orig_end_index)
for subject_ in subjects:
for object_ in objects:
spo_list.append({
"predicate": id2spo['predicate'][id_],
"object_type": {
'@value': id2spo['object_type'][id_]
},
'subject_type': id2spo['subject_type'][id_],
"object": {
'@value': object_
},
"subject": subject_
})
else:
# traverse all complex relation and look through their corresponding affiliated objects
subjects = find_entity(text_raw, id_, predictions,
tok_to_orig_start_index,
tok_to_orig_end_index)
objects = find_entity(text_raw, id_ + 55, predictions,
tok_to_orig_start_index,
tok_to_orig_end_index)
for subject_ in subjects:
for object_ in objects:
object_dict = {'@value': object_}
object_type_dict = {
'@value': id2spo['object_type'][id_].split('_')[0]
}
if id_ in [8, 10, 32, 46
] and id_ + 1 in subject_id_list:
id_affi = id_ + 1
object_dict[id2spo['object_type'][id_affi].split(
'_')[1]] = find_entity(text_raw, id_affi + 55,
predictions,
tok_to_orig_start_index,
tok_to_orig_end_index)[0]
object_type_dict[id2spo['object_type'][
id_affi].split('_')[1]] = id2spo['object_type'][
id_affi].split('_')[0]
elif id_ == 26:
for id_affi in [27, 28, 29]:
if id_affi in subject_id_list:
object_dict[id2spo['object_type'][id_affi].split('_')[1]] = find_entity(text_raw, id_affi + 55, predictions, tok_to_orig_start_index, tok_to_orig_end_index)[0]
object_type_dict[id2spo['object_type'][id_affi].split('_')[1]] = id2spo['object_type'][id_affi].split('_')[0]
spo_list.append({
"predicate": id2spo['predicate'][id_],
"object_type": object_type_dict,
"subject_type": id2spo['subject_type'][id_],
"object": object_dict,
"subject": subject_
})
formatted_instance['text'] = example['text']
formatted_instance['spo_list'] = spo_list
formatted_outputs.append(formatted_instance)
return formatted_outputs
# In[8]:
def get_precision_recall_f1(golden_file, predict_file):
r = os.popen(
'python3 ./re_official_evaluation.py --golden_file={} --predict_file={}'.
format(golden_file, predict_file))
result = r.read()
r.close()
precision = float(
re.search("\"precision\", \"value\":.*?}", result).group(0).lstrip(
"\"precision\", \"value\":").rstrip("}"))
recall = float(
re.search("\"recall\", \"value\":.*?}", result).group(0).lstrip(
"\"recall\", \"value\":").rstrip("}"))
f1 = float(
re.search("\"f1-score\", \"value\":.*?}", result).group(0).lstrip(
"\"f1-score\", \"value\":").rstrip("}"))
return precision, recall, f1
# In[9]:
def write_prediction_results(formatted_outputs, file_path):
"""write the prediction results"""
with codecs.open(file_path, 'w', 'utf-8') as f:
for formatted_instance in formatted_outputs:
json_str = json.dumps(formatted_instance, ensure_ascii=False)
f.write(json_str)
f.write('\n')
zipfile_path = file_path + '.zip'
f = zipfile.ZipFile(zipfile_path, 'w', zipfile.ZIP_DEFLATED)
f.write(file_path)
return zipfile_path
# In[10]:
def ie_do_train():
paddle.set_device(device)
rank = paddle.distributed.get_rank()
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
# Reads label_map.
label_map_path = os.path.join(data_path, "predicate2id.json")
if not (os.path.exists(label_map_path) and os.path.isfile(label_map_path)):
sys.exit("{} dose not exists or is not a file.".format(label_map_path))
with open(label_map_path, 'r', encoding='utf8') as fp:
label_map = json.load(fp)
num_classes = (len(label_map.keys()) - 2) * 2 + 2
# Loads pretrained model ERNIE
model = ErnieForTokenClassification.from_pretrained(
"ernie-tiny", num_classes=num_classes)
model = paddle.DataParallel(model)
tokenizer = ErnieTokenizer.from_pretrained("ernie-tiny")
criterion = BCELossForDuIE()
# Loads dataset.
train_dataset = DuIEDataset.from_file(
os.path.join(data_path,"duie_train.json"), tokenizer,
max_seq_length, True)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
collator = DataCollator()
train_data_loader = DataLoader(
dataset=train_dataset,
batch_sampler=train_batch_sampler,
collate_fn=collator,
return_list=True)
eval_file_path = os.path.join(data_path,"duie_dev.json")
test_dataset = DuIEDataset.from_file(eval_file_path, tokenizer,
max_seq_length, True)
test_batch_sampler = paddle.io.BatchSampler(
test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_data_loader = DataLoader(
dataset=test_dataset,
batch_sampler=test_batch_sampler,
collate_fn=collator,
return_list=True)
# Defines learning rate strategy.
steps_by_epoch = len(train_data_loader)
num_training_steps = steps_by_epoch * num_train_epochs
lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps,
warmup_ratio)
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)
# Starts training.
global_step = 0
logging_steps = 50
save_steps = 10000
tic_train = time.time()
for epoch in range(num_train_epochs):
print("\n=====start training of %d epochs=====" % epoch)
tic_epoch = time.time()
model.train()
for step, batch in enumerate(train_data_loader):
input_ids, seq_lens, tok_to_orig_start_index, tok_to_orig_end_index, labels = batch
logits = model(input_ids=input_ids)
mask = (input_ids != 0).logical_and((input_ids != 1)).logical_and(
(input_ids != 2))
loss = criterion(logits, labels, mask)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
loss_item = loss.numpy().item()
global_step += 1
if global_step % logging_steps == 0 and rank == 0:
print(
"epoch: %d / %d, steps: %d / %d, loss: %f, speed: %.2f step/s"
% (epoch, num_train_epochs, step, steps_by_epoch,
loss_item, logging_steps / (time.time() - tic_train)))
tic_train = time.time()
if global_step % save_steps == 0 and rank == 0:
print("\n=====start evaluating ckpt of %d steps=====" %
global_step)
precision, recall, f1 = evaluate(
model, criterion, test_data_loader, eval_file_path, "eval")
print("precision: %.2f\t recall: %.2f\t f1: %.2f\t" %
(100 * precision, 100 * recall, 100 * f1))
print("saving checkpoing model_%d.pdparams to %s " %
(global_step, output_dir))
paddle.save(model.state_dict(),
os.path.join(output_dir,
"model_%d.pdparams" % global_step))
model.train() # back to train mode
tic_epoch = time.time() - tic_epoch
print("epoch time footprint: %d hour %d min %d sec" %
(tic_epoch // 3600, (tic_epoch % 3600) // 60, tic_epoch % 60))
# Does final evaluation.
if rank == 0:
print("\n=====start evaluating last ckpt of %d steps=====" %
global_step)
precision, recall, f1 = evaluate(model, criterion, test_data_loader,
eval_file_path, "eval")
print("precision: %.2f\t recall: %.2f\t f1: %.2f\t" %
(100 * precision, 100 * recall, 100 * f1))
paddle.save(model.state_dict(),
os.path.join(output_dir,
"model_%d.pdparams" % global_step))
print("\n=====training complete=====")
# In[11]:
def parse_label(spo_list, label_map, tokens, tokenizer):
# 2 tags for each predicate + I tag + O tag
num_labels = 2 * (len(label_map.keys()) - 2) + 2
seq_len = len(tokens)
# initialize tag
labels = [[0] * num_labels for i in range(seq_len)]
# find all entities and tag them with corresponding "B"/"I" labels
for spo in spo_list:
for spo_object in spo['object'].keys():
# assign relation label
if spo['predicate'] in label_map.keys():
# simple relation
label_subject = label_map[spo['predicate']]
label_object = label_subject + 55
subject_tokens = tokenizer._tokenize(spo['subject'])
object_tokens = tokenizer._tokenize(spo['object']['@value'])
else:
# complex relation
label_subject = label_map[spo['predicate'] + '_' + spo_object]
label_object = label_subject + 55
subject_tokens = tokenizer._tokenize(spo['subject'])
object_tokens = tokenizer._tokenize(spo['object'][spo_object])
subject_tokens_len = len(subject_tokens)
object_tokens_len = len(object_tokens)
# assign token label
# there are situations where s entity and o entity might overlap, e.g. xyz established xyz corporation
# to prevent single token from being labeled into two different entity
# we tag the longer entity first, then match the shorter entity within the rest text
forbidden_index = None
if subject_tokens_len > object_tokens_len:
for index in range(seq_len - subject_tokens_len + 1):
if tokens[index:index +
subject_tokens_len] == subject_tokens:
labels[index][label_subject] = 1
for i in range(subject_tokens_len - 1):
labels[index + i + 1][1] = 1
forbidden_index = index
break
for index in range(seq_len - object_tokens_len + 1):
if tokens[index:index + object_tokens_len] == object_tokens:
if forbidden_index is None:
labels[index][label_object] = 1
for i in range(object_tokens_len - 1):
labels[index + i + 1][1] = 1
break
# check if labeled already
elif index < forbidden_index or index >= forbidden_index + len(
subject_tokens):
labels[index][label_object] = 1
for i in range(object_tokens_len - 1):
labels[index + i + 1][1] = 1
break
else:
for index in range(seq_len - object_tokens_len + 1):
if tokens[index:index + object_tokens_len] == object_tokens:
labels[index][label_object] = 1
for i in range(object_tokens_len - 1):
labels[index + i + 1][1] = 1
forbidden_index = index
break
for index in range(seq_len - subject_tokens_len + 1):
if tokens[index:index +
subject_tokens_len] == subject_tokens:
if forbidden_index is None:
labels[index][label_subject] = 1
for i in range(subject_tokens_len - 1):
labels[index + i + 1][1] = 1
break
elif index < forbidden_index or index >= forbidden_index + len(
object_tokens):
labels[index][label_subject] = 1
for i in range(subject_tokens_len - 1):
labels[index + i + 1][1] = 1
break
# if token wasn't assigned as any "B"/"I" tag, give it an "O" tag for outside
for i in range(seq_len):
if labels[i] == [0] * num_labels:
labels[i][0] = 1
return labels
# In[12]:
InputFeature = collections.namedtuple("InputFeature", [
"input_ids", "seq_len", "tok_to_orig_start_index", "tok_to_orig_end_index",
"labels"
])
# In[13]:
def convert_example_to_feature(
example,
tokenizer: ErnieTokenizer,
chineseandpunctuationextractor: ChineseAndPunctuationExtractor,
label_map,
max_length: Optional[int]=512,
pad_to_max_length: Optional[bool]=None):
spo_list = example['spo_list'] if "spo_list" in example.keys() else None
text_raw = example['text']
sub_text = []
buff = ""
for char in text_raw:
if chineseandpunctuationextractor.is_chinese_or_punct(char):
if buff != "":
sub_text.append(buff)
buff = ""
sub_text.append(char)
else:
buff += char
if buff != "":
sub_text.append(buff)
tok_to_orig_start_index = []
tok_to_orig_end_index = []
orig_to_tok_index = []
tokens = []
text_tmp = ''
for (i, token) in enumerate(sub_text):
orig_to_tok_index.append(len(tokens))
sub_tokens = tokenizer._tokenize(token)
text_tmp += token
for sub_token in sub_tokens:
tok_to_orig_start_index.append(len(text_tmp) - len(token))
tok_to_orig_end_index.append(len(text_tmp) - 1)
tokens.append(sub_token)
if len(tokens) >= max_length - 2:
break
else:
continue
break
seq_len = len(tokens)
# 2 tags for each predicate + I tag + O tag
num_labels = 2 * (len(label_map.keys()) - 2) + 2
# initialize tag
labels = [[0] * num_labels for i in range(seq_len)]
if spo_list is not None:
labels = parse_label(spo_list, label_map, tokens, tokenizer)
# add [CLS] and [SEP] token, they are tagged into "O" for outside
if seq_len > max_length - 2:
tokens = tokens[0:(max_length - 2)]
labels = labels[0:(max_length - 2)]
tok_to_orig_start_index = tok_to_orig_start_index[0:(max_length - 2)]
tok_to_orig_end_index = tok_to_orig_end_index[0:(max_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
# "O" tag for [PAD], [CLS], [SEP] token
outside_label = [[1] + [0] * (num_labels - 1)]
labels = outside_label + labels + outside_label
tok_to_orig_start_index = [-1] + tok_to_orig_start_index + [-1]
tok_to_orig_end_index = [-1] + tok_to_orig_end_index + [-1]
if seq_len < max_length:
tokens = tokens + ["[PAD]"] * (max_length - seq_len - 2)
labels = labels + outside_label * (max_length - len(labels))
tok_to_orig_start_index = tok_to_orig_start_index + [-1] * (
max_length - len(tok_to_orig_start_index))
tok_to_orig_end_index = tok_to_orig_end_index + [-1] * (
max_length - len(tok_to_orig_end_index))
token_ids = tokenizer.convert_tokens_to_ids(tokens)
return InputFeature(
input_ids=np.array(token_ids),
seq_len=np.array(seq_len),
tok_to_orig_start_index=np.array(tok_to_orig_start_index),
tok_to_orig_end_index=np.array(tok_to_orig_end_index),
labels=np.array(labels), )
# In[14]:
@paddle.no_grad()
def evaluate(model, criterion, data_loader, file_path, mode):
"""
mode eval:
eval on development set and compute P/R/F1, called between training.
mode predict:
eval on development / test set, then write predictions to \
predict_test.json and predict_test.json.zip \
under data_path dir for later submission or evaluation.
"""
example_all = []
with open(file_path, "r", encoding="utf-8") as fp:
for line in fp:
example_all.append(json.loads(line))
id2spo_path = os.path.join(os.path.dirname(file_path), "id2spo.json")
with open(id2spo_path, 'r', encoding='utf8') as fp:
id2spo = json.load(fp)
model.eval()
loss_all = 0
eval_steps = 0
formatted_outputs = []
current_idx = 0
for batch in tqdm(data_loader, total=len(data_loader)):
eval_steps += 1
input_ids, seq_len, tok_to_orig_start_index, tok_to_orig_end_index, labels = batch
logits = model(input_ids=input_ids)
mask = (input_ids != 0).logical_and((input_ids != 1)).logical_and((input_ids != 2))
loss = criterion(logits, labels, mask)
loss_all += loss.numpy().item()
probs = F.sigmoid(logits)
logits_batch = probs.numpy()
seq_len_batch = seq_len.numpy()
tok_to_orig_start_index_batch = tok_to_orig_start_index.numpy()
tok_to_orig_end_index_batch = tok_to_orig_end_index.numpy()
formatted_outputs.extend(decoding(example_all[current_idx: current_idx + len(logits)],
id2spo,
logits_batch,
seq_len_batch,
tok_to_orig_start_index_batch,
tok_to_orig_end_index_batch))
current_idx = current_idx + len(logits)
loss_avg = loss_all / eval_steps
print("eval loss: %f" % (loss_avg))
if mode == "predict":
predict_file_path = os.path.join(data_path, 'predictions.json')
else:
predict_file_path = os.path.join(data_path, 'predict_eval.json')
predict_zipfile_path = write_prediction_results(formatted_outputs,
predict_file_path)
if mode == "eval":
precision, recall, f1 = get_precision_recall_f1(file_path,
predict_zipfile_path)
os.system('rm {} {}'.format(predict_file_path, predict_zipfile_path))
return precision, recall, f1
elif mode != "predict":
raise Exception("wrong mode for eval func")
# In[15]:
def find_entity(text_raw, id_, predictions, tok_to_orig_start_index,
tok_to_orig_end_index):
"""
retrieval entity mention under given predicate id for certain prediction.
this is called by the "decoding" func.
"""
entity_list = []
for i in range(len(predictions)):
if [id_] in predictions[i]:
j = 0
while i + j + 1 < len(predictions):
if [1] in predictions[i + j + 1]:
j += 1
else:
break
entity = ''.join(text_raw[tok_to_orig_start_index[i]:
tok_to_orig_end_index[i + j] + 1])
entity_list.append(entity)
return list(set(entity_list))
# 打造大型联合性任务闭环。突破认识边界。
# In[16]:
if __name__ == '__main__':
ie_do_train()
# ernie tiny 预训练语言模型在关系抽取任务中
# 第一个epoch的验证效果
# precision: 53.43 recall: 50.57 f1: 51.96
#
# 第二个epoch的效果
# precision: 61.35 recall: 57.74 f1: 59.49
#
# 运行过程中的资源监控
#
# ![](https://ai-studio-static-online.cdn.bcebos.com/4d8591b126ad47628d4daf796b7cc7eb59652af498e84efb859b0d36bdfeb2d9)
#
# 目前项目的问题有哪些?
# * 第一数据处理速度慢
#
# 保存数据处理结果。下次通过读取文件的方式直接获取处理后的结果。
#
# * 第二没办法进行工程化的部署
#
# 解决办法深入理解代码。对代码进行抽象。
#
# * 第三没办法新增一个关系类别
#
# 解决方案融入增量学习
set -eux
export CUDA_VISIBLE_DEVICES=0
export BATCH_SIZE=64
export CKPT=./checkpoints/model_10000.pdparams
export DATASET_FILE=DuEE_DuIE_data/data_DuIE/test_data.json
python run_duie.py \
--data_path DuEE_DuIE_data/data_DuIE/ \
--do_predict \
--init_checkpoint $CKPT \
--predict_data_file $DATASET_FILE \
--max_seq_length 128 \
--batch_size $BATCH_SIZE
# Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# imitations under the License.
"""
This module to calculate precision, recall and f1-value
of the predicated results.
"""
import sys
import json
import os
import zipfile
import traceback
import argparse
SUCCESS = 0
FILE_ERROR = 1
NOT_ZIP_FILE = 2
ENCODING_ERROR = 3
JSON_ERROR = 4
SCHEMA_ERROR = 5
ALIAS_FORMAT_ERROR = 6
CODE_INFO = {
SUCCESS: 'success',
FILE_ERROR: 'file is not exists',
NOT_ZIP_FILE: 'predict file is not a zipfile',
ENCODING_ERROR: 'file encoding error',
JSON_ERROR: 'json parse is error',
SCHEMA_ERROR: 'schema is error',
ALIAS_FORMAT_ERROR: 'alias dict format is error'
}
def del_bookname(entity_name):
"""delete the book name"""
if entity_name.startswith(u'《') and entity_name.endswith(u'》'):
entity_name = entity_name[1:-1]
return entity_name
def check_format(line):
"""检查输入行是否格式错误"""
ret_code = SUCCESS
json_info = {}
try:
line = line.strip()
except:
ret_code = ENCODING_ERROR
return ret_code, json_info
try:
json_info = json.loads(line)
except:
ret_code = JSON_ERROR
return ret_code, json_info
if 'text' not in json_info or 'spo_list' not in json_info:
ret_code = SCHEMA_ERROR
return ret_code, json_info
required_key_list = ['subject', 'predicate', 'object']
for spo_item in json_info['spo_list']:
if type(spo_item) is not dict:
ret_code = SCHEMA_ERROR
return ret_code, json_info
if not all(
[required_key in spo_item for required_key in required_key_list]):
ret_code = SCHEMA_ERROR
return ret_code, json_info
if not isinstance(spo_item['subject'], str) or \
not isinstance(spo_item['object'], dict):
ret_code = SCHEMA_ERROR
return ret_code, json_info
return ret_code, json_info
def _parse_structured_ovalue(json_info):
spo_result = []
for item in json_info["spo_list"]:
s = del_bookname(item['subject'].lower())
o = {}
for o_key, o_value in item['object'].items():
o_value = del_bookname(o_value).lower()
o[o_key] = o_value
spo_result.append({"predicate": item['predicate'], \
"subject": s, \
"object": o})
return spo_result
def load_predict_result(predict_filename):
"""Loads the file to be predicted"""
predict_result = {}
ret_code = SUCCESS
if not os.path.exists(predict_filename):
ret_code = FILE_ERROR
return ret_code, predict_result
try:
predict_file_zip = zipfile.ZipFile(predict_filename)
except:
ret_code = NOT_ZIP_FILE
return ret_code, predict_result
for predict_file in predict_file_zip.namelist():
for line in predict_file_zip.open(predict_file):
ret_code, json_info = check_format(line)
if ret_code != SUCCESS:
return ret_code, predict_result
sent = json_info['text']
spo_result = _parse_structured_ovalue(json_info)
predict_result[sent] = spo_result
return ret_code, predict_result
def load_test_dataset(golden_filename):
"""load golden file"""
golden_dict = {}
ret_code = SUCCESS
if not os.path.exists(golden_filename):
ret_code = FILE_ERROR
return ret_code, golden_dict
with open(golden_filename, 'r', encoding="utf-8") as gf:
for line in gf:
ret_code, json_info = check_format(line)
if ret_code != SUCCESS:
return ret_code, golden_dict
sent = json_info['text']
spo_result = _parse_structured_ovalue(json_info)
golden_dict[sent] = spo_result
return ret_code, golden_dict
def load_alias_dict(alias_filename):
"""load alias dict"""
alias_dict = {}
ret_code = SUCCESS
if alias_filename == "":
return ret_code, alias_dict
if not os.path.exists(alias_filename):
ret_code = FILE_ERROR
return ret_code, alias_dict
with open(alias_filename, "r", encoding="utf-8") as af:
for line in af:
line = line.strip()
try:
words = line.split('\t')
alias_dict[words[0].lower()] = set()
for alias_word in words[1:]:
alias_dict[words[0].lower()].add(alias_word.lower())
except:
ret_code = ALIAS_FORMAT_ERROR
return ret_code, alias_dict
return ret_code, alias_dict
def del_duplicate(spo_list, alias_dict):
"""delete synonyms triples in predict result"""
normalized_spo_list = []
for spo in spo_list:
if not is_spo_in_list(spo, normalized_spo_list, alias_dict):
normalized_spo_list.append(spo)
return normalized_spo_list
def is_spo_in_list(target_spo, golden_spo_list, alias_dict):
"""target spo是否在golden_spo_list中"""
if target_spo in golden_spo_list:
return True
target_s = target_spo["subject"]
target_p = target_spo["predicate"]
target_o = target_spo["object"]
target_s_alias_set = alias_dict.get(target_s, set())
target_s_alias_set.add(target_s)
for spo in golden_spo_list:
s = spo["subject"]
p = spo["predicate"]
o = spo["object"]
if p != target_p:
continue
if s in target_s_alias_set and _is_equal_o(o, target_o, alias_dict):
return True
return False
def _is_equal_o(o_a, o_b, alias_dict):
for key_a, value_a in o_a.items():
if key_a not in o_b:
return False
value_a_alias_set = alias_dict.get(value_a, set())
value_a_alias_set.add(value_a)
if o_b[key_a] not in value_a_alias_set:
return False
for key_b, value_b in o_b.items():
if key_b not in o_a:
return False
value_b_alias_set = alias_dict.get(value_b, set())
value_b_alias_set.add(value_b)
if o_a[key_b] not in value_b_alias_set:
return False
return True
def calc_pr(predict_filename, alias_filename, golden_filename):
"""calculate precision, recall, f1"""
ret_info = {}
#load alias dict
ret_code, alias_dict = load_alias_dict(alias_filename)
if ret_code != SUCCESS:
ret_info['errorCode'] = ret_code
ret_info['errorMsg'] = CODE_INFO[ret_code]
return ret_info
#load test golden dataset
ret_code, golden_dict = load_test_dataset(golden_filename)
if ret_code != SUCCESS:
ret_info['errorCode'] = ret_code
ret_info['errorMsg'] = CODE_INFO[ret_code]
return ret_info
#load predict result
ret_code, predict_result = load_predict_result(predict_filename)
if ret_code != SUCCESS:
ret_info['errorCode'] = ret_code
ret_info['errorMsg'] = CODE_INFO[ret_code]
return ret_info
#evaluation
correct_sum, predict_sum, recall_sum, recall_correct_sum = 0.0, 0.0, 0.0, 0.0
for sent in golden_dict:
golden_spo_list = del_duplicate(golden_dict[sent], alias_dict)
predict_spo_list = predict_result.get(sent, list())
normalized_predict_spo = del_duplicate(predict_spo_list, alias_dict)
recall_sum += len(golden_spo_list)
predict_sum += len(normalized_predict_spo)
for spo in normalized_predict_spo:
if is_spo_in_list(spo, golden_spo_list, alias_dict):
correct_sum += 1
for golden_spo in golden_spo_list:
if is_spo_in_list(golden_spo, predict_spo_list, alias_dict):
recall_correct_sum += 1
sys.stderr.write('correct spo num = {}\n'.format(correct_sum))
sys.stderr.write('submitted spo num = {}\n'.format(predict_sum))
sys.stderr.write('golden set spo num = {}\n'.format(recall_sum))
sys.stderr.write('submitted recall spo num = {}\n'.format(
recall_correct_sum))
precision = correct_sum / predict_sum if predict_sum > 0 else 0.0
recall = recall_correct_sum / recall_sum if recall_sum > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) \
if precision + recall > 0 else 0.0
precision = round(precision, 4)
recall = round(recall, 4)
f1 = round(f1, 4)
ret_info['errorCode'] = SUCCESS
ret_info['errorMsg'] = CODE_INFO[SUCCESS]
ret_info['data'] = []
ret_info['data'].append({'name': 'precision', 'value': precision})
ret_info['data'].append({'name': 'recall', 'value': recall})
ret_info['data'].append({'name': 'f1-score', 'value': f1})
return ret_info
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--golden_file", type=str, help="true spo results", required=True)
parser.add_argument(
"--predict_file", type=str, help="spo results predicted", required=True)
parser.add_argument(
"--alias_file", type=str, default='', help="entities alias dictionary")
args = parser.parse_args()
golden_filename = args.golden_file
predict_filename = args.predict_file
alias_filename = args.alias_file
ret_info = calc_pr(predict_filename, alias_filename, golden_filename)
print(json.dumps(ret_info))
\ No newline at end of file
# Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import random
import time
import math
import json
from functools import partial
import codecs
import zipfile
import re
from tqdm import tqdm
import sys
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import DataLoader
from paddlenlp.transformers import ErnieTokenizer, ErnieForTokenClassification, LinearDecayWithWarmup
from data_loader import DuIEDataset, DataCollator
from utils import decoding, find_entity, get_precision_recall_f1, write_prediction_results
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--do_train", action='store_true', default=False, help="do train")
parser.add_argument("--do_predict", action='store_true', default=False, help="do predict")
parser.add_argument("--init_checkpoint", default=None, type=str, required=False, help="Path to initialize params from")
parser.add_argument("--data_path", default="./data", type=str, required=False, help="Path to data.")
parser.add_argument("--predict_data_file", default="./data/test_data.json", type=str, required=False, help="Path to data.")
parser.add_argument("--output_dir", default="./checkpoints", type=str, required=False, help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--max_seq_length", default=128, type=int,help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", )
parser.add_argument("--batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.", )
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.")
parser.add_argument("--warmup_ratio", default=0, type=float, help="Linear warmup over warmup_ratio * total_steps.")
parser.add_argument("--seed", default=42, type=int, help="random seed for initialization")
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
args = parser.parse_args()
# yapf: enable
class BCELossForDuIE(nn.Layer):
def __init__(self, ):
super(BCELossForDuIE, self).__init__()
self.criterion = nn.BCEWithLogitsLoss(reduction='none')
def forward(self, logits, labels, mask):
loss = self.criterion(logits, labels)
mask = paddle.cast(mask, 'float32')
loss = loss * mask.unsqueeze(-1)
loss = paddle.sum(loss.mean(axis=2), axis=1) / paddle.sum(mask, axis=1)
loss = loss.mean()
return loss
def set_random_seed(seed):
"""sets random seed"""
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
@paddle.no_grad()
def evaluate(model, criterion, data_loader, file_path, mode):
"""
mode eval:
eval on development set and compute P/R/F1, called between training.
mode predict:
eval on development / test set, then write predictions to \
predict_test.json and predict_test.json.zip \
under args.data_path dir for later submission or evaluation.
"""
example_all = []
with open(file_path, "r", encoding="utf-8") as fp:
for line in fp:
example_all.append(json.loads(line))
id2spo_path = os.path.join(os.path.dirname(file_path), "id2spo.json")
with open(id2spo_path, 'r', encoding='utf8') as fp:
id2spo = json.load(fp)
model.eval()
loss_all = 0
eval_steps = 0
formatted_outputs = []
current_idx = 0
for batch in tqdm(data_loader, total=len(data_loader)):
eval_steps += 1
input_ids, seq_len, tok_to_orig_start_index, tok_to_orig_end_index, labels = batch
logits = model(input_ids=input_ids)
mask = (input_ids != 0).logical_and((input_ids != 1)).logical_and(
(input_ids != 2))
loss = criterion(logits, labels, mask)
loss_all += loss.numpy().item()
probs = F.sigmoid(logits)
logits_batch = probs.numpy()
seq_len_batch = seq_len.numpy()
tok_to_orig_start_index_batch = tok_to_orig_start_index.numpy()
tok_to_orig_end_index_batch = tok_to_orig_end_index.numpy()
formatted_outputs.extend(
decoding(example_all[current_idx:current_idx + len(logits)], id2spo,
logits_batch, seq_len_batch, tok_to_orig_start_index_batch,
tok_to_orig_end_index_batch))
current_idx = current_idx + len(logits)
loss_avg = loss_all / eval_steps
print("eval loss: %f" % (loss_avg))
if mode == "predict":
predict_file_path = os.path.join(args.data_path, 'predictions.json')
else:
predict_file_path = os.path.join(args.data_path, 'predict_eval.json')
predict_zipfile_path = write_prediction_results(formatted_outputs,
predict_file_path)
if mode == "eval":
precision, recall, f1 = get_precision_recall_f1(file_path,
predict_zipfile_path)
os.system('rm {} {}'.format(predict_file_path, predict_zipfile_path))
return precision, recall, f1
elif mode != "predict":
raise Exception("wrong mode for eval func")
def do_train():
paddle.set_device(args.device)
rank = paddle.distributed.get_rank()
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
# Reads label_map.
label_map_path = os.path.join(args.data_path, "predicate2id.json")
if not (os.path.exists(label_map_path) and os.path.isfile(label_map_path)):
sys.exit("{} dose not exists or is not a file.".format(label_map_path))
with open(label_map_path, 'r', encoding='utf8') as fp:
label_map = json.load(fp)
num_classes = (len(label_map.keys()) - 2) * 2 + 2
# Loads pretrained model ERNIE
model = ErnieForTokenClassification.from_pretrained(
"ernie-1.0", num_classes=num_classes)
model = paddle.DataParallel(model)
tokenizer = ErnieTokenizer.from_pretrained("ernie-1.0")
criterion = BCELossForDuIE()
# Loads dataset.
train_dataset = DuIEDataset.from_file(
os.path.join(args.data_path, 'train_data.json'), tokenizer,
args.max_seq_length, True)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
collator = DataCollator()
train_data_loader = DataLoader(
dataset=train_dataset,
batch_sampler=train_batch_sampler,
collate_fn=collator,
return_list=True)
eval_file_path = os.path.join(args.data_path, 'dev_data.json')
test_dataset = DuIEDataset.from_file(eval_file_path, tokenizer,
args.max_seq_length, True)
test_batch_sampler = paddle.io.BatchSampler(
test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True)
test_data_loader = DataLoader(
dataset=test_dataset,
batch_sampler=test_batch_sampler,
collate_fn=collator,
return_list=True)
# Defines learning rate strategy.
steps_by_epoch = len(train_data_loader)
num_training_steps = steps_by_epoch * args.num_train_epochs
lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps,
args.warmup_ratio)
# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)
# Starts training.
global_step = 0
logging_steps = 50
save_steps = 10000
tic_train = time.time()
for epoch in range(args.num_train_epochs):
print("\n=====start training of %d epochs=====" % epoch)
tic_epoch = time.time()
model.train()
for step, batch in enumerate(train_data_loader):
input_ids, seq_lens, tok_to_orig_start_index, tok_to_orig_end_index, labels = batch
logits = model(input_ids=input_ids)
mask = (input_ids != 0).logical_and((input_ids != 1)).logical_and(
(input_ids != 2))
loss = criterion(logits, labels, mask)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
loss_item = loss.numpy().item()
global_step += 1
if global_step % logging_steps == 0 and rank == 0:
print(
"epoch: %d / %d, steps: %d / %d, loss: %f, speed: %.2f step/s"
% (epoch, args.num_train_epochs, step, steps_by_epoch,
loss_item, logging_steps / (time.time() - tic_train)))
tic_train = time.time()
if global_step % save_steps == 0 and rank == 0:
print("\n=====start evaluating ckpt of %d steps=====" %
global_step)
precision, recall, f1 = evaluate(
model, criterion, test_data_loader, eval_file_path, "eval")
print("precision: %.2f\t recall: %.2f\t f1: %.2f\t" %
(100 * precision, 100 * recall, 100 * f1))
print("saving checkpoing model_%d.pdparams to %s " %
(global_step, args.output_dir))
paddle.save(model.state_dict(),
os.path.join(args.output_dir,
"model_%d.pdparams" % global_step))
model.train() # back to train mode
tic_epoch = time.time() - tic_epoch
print("epoch time footprint: %d hour %d min %d sec" %
(tic_epoch // 3600, (tic_epoch % 3600) // 60, tic_epoch % 60))
# Does final evaluation.
if rank == 0:
print("\n=====start evaluating last ckpt of %d steps=====" %
global_step)
precision, recall, f1 = evaluate(model, criterion, test_data_loader,
eval_file_path, "eval")
print("precision: %.2f\t recall: %.2f\t f1: %.2f\t" %
(100 * precision, 100 * recall, 100 * f1))
paddle.save(model.state_dict(),
os.path.join(args.output_dir,
"model_%d.pdparams" % global_step))
print("\n=====training complete=====")
def do_predict():
paddle.set_device(args.device)
# Reads label_map.
label_map_path = os.path.join(args.data_path, "predicate2id.json")
if not (os.path.exists(label_map_path) and os.path.isfile(label_map_path)):
sys.exit("{} dose not exists or is not a file.".format(label_map_path))
with open(label_map_path, 'r', encoding='utf8') as fp:
label_map = json.load(fp)
num_classes = (len(label_map.keys()) - 2) * 2 + 2
# Loads pretrained model ERNIE
model = ErnieForTokenClassification.from_pretrained(
"ernie-1.0", num_classes=num_classes)
tokenizer = ErnieTokenizer.from_pretrained("ernie-1.0")
criterion = BCELossForDuIE()
# Loads dataset.
test_dataset = DuIEDataset.from_file(args.predict_data_file, tokenizer,
args.max_seq_length, True)
collator = DataCollator()
test_batch_sampler = paddle.io.BatchSampler(
test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True)
test_data_loader = DataLoader(
dataset=test_dataset,
batch_sampler=test_batch_sampler,
collate_fn=collator,
return_list=True)
# Loads model parameters.
if not (os.path.exists(args.init_checkpoint) and
os.path.isfile(args.init_checkpoint)):
sys.exit("wrong directory: init checkpoints {} not exist".format(
args.init_checkpoint))
state_dict = paddle.load(args.init_checkpoint)
model.set_dict(state_dict)
# Does predictions.
print("\n=====start predicting=====")
evaluate(model, criterion, test_data_loader, args.predict_data_file,
"predict")
print("=====predicting complete=====")
if __name__ == "__main__":
if args.do_train:
do_train()
elif args.do_predict:
do_predict()
\ No newline at end of file
set -eux
export CUDA_VISIBLE_DEVICES=0
export BATCH_SIZE=64
export CKPT=./checkpoints/
python run_duie.py \
--data_path DuEE_DuIE_data/data_DuIE/ \
--do_train \
--output_dir $CKPT \
--num_train_epochs 2 \
--max_seq_length 128 \
--batch_size $BATCH_SIZE
\ No newline at end of file
# Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import json
import os
import re
import zipfile
import numpy as np
def find_entity(text_raw, id_, predictions, tok_to_orig_start_index,
tok_to_orig_end_index):
"""
retrieval entity mention under given predicate id for certain prediction.
this is called by the "decoding" func.
"""
entity_list = []
for i in range(len(predictions)):
if [id_] in predictions[i]:
j = 0
while i + j + 1 < len(predictions):
if [1] in predictions[i + j + 1]:
j += 1
else:
break
entity = ''.join(text_raw[tok_to_orig_start_index[i]:
tok_to_orig_end_index[i + j] + 1])
entity_list.append(entity)
return list(set(entity_list))
def decoding(example_batch, id2spo, logits_batch, seq_len_batch,
tok_to_orig_start_index_batch, tok_to_orig_end_index_batch):
"""
model output logits -> formatted spo (as in data set file)
"""
formatted_outputs = []
for (i, (example, logits, seq_len, tok_to_orig_start_index, tok_to_orig_end_index)) in \
enumerate(zip(example_batch, logits_batch, seq_len_batch, tok_to_orig_start_index_batch, tok_to_orig_end_index_batch)):
logits = logits[1:seq_len +
1] # slice between [CLS] and [SEP] to get valid logits
logits[logits >= 0.5] = 1
logits[logits < 0.5] = 0
tok_to_orig_start_index = tok_to_orig_start_index[1:seq_len + 1]
tok_to_orig_end_index = tok_to_orig_end_index[1:seq_len + 1]
predictions = []
for token in logits:
predictions.append(np.argwhere(token == 1).tolist())
# format predictions into example-style output
formatted_instance = {}
text_raw = example['text']
complex_relation_label = [8, 10, 26, 32, 46]
complex_relation_affi_label = [9, 11, 27, 28, 29, 33, 47]
# flatten predictions then retrival all valid subject id
flatten_predictions = []
for layer_1 in predictions:
for layer_2 in layer_1:
flatten_predictions.append(layer_2[0])
subject_id_list = []
for cls_label in list(set(flatten_predictions)):
if 1 < cls_label <= 56 and (cls_label + 55) in flatten_predictions:
subject_id_list.append(cls_label)
subject_id_list = list(set(subject_id_list))
# fetch all valid spo by subject id
spo_list = []
for id_ in subject_id_list:
if id_ in complex_relation_affi_label:
continue # do this in the next "else" branch
if id_ not in complex_relation_label:
subjects = find_entity(text_raw, id_, predictions,
tok_to_orig_start_index,
tok_to_orig_end_index)
objects = find_entity(text_raw, id_ + 55, predictions,
tok_to_orig_start_index,
tok_to_orig_end_index)
for subject_ in subjects:
for object_ in objects:
spo_list.append({
"predicate": id2spo['predicate'][id_],
"object_type": {
'@value': id2spo['object_type'][id_]
},
'subject_type': id2spo['subject_type'][id_],
"object": {
'@value': object_
},
"subject": subject_
})
else:
# traverse all complex relation and look through their corresponding affiliated objects
subjects = find_entity(text_raw, id_, predictions,
tok_to_orig_start_index,
tok_to_orig_end_index)
objects = find_entity(text_raw, id_ + 55, predictions,
tok_to_orig_start_index,
tok_to_orig_end_index)
for subject_ in subjects:
for object_ in objects:
object_dict = {'@value': object_}
object_type_dict = {
'@value': id2spo['object_type'][id_].split('_')[0]
}
if id_ in [8, 10, 32, 46
] and id_ + 1 in subject_id_list:
id_affi = id_ + 1
object_dict[id2spo['object_type'][id_affi].split(
'_')[1]] = find_entity(text_raw, id_affi + 55,
predictions,
tok_to_orig_start_index,
tok_to_orig_end_index)[0]
object_type_dict[id2spo['object_type'][
id_affi].split('_')[1]] = id2spo['object_type'][
id_affi].split('_')[0]
elif id_ == 26:
for id_affi in [27, 28, 29]:
if id_affi in subject_id_list:
object_dict[id2spo['object_type'][id_affi].split('_')[1]] = \
find_entity(text_raw, id_affi + 55, predictions, tok_to_orig_start_index, tok_to_orig_end_index)[0]
object_type_dict[id2spo['object_type'][id_affi].split('_')[1]] = \
id2spo['object_type'][id_affi].split('_')[0]
spo_list.append({
"predicate": id2spo['predicate'][id_],
"object_type": object_type_dict,
"subject_type": id2spo['subject_type'][id_],
"object": object_dict,
"subject": subject_
})
formatted_instance['text'] = example['text']
formatted_instance['spo_list'] = spo_list
formatted_outputs.append(formatted_instance)
return formatted_outputs
def write_prediction_results(formatted_outputs, file_path):
"""write the prediction results"""
with codecs.open(file_path, 'w', 'utf-8') as f:
for formatted_instance in formatted_outputs:
json_str = json.dumps(formatted_instance, ensure_ascii=False)
f.write(json_str)
f.write('\n')
zipfile_path = file_path + '.zip'
f = zipfile.ZipFile(zipfile_path, 'w', zipfile.ZIP_DEFLATED)
f.write(file_path)
return zipfile_path
def get_precision_recall_f1(golden_file, predict_file):
r = os.popen(
'python3 ./re_official_evaluation.py --golden_file={} --predict_file={}'.
format(golden_file, predict_file))
result = r.read()
r.close()
precision = float(
re.search("\"precision\", \"value\":.*?}", result).group(0).lstrip(
"\"precision\", \"value\":").rstrip("}"))
recall = float(
re.search("\"recall\", \"value\":.*?}", result).group(0).lstrip(
"\"recall\", \"value\":").rstrip("}"))
f1 = float(
re.search("\"f1-score\", \"value\":.*?}", result).group(0).lstrip(
"\"f1-score\", \"value\":").rstrip("}"))
return precision, recall, f1
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment