import argparse import json import numpy as np import torch from torch import device import os def main(): parser = argparse.ArgumentParser() parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)") parser.add_argument('--template_dir', type=str, help='Template directory') # These options should be kept as their default values parser.add_argument("--k", type=int, default=16) parser.add_argument("--log", type=str, default="log", help="Log path.") parser.add_argument("--key", type=str, default='', help="Validation metric name") parser.add_argument("--test_key", type=str, default="", help="Test metric name") parser.add_argument("--test_key2", type=str, default="", help="Second test metric name") args = parser.parse_args() condition = eval(args.condition) if len(args.key) == 0: if condition['task_name'] == 'cola': args.key = 'cola_dev_eval_mcc' args.test_key = 'cola_test_eval_mcc' print_name = 'CoLA' elif condition['task_name'] == 'mrpc/acc': args.key = 'mrpc_dev_eval_acc' args.test_key = 'mrpc_test_eval_acc' args.test_key2 = 'mrpc_test_eval_f1' condition['task_name'] = 'mrpc' print_name = 'MRPC' elif condition['task_name'] == 'mrpc/f1': args.key = 'mrpc_dev_eval_f1' args.test_key2 = 'mrpc_test_eval_acc' args.test_key = 'mrpc_test_eval_f1' condition['task_name'] = 'mrpc' print_name = 'MRPC' elif condition['task_name'] == 'qqp/acc': args.key = 'qqp_dev_eval_acc' args.test_key = 'qqp_test_eval_acc' args.test_key2 = 'qqp_test_eval_f1' condition['task_name'] = 'qqp' print_name = 'QQP' elif condition['task_name'] == 'qqp/f1': args.key = 'qqp_dev_eval_f1' args.test_key2 = 'qqp_test_eval_acc' args.test_key = 'qqp_test_eval_f1' condition['task_name'] = 'qqp' print_name = 'QQP' elif condition['task_name'] == 'sts-b/pearson': args.key = 'sts-b_dev_eval_pearson' args.test_key = 'sts-b_test_eval_pearson' args.test_key2 = 'sts-b_test_eval_spearmanr' condition['task_name'] = 'sts-b' print_name = 'STS-B' elif condition['task_name'] == 'sts-b/spearmanr': args.key = 'sts-b_dev_eval_spearmanr' args.test_key2 = 'sts-b_test_eval_pearson' args.test_key = 'sts-b_test_eval_spearmanr' condition['task_name'] = 'sts-b' print_name = 'STS-B' elif condition['task_name'] == 'qnli': args.key = 'qnli_dev_eval_acc' args.test_key = 'qnli_test_eval_acc' print_name = 'QNLI' elif condition['task_name'] == 'sst-2': args.key = 'sst-2_dev_eval_acc' args.test_key = 'sst-2_test_eval_acc' print_name = 'SST-2' elif condition['task_name'] == 'snli': args.key = 'snli_dev_eval_acc' args.test_key = 'snli_test_eval_acc' print_name = 'SNLI' elif condition['task_name'] == 'mnli': args.key = 'mnli_dev_eval_mnli/acc' args.test_key = 'mnli_test_eval_mnli/acc' print_name = 'MNLI' elif condition['task_name'] == 'mnli-mm': condition['task_name'] = 'mnli' args.key = 'mnli_dev_eval_mnli/acc' args.test_key = 'mnli-mm_test_eval_mnli-mm/acc' print_name = 'MNLI' elif condition['task_name'] == 'rte': args.key = 'rte_dev_eval_acc' args.test_key = 'rte_test_eval_acc' print_name = 'RTE' elif condition['task_name'] == 'ag_news': args.key = 'ag_news_dev_eval_acc' args.test_key = 'ag_news_test_eval_acc' print_name = condition['task_name'] elif condition['task_name'] == 'yahoo_answers': args.key = 'yahoo_answers_dev_eval_acc' args.test_key = 'yahoo_answers_test_eval_acc' print_name = condition['task_name'] elif condition['task_name'] == 'yelp_review_full': args.key = 'yelp_review_full_dev_eval_acc' args.test_key = 'yelp_review_full_test_eval_acc' print_name = condition['task_name'] elif condition['task_name'] == 'mr': args.key = 'mr_dev_eval_acc' args.test_key = 'mr_test_eval_acc' print_name = condition['task_name'] elif condition['task_name'] == 'sst-5': args.key = 'sst-5_dev_eval_acc' args.test_key = 'sst-5_test_eval_acc' print_name = condition['task_name'] elif condition['task_name'] == 'subj': args.key = 'subj_dev_eval_acc' args.test_key = 'subj_test_eval_acc' print_name = condition['task_name'] elif condition['task_name'] == 'trec': args.key = 'trec_dev_eval_acc' args.test_key = 'trec_test_eval_acc' print_name = condition['task_name'] elif condition['task_name'] == 'cr': args.key = 'cr_dev_eval_acc' args.test_key = 'cr_test_eval_acc' print_name = condition['task_name'] elif condition['task_name'] == 'mpqa': args.key = 'mpqa_dev_eval_acc' args.test_key = 'mpqa_test_eval_acc' print_name = condition['task_name'] else: raise NotImplementedError with open(args.log) as f: result_list = [] for line in f: result_list.append(eval(line)) seed_result = {} seed_result_template_id = {} # avoid duplication for item in result_list: ok = True for cond in condition: if cond not in item or item[cond] != condition[cond]: ok = False break if ok: seed = item['seed'] if seed not in seed_result: seed_result[seed] = [item] seed_result_template_id[seed] = {item['template_id']: 1} else: if item['template_id'] not in seed_result_template_id[seed]: seed_result[seed].append(item) seed_result_template_id[seed][item['template_id']] = 1 for seed in seed_result: print("Seed %d has %d results" % (seed, len(seed_result[seed]))) # Load all templates with open(os.path.join(args.template_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f: templates = [] for line in f: templates.append(line.strip()) # Write sorted templates fsort = open(os.path.join(args.template_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w') fscore = open(os.path.join(args.template_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w') seed_result[seed].sort(key=lambda x: x[args.key], reverse=True) for item in seed_result[seed]: fsort.write(templates[item['template_id']] + '\n') fscore.write("%.5f %s\n" % (item[args.key], templates[item['template_id']])) if __name__ == '__main__': main()