# Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """duee finance data predict post-process""" import os import sys import json import argparse from utils import read_by_lines, write_by_lines, extract_result enum_event_type = "公司上市" enum_role = "环节" def event_normalization(doc): """event_merge""" for event in doc.get("event_list", []): argument_list = [] argument_set = set() for arg in event["arguments"]: arg_str = "{}-{}".format(arg["role"], arg["argument"]) if arg_str not in argument_set: argument_list.append(arg) argument_set.add(arg_str) event["arguments"] = argument_list event_list = sorted( doc.get("event_list", []), key=lambda x: len(x["arguments"]), reverse=True) new_event_list = [] for event in event_list: event_type = event["event_type"] event_argument_set = set() for arg in event["arguments"]: event_argument_set.add("{}-{}".format(arg["role"], arg["argument"])) flag = True for new_event in new_event_list: if event_type != new_event["event_type"]: continue new_event_argument_set = set() for arg in new_event["arguments"]: new_event_argument_set.add("{}-{}".format(arg["role"], arg[ "argument"])) if len(event_argument_set & new_event_argument_set) == len( new_event_argument_set): flag = False if flag: new_event_list.append(event) doc["event_list"] = new_event_list return doc def predict_data_process(trigger_file, role_file, enum_file, schema_file, save_path): """predict_data_process""" pred_ret = [] trigger_data = read_by_lines(trigger_file) role_data = read_by_lines(role_file) enum_data = read_by_lines(enum_file) schema_data = read_by_lines(schema_file) print("trigger predict {} load from {}".format( len(trigger_data), trigger_file)) print("role predict {} load from {}".format(len(role_data), role_file)) print("enum predict {} load from {}".format(len(enum_data), enum_file)) print("schema {} load from {}".format(len(schema_data), schema_file)) schema, sent_role_mapping, sent_enum_mapping = {}, {}, {} for s in schema_data: d_json = json.loads(s) schema[d_json["event_type"]] = [r["role"] for r in d_json["role_list"]] # role depends on id and sent_id for d in role_data: d_json = json.loads(d) r_ret = extract_result(d_json["text"], d_json["pred"]["labels"]) role_ret = {} for r in r_ret: role_type = r["type"] if role_type not in role_ret: role_ret[role_type] = [] role_ret[role_type].append("".join(r["text"])) _id = "{}\t{}".format(d_json["id"], d_json["sent_id"]) sent_role_mapping[_id] = role_ret # process the enum_role data for d in enum_data: d_json = json.loads(d) _id = "{}\t{}".format(d_json["id"], d_json["sent_id"]) label = d_json["pred"]["label"] sent_enum_mapping[_id] = label # process trigger data for d in trigger_data: d_json = json.loads(d) t_ret = extract_result(d_json["text"], d_json["pred"]["labels"]) pred_event_types = list(set([t["type"] for t in t_ret])) event_list = [] _id = "{}\t{}".format(d_json["id"], d_json["sent_id"]) for event_type in pred_event_types: role_list = schema[event_type] arguments = [] for role_type, ags in sent_role_mapping[_id].items(): if role_type not in role_list: continue for arg in ags: arguments.append({"role": role_type, "argument": arg}) # 特殊处理环节 if event_type == enum_event_type: arguments.append({ "role": enum_role, "argument": sent_enum_mapping[_id] }) event = { "event_type": event_type, "arguments": arguments, "text": d_json["text"] } event_list.append(event) pred_ret.append({ "id": d_json["id"], "sent_id": d_json["sent_id"], "text": d_json["text"], "event_list": event_list }) doc_pred = {} for d in pred_ret: if d["id"] not in doc_pred: doc_pred[d["id"]] = {"id": d["id"], "event_list": []} doc_pred[d["id"]]["event_list"].extend(d["event_list"]) # unfiy the all prediction results and save them doc_pred = [ json.dumps( event_normalization(r), ensure_ascii=False) for r in doc_pred.values() ] print("submit data {} save to {}".format(len(doc_pred), save_path)) write_by_lines(save_path, doc_pred) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Official evaluation script for DuEE version 1.0") parser.add_argument( "--trigger_file", help="trigger model predict data path", required=True) parser.add_argument( "--role_file", help="role model predict data path", required=True) parser.add_argument( "--enum_file", help="enum model predict data path", required=True) parser.add_argument("--schema_file", help="schema file path", required=True) parser.add_argument("--save_path", help="save file path", required=True) args = parser.parse_args() predict_data_process(args.trigger_file, args.role_file, args.enum_file, args.schema_file, args.save_path)