From ee0a5a497e0a626615c19f6ca6cb32d3eb8aa45a Mon Sep 17 00:00:00 2001
From: 20210828028 <yunwanx@foxmail.com>
Date: Wed, 9 Mar 2022 14:55:10 +0800
Subject: [PATCH] DuEE 1.0

---
 DuEE/duee_1_data_prepare.py | 138 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 DuEE/duee_1_postprocess.py  |  91 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 DuEE/run_duee_1.sh          |  60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 289 insertions(+)
 create mode 100644 DuEE/duee_1_data_prepare.py
 create mode 100644 DuEE/duee_1_postprocess.py
 create mode 100644 DuEE/run_duee_1.sh

diff --git a/DuEE/duee_1_data_prepare.py b/DuEE/duee_1_data_prepare.py
new file mode 100644
index 0000000..a64363b
--- /dev/null
+++ b/DuEE/duee_1_data_prepare.py
@@ -0,0 +1,138 @@
+# Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""duee 1.0 dataset process"""
+import os
+import sys
+import json
+from utils import read_by_lines, write_by_lines
+
+
+def data_process(path, model="trigger", is_predict=False):
+    """data_process"""
+
+    def label_data(data, start, l, _type):
+        """label_data"""
+        for i in range(start, start + l):
+            suffix = "B-" if i == start else "I-"
+            data[i] = "{}{}".format(suffix, _type)
+        return data
+
+    sentences = []
+    output = ["text_a"] if is_predict else ["text_a\tlabel"]
+    with open(path) as f:
+        for line in f:
+            d_json = json.loads(line.strip())
+            _id = d_json["id"]
+            text_a = [
+                "," if t == " " or t == "\n" or t == "\t" else t
+                for t in list(d_json["text"].lower())
+            ]
+            if is_predict:
+                sentences.append({"text": d_json["text"], "id": _id})
+                output.append('\002'.join(text_a))
+            else:
+                if model == "trigger":
+                    labels = ["O"] * len(text_a)
+                    for event in d_json.get("event_list", []):
+                        event_type = event["event_type"]
+                        start = event["trigger_start_index"]
+                        trigger = event["trigger"]
+                        labels = label_data(labels, start,
+                                            len(trigger), event_type)
+                    output.append("{}\t{}".format('\002'.join(text_a),
+                                                  '\002'.join(labels)))
+                elif model == "role":
+                    for event in d_json.get("event_list", []):
+                        labels = ["O"] * len(text_a)
+                        for arg in event["arguments"]:
+                            role_type = arg["role"]
+                            argument = arg["argument"]
+                            start = arg["argument_start_index"]
+                            labels = label_data(labels, start,
+                                                len(argument), role_type)
+                        output.append("{}\t{}".format('\002'.join(text_a),
+                                                      '\002'.join(labels)))
+    return output
+
+
+def schema_process(path, model="trigger"):
+    """schema_process"""
+
+    def label_add(labels, _type):
+        """label_add"""
+        if "B-{}".format(_type) not in labels:
+            labels.extend(["B-{}".format(_type), "I-{}".format(_type)])
+        return labels
+
+    labels = []
+    for line in read_by_lines(path):
+        d_json = json.loads(line.strip())
+        if model == "trigger":
+            labels = label_add(labels, d_json["event_type"])
+        elif model == "role":
+            for role in d_json["role_list"]:
+                labels = label_add(labels, role["role"])
+    labels.append("O")
+    tags = []
+    for index, label in enumerate(labels):
+        tags.append("{}\t{}".format(index, label))
+    return tags
+
+
+if __name__ == "__main__":
+    print("\n=================DUEE 1.0 DATASET==============")
+    conf_dir = "./conf/DuEE1.0"
+    schema_path = "{}/event_schema.json".format(conf_dir)
+    tags_trigger_path = "{}/trigger_tag.dict".format(conf_dir)
+    tags_role_path = "{}/role_tag.dict".format(conf_dir)
+    print("\n=================start schema process==============")
+    print('input path {}'.format(schema_path))
+    tags_trigger = schema_process(schema_path, "trigger")
+    write_by_lines(tags_trigger_path, tags_trigger)
+    print("save trigger tag {} at {}".format(
+        len(tags_trigger), tags_trigger_path))
+    tags_role = schema_process(schema_path, "role")
+    write_by_lines(tags_role_path, tags_role)
+    print("save trigger tag {} at {}".format(len(tags_role), tags_role_path))
+    print("=================end schema process===============")
+
+    # data process
+    data_dir = "./data/DuEE1.0"
+    trigger_save_dir = "{}/trigger".format(data_dir)
+    role_save_dir = "{}/role".format(data_dir)
+    print("\n=================start schema process==============")
+    if not os.path.exists(trigger_save_dir):
+        os.makedirs(trigger_save_dir)
+    if not os.path.exists(role_save_dir):
+        os.makedirs(role_save_dir)
+    print("\n----trigger------for dir {} to {}".format(data_dir,
+                                                       trigger_save_dir))
+    train_tri = data_process("{}/duee_train.json".format(data_dir), "trigger")
+    write_by_lines("{}/train.tsv".format(trigger_save_dir), train_tri)
+    dev_tri = data_process("{}/duee_dev.json".format(data_dir), "trigger")
+    write_by_lines("{}/dev.tsv".format(trigger_save_dir), dev_tri)
+    test_tri = data_process("{}/duee_test1.json".format(data_dir), "trigger")
+    write_by_lines("{}/test.tsv".format(trigger_save_dir), test_tri)
+    print("train {} dev {} test {}".format(
+        len(train_tri), len(dev_tri), len(test_tri)))
+    print("\n----role------for dir {} to {}".format(data_dir, role_save_dir))
+    train_role = data_process("{}/duee_train.json".format(data_dir), "role")
+    write_by_lines("{}/train.tsv".format(role_save_dir), train_role)
+    dev_role = data_process("{}/duee_dev.json".format(data_dir), "role")
+    write_by_lines("{}/dev.tsv".format(role_save_dir), dev_role)
+    test_role = data_process("{}/duee_test1.json".format(data_dir), "role")
+    write_by_lines("{}/test.tsv".format(role_save_dir), test_role)
+    print("train {} dev {} test {}".format(
+        len(train_role), len(dev_role), len(test_role)))
+    print("=================end schema process==============")
\ No newline at end of file
diff --git a/DuEE/duee_1_postprocess.py b/DuEE/duee_1_postprocess.py
new file mode 100644
index 0000000..2966415
--- /dev/null
+++ b/DuEE/duee_1_postprocess.py
@@ -0,0 +1,91 @@
+# Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""duee 1.0 data predict post-process"""
+
+import os
+import sys
+import json
+import argparse
+
+from utils import read_by_lines, write_by_lines, extract_result
+
+
+def predict_data_process(trigger_file, role_file, schema_file, save_path):
+    """predict_data_process"""
+    pred_ret = []
+    trigger_datas = read_by_lines(trigger_file)
+    role_data = read_by_lines(role_file)
+    schema_datas = read_by_lines(schema_file)
+    print("trigger predict {} load from {}".format(
+        len(trigger_datas), trigger_file))
+    print("role predict {} load from {}".format(len(role_data), role_file))
+    print("schema {} load from {}".format(len(schema_datas), schema_file))
+
+    schema = {}
+    for s in schema_datas:
+        d_json = json.loads(s)
+        schema[d_json["event_type"]] = [r["role"] for r in d_json["role_list"]]
+
+    # process the role data
+    sent_role_mapping = {}
+    for d in role_data:
+        d_json = json.loads(d)
+        r_ret = extract_result(d_json["text"], d_json["pred"]["labels"])
+        role_ret = {}
+        for r in r_ret:
+            role_type = r["type"]
+            if role_type not in role_ret:
+                role_ret[role_type] = []
+            role_ret[role_type].append("".join(r["text"]))
+        sent_role_mapping[d_json["id"]] = role_ret
+
+    for d in trigger_datas:
+        d_json = json.loads(d)
+        t_ret = extract_result(d_json["text"], d_json["pred"]["labels"])
+        pred_event_types = list(set([t["type"] for t in t_ret]))
+        event_list = []
+        for event_type in pred_event_types:
+            role_list = schema[event_type]
+            arguments = []
+            for role_type, ags in sent_role_mapping[d_json["id"]].items():
+                if role_type not in role_list:
+                    continue
+                for arg in ags:
+                    if len(arg) == 1:
+                        continue
+                    arguments.append({"role": role_type, "argument": arg})
+            event = {"event_type": event_type, "arguments": arguments}
+            event_list.append(event)
+        pred_ret.append({
+            "id": d_json["id"],
+            "text": d_json["text"],
+            "event_list": event_list
+        })
+    pred_ret = [json.dumps(r, ensure_ascii=False) for r in pred_ret]
+    print("submit data {} save to {}".format(len(pred_ret), save_path))
+    write_by_lines(save_path, pred_ret)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        description="Official evaluation script for DuEE version 1.0")
+    parser.add_argument(
+        "--trigger_file", help="trigger model predict data path", required=True)
+    parser.add_argument(
+        "--role_file", help="role model predict data path", required=True)
+    parser.add_argument("--schema_file", help="schema file path", required=True)
+    parser.add_argument("--save_path", help="save file path", required=True)
+    args = parser.parse_args()
+    predict_data_process(args.trigger_file, args.role_file, args.schema_file,
+                         args.save_path)
\ No newline at end of file
diff --git a/DuEE/run_duee_1.sh b/DuEE/run_duee_1.sh
new file mode 100644
index 0000000..4d59959
--- /dev/null
+++ b/DuEE/run_duee_1.sh
@@ -0,0 +1,60 @@
+dataset_name=DuEE1.0
+data_dir=./data/${dataset_name}
+conf_dir=./conf/${dataset_name}
+ckpt_dir=./ckpt/${dataset_name}
+submit_data_path=./submit/test_duee_1.json
+pred_data=${data_dir}/duee_test1.json   # 换其他数据,需要修改它
+
+learning_rate=5e-5
+max_seq_len=300
+batch_size=16
+epoch=20
+
+echo -e "check and create directory"
+dir_list=(./ckpt ${ckpt_dir} ./submit)
+for item in ${dir_list[*]}
+do
+    if [ ! -d ${item} ]; then
+        mkdir ${item}
+        echo "create dir * ${item} *"
+    else
+        echo "dir ${item} exist"
+    fi
+done
+
+process_name=${1}
+
+run_sequence_labeling_model(){
+    model=${1}
+    is_train=${2}
+    pred_save_path=${ckpt_dir}/${model}/test_pred.json
+    sh run_sequence_labeling.sh ${data_dir}/${model} ${conf_dir}/${model}_tag.dict ${ckpt_dir}/${model} ${pred_data} ${learning_rate} ${is_train} ${max_seq_len} ${batch_size} ${epoch} ${pred_save_path}
+}
+
+if [ ${process_name} == data_prepare ]; then
+    echo -e "\nstart ${dataset_name} data prepare"
+    python duee_1_data_prepare.py
+    echo -e "end ${dataset_name} data prepare"
+elif [ ${process_name} == trigger_train ]; then
+    echo -e "\nstart ${dataset_name} trigger train"
+    run_sequence_labeling_model trigger True
+    echo -e "end ${dataset_name} trigger train"
+elif [ ${process_name} == trigger_predict ]; then
+    echo -e "\nstart ${dataset_name} trigger predict"
+    run_sequence_labeling_model trigger False
+    echo -e "end ${dataset_name} trigger predict"
+elif [ ${process_name} == role_train ]; then
+    echo -e "\nstart ${dataset_name} role train"
+    run_sequence_labeling_model role True
+    echo -e "end ${dataset_name} role train"
+elif [ ${process_name} == role_predict ]; then
+    echo -e "\nstart ${dataset_name} role predict"
+    run_sequence_labeling_model role False
+    echo -e "end ${dataset_name} role predict"
+elif [ ${process_name} == pred_2_submit ]; then
+    echo -e "\nstart ${dataset_name} predict data merge to submit fotmat"
+    python duee_1_postprocess.py --trigger_file ${ckpt_dir}/trigger/test_pred.json --role_file ${ckpt_dir}/role/test_pred.json --schema_file ${conf_dir}/event_schema.json --save_path ${submit_data_path}
+    echo -e "end ${dataset_name} role predict data merge"
+else
+    echo "no process name ${process_name}"
+fi
\ No newline at end of file
--
libgit2 0.26.0