From ee0a5a497e0a626615c19f6ca6cb32d3eb8aa45a Mon Sep 17 00:00:00 2001 From: 20210828028 <yunwanx@foxmail.com> Date: Wed, 9 Mar 2022 14:55:10 +0800 Subject: [PATCH] DuEE 1.0 --- DuEE/duee_1_data_prepare.py | 138 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ DuEE/duee_1_postprocess.py | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ DuEE/run_duee_1.sh | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 289 insertions(+) create mode 100644 DuEE/duee_1_data_prepare.py create mode 100644 DuEE/duee_1_postprocess.py create mode 100644 DuEE/run_duee_1.sh diff --git a/DuEE/duee_1_data_prepare.py b/DuEE/duee_1_data_prepare.py new file mode 100644 index 0000000..a64363b --- /dev/null +++ b/DuEE/duee_1_data_prepare.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""duee 1.0 dataset process""" +import os +import sys +import json +from utils import read_by_lines, write_by_lines + + +def data_process(path, model="trigger", is_predict=False): + """data_process""" + + def label_data(data, start, l, _type): + """label_data""" + for i in range(start, start + l): + suffix = "B-" if i == start else "I-" + data[i] = "{}{}".format(suffix, _type) + return data + + sentences = [] + output = ["text_a"] if is_predict else ["text_a\tlabel"] + with open(path) as f: + for line in f: + d_json = json.loads(line.strip()) + _id = d_json["id"] + text_a = [ + "," if t == " " or t == "\n" or t == "\t" else t + for t in list(d_json["text"].lower()) + ] + if is_predict: + sentences.append({"text": d_json["text"], "id": _id}) + output.append('\002'.join(text_a)) + else: + if model == "trigger": + labels = ["O"] * len(text_a) + for event in d_json.get("event_list", []): + event_type = event["event_type"] + start = event["trigger_start_index"] + trigger = event["trigger"] + labels = label_data(labels, start, + len(trigger), event_type) + output.append("{}\t{}".format('\002'.join(text_a), + '\002'.join(labels))) + elif model == "role": + for event in d_json.get("event_list", []): + labels = ["O"] * len(text_a) + for arg in event["arguments"]: + role_type = arg["role"] + argument = arg["argument"] + start = arg["argument_start_index"] + labels = label_data(labels, start, + len(argument), role_type) + output.append("{}\t{}".format('\002'.join(text_a), + '\002'.join(labels))) + return output + + +def schema_process(path, model="trigger"): + """schema_process""" + + def label_add(labels, _type): + """label_add""" + if "B-{}".format(_type) not in labels: + labels.extend(["B-{}".format(_type), "I-{}".format(_type)]) + return labels + + labels = [] + for line in read_by_lines(path): + d_json = json.loads(line.strip()) + if model == "trigger": + labels = label_add(labels, d_json["event_type"]) + elif model == "role": + for role in d_json["role_list"]: + labels = label_add(labels, role["role"]) + labels.append("O") + tags = [] + for index, label in enumerate(labels): + tags.append("{}\t{}".format(index, label)) + return tags + + +if __name__ == "__main__": + print("\n=================DUEE 1.0 DATASET==============") + conf_dir = "./conf/DuEE1.0" + schema_path = "{}/event_schema.json".format(conf_dir) + tags_trigger_path = "{}/trigger_tag.dict".format(conf_dir) + tags_role_path = "{}/role_tag.dict".format(conf_dir) + print("\n=================start schema process==============") + print('input path {}'.format(schema_path)) + tags_trigger = schema_process(schema_path, "trigger") + write_by_lines(tags_trigger_path, tags_trigger) + print("save trigger tag {} at {}".format( + len(tags_trigger), tags_trigger_path)) + tags_role = schema_process(schema_path, "role") + write_by_lines(tags_role_path, tags_role) + print("save trigger tag {} at {}".format(len(tags_role), tags_role_path)) + print("=================end schema process===============") + + # data process + data_dir = "./data/DuEE1.0" + trigger_save_dir = "{}/trigger".format(data_dir) + role_save_dir = "{}/role".format(data_dir) + print("\n=================start schema process==============") + if not os.path.exists(trigger_save_dir): + os.makedirs(trigger_save_dir) + if not os.path.exists(role_save_dir): + os.makedirs(role_save_dir) + print("\n----trigger------for dir {} to {}".format(data_dir, + trigger_save_dir)) + train_tri = data_process("{}/duee_train.json".format(data_dir), "trigger") + write_by_lines("{}/train.tsv".format(trigger_save_dir), train_tri) + dev_tri = data_process("{}/duee_dev.json".format(data_dir), "trigger") + write_by_lines("{}/dev.tsv".format(trigger_save_dir), dev_tri) + test_tri = data_process("{}/duee_test1.json".format(data_dir), "trigger") + write_by_lines("{}/test.tsv".format(trigger_save_dir), test_tri) + print("train {} dev {} test {}".format( + len(train_tri), len(dev_tri), len(test_tri))) + print("\n----role------for dir {} to {}".format(data_dir, role_save_dir)) + train_role = data_process("{}/duee_train.json".format(data_dir), "role") + write_by_lines("{}/train.tsv".format(role_save_dir), train_role) + dev_role = data_process("{}/duee_dev.json".format(data_dir), "role") + write_by_lines("{}/dev.tsv".format(role_save_dir), dev_role) + test_role = data_process("{}/duee_test1.json".format(data_dir), "role") + write_by_lines("{}/test.tsv".format(role_save_dir), test_role) + print("train {} dev {} test {}".format( + len(train_role), len(dev_role), len(test_role))) + print("=================end schema process==============") \ No newline at end of file diff --git a/DuEE/duee_1_postprocess.py b/DuEE/duee_1_postprocess.py new file mode 100644 index 0000000..2966415 --- /dev/null +++ b/DuEE/duee_1_postprocess.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""duee 1.0 data predict post-process""" + +import os +import sys +import json +import argparse + +from utils import read_by_lines, write_by_lines, extract_result + + +def predict_data_process(trigger_file, role_file, schema_file, save_path): + """predict_data_process""" + pred_ret = [] + trigger_datas = read_by_lines(trigger_file) + role_data = read_by_lines(role_file) + schema_datas = read_by_lines(schema_file) + print("trigger predict {} load from {}".format( + len(trigger_datas), trigger_file)) + print("role predict {} load from {}".format(len(role_data), role_file)) + print("schema {} load from {}".format(len(schema_datas), schema_file)) + + schema = {} + for s in schema_datas: + d_json = json.loads(s) + schema[d_json["event_type"]] = [r["role"] for r in d_json["role_list"]] + + # process the role data + sent_role_mapping = {} + for d in role_data: + d_json = json.loads(d) + r_ret = extract_result(d_json["text"], d_json["pred"]["labels"]) + role_ret = {} + for r in r_ret: + role_type = r["type"] + if role_type not in role_ret: + role_ret[role_type] = [] + role_ret[role_type].append("".join(r["text"])) + sent_role_mapping[d_json["id"]] = role_ret + + for d in trigger_datas: + d_json = json.loads(d) + t_ret = extract_result(d_json["text"], d_json["pred"]["labels"]) + pred_event_types = list(set([t["type"] for t in t_ret])) + event_list = [] + for event_type in pred_event_types: + role_list = schema[event_type] + arguments = [] + for role_type, ags in sent_role_mapping[d_json["id"]].items(): + if role_type not in role_list: + continue + for arg in ags: + if len(arg) == 1: + continue + arguments.append({"role": role_type, "argument": arg}) + event = {"event_type": event_type, "arguments": arguments} + event_list.append(event) + pred_ret.append({ + "id": d_json["id"], + "text": d_json["text"], + "event_list": event_list + }) + pred_ret = [json.dumps(r, ensure_ascii=False) for r in pred_ret] + print("submit data {} save to {}".format(len(pred_ret), save_path)) + write_by_lines(save_path, pred_ret) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Official evaluation script for DuEE version 1.0") + parser.add_argument( + "--trigger_file", help="trigger model predict data path", required=True) + parser.add_argument( + "--role_file", help="role model predict data path", required=True) + parser.add_argument("--schema_file", help="schema file path", required=True) + parser.add_argument("--save_path", help="save file path", required=True) + args = parser.parse_args() + predict_data_process(args.trigger_file, args.role_file, args.schema_file, + args.save_path) \ No newline at end of file diff --git a/DuEE/run_duee_1.sh b/DuEE/run_duee_1.sh new file mode 100644 index 0000000..4d59959 --- /dev/null +++ b/DuEE/run_duee_1.sh @@ -0,0 +1,60 @@ +dataset_name=DuEE1.0 +data_dir=./data/${dataset_name} +conf_dir=./conf/${dataset_name} +ckpt_dir=./ckpt/${dataset_name} +submit_data_path=./submit/test_duee_1.json +pred_data=${data_dir}/duee_test1.json # 换其他数据,需要修改它 + +learning_rate=5e-5 +max_seq_len=300 +batch_size=16 +epoch=20 + +echo -e "check and create directory" +dir_list=(./ckpt ${ckpt_dir} ./submit) +for item in ${dir_list[*]} +do + if [ ! -d ${item} ]; then + mkdir ${item} + echo "create dir * ${item} *" + else + echo "dir ${item} exist" + fi +done + +process_name=${1} + +run_sequence_labeling_model(){ + model=${1} + is_train=${2} + pred_save_path=${ckpt_dir}/${model}/test_pred.json + sh run_sequence_labeling.sh ${data_dir}/${model} ${conf_dir}/${model}_tag.dict ${ckpt_dir}/${model} ${pred_data} ${learning_rate} ${is_train} ${max_seq_len} ${batch_size} ${epoch} ${pred_save_path} +} + +if [ ${process_name} == data_prepare ]; then + echo -e "\nstart ${dataset_name} data prepare" + python duee_1_data_prepare.py + echo -e "end ${dataset_name} data prepare" +elif [ ${process_name} == trigger_train ]; then + echo -e "\nstart ${dataset_name} trigger train" + run_sequence_labeling_model trigger True + echo -e "end ${dataset_name} trigger train" +elif [ ${process_name} == trigger_predict ]; then + echo -e "\nstart ${dataset_name} trigger predict" + run_sequence_labeling_model trigger False + echo -e "end ${dataset_name} trigger predict" +elif [ ${process_name} == role_train ]; then + echo -e "\nstart ${dataset_name} role train" + run_sequence_labeling_model role True + echo -e "end ${dataset_name} role train" +elif [ ${process_name} == role_predict ]; then + echo -e "\nstart ${dataset_name} role predict" + run_sequence_labeling_model role False + echo -e "end ${dataset_name} role predict" +elif [ ${process_name} == pred_2_submit ]; then + echo -e "\nstart ${dataset_name} predict data merge to submit fotmat" + python duee_1_postprocess.py --trigger_file ${ckpt_dir}/trigger/test_pred.json --role_file ${ckpt_dir}/role/test_pred.json --schema_file ${conf_dir}/event_schema.json --save_path ${submit_data_path} + echo -e "end ${dataset_name} role predict data merge" +else + echo "no process name ${process_name}" +fi \ No newline at end of file -- libgit2 0.26.0