# Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""duee finance data predict post-process"""

import os
import sys
import json
import argparse

from utils import read_by_lines, write_by_lines, extract_result

enum_event_type = "公司上市"
enum_role = "环节"


def event_normalization(doc):
    """event_merge"""
    for event in doc.get("event_list", []):
        argument_list = []
        argument_set = set()
        for arg in event["arguments"]:
            arg_str = "{}-{}".format(arg["role"], arg["argument"])
            if arg_str not in argument_set:
                argument_list.append(arg)
            argument_set.add(arg_str)
        event["arguments"] = argument_list

    event_list = sorted(
        doc.get("event_list", []),
        key=lambda x: len(x["arguments"]),
        reverse=True)
    new_event_list = []
    for event in event_list:
        event_type = event["event_type"]
        event_argument_set = set()
        for arg in event["arguments"]:
            event_argument_set.add("{}-{}".format(arg["role"], arg["argument"]))
        flag = True
        for new_event in new_event_list:
            if event_type != new_event["event_type"]:
                continue
            new_event_argument_set = set()
            for arg in new_event["arguments"]:
                new_event_argument_set.add("{}-{}".format(arg["role"], arg[
                    "argument"]))
            if len(event_argument_set & new_event_argument_set) == len(
                    new_event_argument_set):
                flag = False
        if flag:
            new_event_list.append(event)
    doc["event_list"] = new_event_list
    return doc


def predict_data_process(trigger_file, role_file, enum_file, schema_file,
                         save_path):
    """predict_data_process"""
    pred_ret = []
    trigger_data = read_by_lines(trigger_file)
    role_data = read_by_lines(role_file)
    enum_data = read_by_lines(enum_file)
    schema_data = read_by_lines(schema_file)
    print("trigger predict {} load from {}".format(
        len(trigger_data), trigger_file))
    print("role predict {} load from {}".format(len(role_data), role_file))
    print("enum predict {} load from {}".format(len(enum_data), enum_file))
    print("schema {} load from {}".format(len(schema_data), schema_file))

    schema, sent_role_mapping, sent_enum_mapping = {}, {}, {}
    for s in schema_data:
        d_json = json.loads(s)
        schema[d_json["event_type"]] = [r["role"] for r in d_json["role_list"]]

    # role depends on id and sent_id 
    for d in role_data:
        d_json = json.loads(d)
        r_ret = extract_result(d_json["text"], d_json["pred"]["labels"])
        role_ret = {}
        for r in r_ret:
            role_type = r["type"]
            if role_type not in role_ret:
                role_ret[role_type] = []
            role_ret[role_type].append("".join(r["text"]))
        _id = "{}\t{}".format(d_json["id"], d_json["sent_id"])
        sent_role_mapping[_id] = role_ret

    # process the enum_role data
    for d in enum_data:
        d_json = json.loads(d)
        _id = "{}\t{}".format(d_json["id"], d_json["sent_id"])
        label = d_json["pred"]["label"]
        sent_enum_mapping[_id] = label

    # process trigger data
    for d in trigger_data:
        d_json = json.loads(d)
        t_ret = extract_result(d_json["text"], d_json["pred"]["labels"])
        pred_event_types = list(set([t["type"] for t in t_ret]))
        event_list = []
        _id = "{}\t{}".format(d_json["id"], d_json["sent_id"])
        for event_type in pred_event_types:
            role_list = schema[event_type]
            arguments = []
            for role_type, ags in sent_role_mapping[_id].items():
                if role_type not in role_list:
                    continue
                for arg in ags:
                    arguments.append({"role": role_type, "argument": arg})
            # 特殊处理环节
            if event_type == enum_event_type:
                arguments.append({
                    "role": enum_role,
                    "argument": sent_enum_mapping[_id]
                })
            event = {
                "event_type": event_type,
                "arguments": arguments,
                "text": d_json["text"]
            }
            event_list.append(event)
        pred_ret.append({
            "id": d_json["id"],
            "sent_id": d_json["sent_id"],
            "text": d_json["text"],
            "event_list": event_list
        })
    doc_pred = {}
    for d in pred_ret:
        if d["id"] not in doc_pred:
            doc_pred[d["id"]] = {"id": d["id"], "event_list": []}
        doc_pred[d["id"]]["event_list"].extend(d["event_list"])

    # unfiy the all prediction results and save them
    doc_pred = [
        json.dumps(
            event_normalization(r), ensure_ascii=False)
        for r in doc_pred.values()
    ]
    print("submit data {} save to {}".format(len(doc_pred), save_path))
    write_by_lines(save_path, doc_pred)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Official evaluation script for DuEE version 1.0")
    parser.add_argument(
        "--trigger_file", help="trigger model predict data path", required=True)
    parser.add_argument(
        "--role_file", help="role model predict data path", required=True)
    parser.add_argument(
        "--enum_file", help="enum model predict data path", required=True)
    parser.add_argument("--schema_file", help="schema file path", required=True)
    parser.add_argument("--save_path", help="save file path", required=True)
    args = parser.parse_args()
    predict_data_process(args.trigger_file, args.role_file, args.enum_file,
                         args.schema_file, args.save_path)