Commit ee0a5a49 by 20210828028

DuEE 1.0

parent b4649acc
# Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""duee 1.0 dataset process"""
import os
import sys
import json
from utils import read_by_lines, write_by_lines
def data_process(path, model="trigger", is_predict=False):
"""data_process"""
def label_data(data, start, l, _type):
"""label_data"""
for i in range(start, start + l):
suffix = "B-" if i == start else "I-"
data[i] = "{}{}".format(suffix, _type)
return data
sentences = []
output = ["text_a"] if is_predict else ["text_a\tlabel"]
with open(path) as f:
for line in f:
d_json = json.loads(line.strip())
_id = d_json["id"]
text_a = [
"," if t == " " or t == "\n" or t == "\t" else t
for t in list(d_json["text"].lower())
]
if is_predict:
sentences.append({"text": d_json["text"], "id": _id})
output.append('\002'.join(text_a))
else:
if model == "trigger":
labels = ["O"] * len(text_a)
for event in d_json.get("event_list", []):
event_type = event["event_type"]
start = event["trigger_start_index"]
trigger = event["trigger"]
labels = label_data(labels, start,
len(trigger), event_type)
output.append("{}\t{}".format('\002'.join(text_a),
'\002'.join(labels)))
elif model == "role":
for event in d_json.get("event_list", []):
labels = ["O"] * len(text_a)
for arg in event["arguments"]:
role_type = arg["role"]
argument = arg["argument"]
start = arg["argument_start_index"]
labels = label_data(labels, start,
len(argument), role_type)
output.append("{}\t{}".format('\002'.join(text_a),
'\002'.join(labels)))
return output
def schema_process(path, model="trigger"):
"""schema_process"""
def label_add(labels, _type):
"""label_add"""
if "B-{}".format(_type) not in labels:
labels.extend(["B-{}".format(_type), "I-{}".format(_type)])
return labels
labels = []
for line in read_by_lines(path):
d_json = json.loads(line.strip())
if model == "trigger":
labels = label_add(labels, d_json["event_type"])
elif model == "role":
for role in d_json["role_list"]:
labels = label_add(labels, role["role"])
labels.append("O")
tags = []
for index, label in enumerate(labels):
tags.append("{}\t{}".format(index, label))
return tags
if __name__ == "__main__":
print("\n=================DUEE 1.0 DATASET==============")
conf_dir = "./conf/DuEE1.0"
schema_path = "{}/event_schema.json".format(conf_dir)
tags_trigger_path = "{}/trigger_tag.dict".format(conf_dir)
tags_role_path = "{}/role_tag.dict".format(conf_dir)
print("\n=================start schema process==============")
print('input path {}'.format(schema_path))
tags_trigger = schema_process(schema_path, "trigger")
write_by_lines(tags_trigger_path, tags_trigger)
print("save trigger tag {} at {}".format(
len(tags_trigger), tags_trigger_path))
tags_role = schema_process(schema_path, "role")
write_by_lines(tags_role_path, tags_role)
print("save trigger tag {} at {}".format(len(tags_role), tags_role_path))
print("=================end schema process===============")
# data process
data_dir = "./data/DuEE1.0"
trigger_save_dir = "{}/trigger".format(data_dir)
role_save_dir = "{}/role".format(data_dir)
print("\n=================start schema process==============")
if not os.path.exists(trigger_save_dir):
os.makedirs(trigger_save_dir)
if not os.path.exists(role_save_dir):
os.makedirs(role_save_dir)
print("\n----trigger------for dir {} to {}".format(data_dir,
trigger_save_dir))
train_tri = data_process("{}/duee_train.json".format(data_dir), "trigger")
write_by_lines("{}/train.tsv".format(trigger_save_dir), train_tri)
dev_tri = data_process("{}/duee_dev.json".format(data_dir), "trigger")
write_by_lines("{}/dev.tsv".format(trigger_save_dir), dev_tri)
test_tri = data_process("{}/duee_test1.json".format(data_dir), "trigger")
write_by_lines("{}/test.tsv".format(trigger_save_dir), test_tri)
print("train {} dev {} test {}".format(
len(train_tri), len(dev_tri), len(test_tri)))
print("\n----role------for dir {} to {}".format(data_dir, role_save_dir))
train_role = data_process("{}/duee_train.json".format(data_dir), "role")
write_by_lines("{}/train.tsv".format(role_save_dir), train_role)
dev_role = data_process("{}/duee_dev.json".format(data_dir), "role")
write_by_lines("{}/dev.tsv".format(role_save_dir), dev_role)
test_role = data_process("{}/duee_test1.json".format(data_dir), "role")
write_by_lines("{}/test.tsv".format(role_save_dir), test_role)
print("train {} dev {} test {}".format(
len(train_role), len(dev_role), len(test_role)))
print("=================end schema process==============")
\ No newline at end of file
# Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""duee 1.0 data predict post-process"""
import os
import sys
import json
import argparse
from utils import read_by_lines, write_by_lines, extract_result
def predict_data_process(trigger_file, role_file, schema_file, save_path):
"""predict_data_process"""
pred_ret = []
trigger_datas = read_by_lines(trigger_file)
role_data = read_by_lines(role_file)
schema_datas = read_by_lines(schema_file)
print("trigger predict {} load from {}".format(
len(trigger_datas), trigger_file))
print("role predict {} load from {}".format(len(role_data), role_file))
print("schema {} load from {}".format(len(schema_datas), schema_file))
schema = {}
for s in schema_datas:
d_json = json.loads(s)
schema[d_json["event_type"]] = [r["role"] for r in d_json["role_list"]]
# process the role data
sent_role_mapping = {}
for d in role_data:
d_json = json.loads(d)
r_ret = extract_result(d_json["text"], d_json["pred"]["labels"])
role_ret = {}
for r in r_ret:
role_type = r["type"]
if role_type not in role_ret:
role_ret[role_type] = []
role_ret[role_type].append("".join(r["text"]))
sent_role_mapping[d_json["id"]] = role_ret
for d in trigger_datas:
d_json = json.loads(d)
t_ret = extract_result(d_json["text"], d_json["pred"]["labels"])
pred_event_types = list(set([t["type"] for t in t_ret]))
event_list = []
for event_type in pred_event_types:
role_list = schema[event_type]
arguments = []
for role_type, ags in sent_role_mapping[d_json["id"]].items():
if role_type not in role_list:
continue
for arg in ags:
if len(arg) == 1:
continue
arguments.append({"role": role_type, "argument": arg})
event = {"event_type": event_type, "arguments": arguments}
event_list.append(event)
pred_ret.append({
"id": d_json["id"],
"text": d_json["text"],
"event_list": event_list
})
pred_ret = [json.dumps(r, ensure_ascii=False) for r in pred_ret]
print("submit data {} save to {}".format(len(pred_ret), save_path))
write_by_lines(save_path, pred_ret)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Official evaluation script for DuEE version 1.0")
parser.add_argument(
"--trigger_file", help="trigger model predict data path", required=True)
parser.add_argument(
"--role_file", help="role model predict data path", required=True)
parser.add_argument("--schema_file", help="schema file path", required=True)
parser.add_argument("--save_path", help="save file path", required=True)
args = parser.parse_args()
predict_data_process(args.trigger_file, args.role_file, args.schema_file,
args.save_path)
\ No newline at end of file
dataset_name=DuEE1.0
data_dir=./data/${dataset_name}
conf_dir=./conf/${dataset_name}
ckpt_dir=./ckpt/${dataset_name}
submit_data_path=./submit/test_duee_1.json
pred_data=${data_dir}/duee_test1.json # 换其他数据,需要修改它
learning_rate=5e-5
max_seq_len=300
batch_size=16
epoch=20
echo -e "check and create directory"
dir_list=(./ckpt ${ckpt_dir} ./submit)
for item in ${dir_list[*]}
do
if [ ! -d ${item} ]; then
mkdir ${item}
echo "create dir * ${item} *"
else
echo "dir ${item} exist"
fi
done
process_name=${1}
run_sequence_labeling_model(){
model=${1}
is_train=${2}
pred_save_path=${ckpt_dir}/${model}/test_pred.json
sh run_sequence_labeling.sh ${data_dir}/${model} ${conf_dir}/${model}_tag.dict ${ckpt_dir}/${model} ${pred_data} ${learning_rate} ${is_train} ${max_seq_len} ${batch_size} ${epoch} ${pred_save_path}
}
if [ ${process_name} == data_prepare ]; then
echo -e "\nstart ${dataset_name} data prepare"
python duee_1_data_prepare.py
echo -e "end ${dataset_name} data prepare"
elif [ ${process_name} == trigger_train ]; then
echo -e "\nstart ${dataset_name} trigger train"
run_sequence_labeling_model trigger True
echo -e "end ${dataset_name} trigger train"
elif [ ${process_name} == trigger_predict ]; then
echo -e "\nstart ${dataset_name} trigger predict"
run_sequence_labeling_model trigger False
echo -e "end ${dataset_name} trigger predict"
elif [ ${process_name} == role_train ]; then
echo -e "\nstart ${dataset_name} role train"
run_sequence_labeling_model role True
echo -e "end ${dataset_name} role train"
elif [ ${process_name} == role_predict ]; then
echo -e "\nstart ${dataset_name} role predict"
run_sequence_labeling_model role False
echo -e "end ${dataset_name} role predict"
elif [ ${process_name} == pred_2_submit ]; then
echo -e "\nstart ${dataset_name} predict data merge to submit fotmat"
python duee_1_postprocess.py --trigger_file ${ckpt_dir}/trigger/test_pred.json --role_file ${ckpt_dir}/role/test_pred.json --schema_file ${conf_dir}/event_schema.json --save_path ${submit_data_path}
echo -e "end ${dataset_name} role predict data merge"
else
echo "no process name ${process_name}"
fi
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment