Commit cf1a4f65 by 20210828028

v1

parent da8ff096
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
*.sh
.venv
.vscode
data
!data/k-shot/checksum
log*
runs
result
wandb
ensemble_predict_results
auto*
my*
slurm
MIT License
Copyright (c) 2021 Princeton Natural Language Processing
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
0%| | 0/18 [00:00<?, ?it/s] 6%|▌ | 1/18 [02:41<45:41, 161.24s/it]
\ No newline at end of file
certifi==2020.12.5
chardet==4.0.0
click==7.1.2
dataclasses
filelock==3.0.12
flake8==3.8.4
future==0.18.2
idna==2.10
importlib-metadata==3.3.0
joblib==1.0.0
mccabe==0.6.1
nltk==3.5
numpy==1.19.4
packaging==20.8
pandas==1.1.5
protobuf==3.14.0
pycodestyle==2.6.0
pyflakes==2.2.0
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2020.5
regex==2020.11.13
requests==2.25.1
sacremoses==0.0.43
scikit-learn==0.24.0
scipy==1.5.4
sentence-transformers==0.4.0
sentencepiece==0.1.94
six==1.15.0
threadpoolctl==2.1.0
tokenizers==0.9.2
torch==1.6.0
tqdm==4.48.2
transformers==3.4.0
typing-extensions==3.7.4.3
urllib3>=1.26.4
zipp==3.4.0
This diff is collapsed. Click to expand it.
"""Automatic label search helpers."""
import itertools
import torch
import tqdm
import multiprocessing
import numpy as np
import scipy.spatial as spatial
import scipy.special as special
import scipy.stats as stats
import logging
logger = logging.getLogger(__name__)
def select_likely_words(train_logits, train_labels, k_likely=1000, vocab=None, is_regression=False):
"""Pre-select likely words based on conditional likelihood."""
indices = []
if is_regression:
median = np.median(train_labels)
train_labels = (train_labels > median).astype(np.int)
num_labels = np.max(train_labels) + 1
for idx in range(num_labels):
label_logits = train_logits[train_labels == idx]
scores = label_logits.mean(axis=0)
kept = []
for i in np.argsort(-scores):
text = vocab[i]
if not text.startswith("Ġ"):
continue
kept.append(i)
indices.append(kept[:k_likely])
return indices
def select_neighbors(distances, k_neighbors, valid):
"""Select k nearest neighbors based on distance (filtered to be within the 'valid' set)."""
indices = np.argsort(distances)
neighbors = []
for i in indices:
if i not in valid:
continue
neighbors.append(i)
if k_neighbors > 0:
return neighbors[:k_neighbors]
return neighbors
def init(train_logits, train_labels):
global logits, labels
logits = train_logits
labels = train_labels
def eval_pairing_acc(pairing):
global logits, labels
label_logits = np.take(logits, pairing, axis=-1)
preds = np.argmax(label_logits, axis=-1)
correct = np.sum(preds == labels)
return correct / len(labels)
def eval_pairing_corr(pairing):
global logits, labels
if pairing[0] == pairing[1]:
return -1
label_logits = np.take(logits, pairing, axis=-1)
label_probs = special.softmax(label_logits, axis=-1)[:, 1]
pearson_corr = stats.pearsonr(label_probs, labels)[0]
return pearson_corr
def find_labels(
model,
train_logits,
train_labels,
seed_labels=None,
k_likely=1000,
k_neighbors=None,
top_n=-1,
vocab=None,
is_regression=False,
):
# Get top indices based on conditional likelihood using the LM.
likely_indices = select_likely_words(
train_logits=train_logits,
train_labels=train_labels,
k_likely=k_likely,
vocab=vocab,
is_regression=is_regression)
logger.info("Top labels (conditional) per class:")
for i, inds in enumerate(likely_indices):
logger.info("\t| Label %d: %s", i, ", ".join([vocab[i] for i in inds[:10]]))
# Convert to sets.
valid_indices = [set(inds) for inds in likely_indices]
# If specified, further re-rank according to nearest neighbors of seed labels.
# Otherwise, keep ranking as is (based on conditional likelihood only).
if seed_labels:
assert(vocab is not None)
seed_ids = [vocab.index(l) for l in seed_labels]
vocab_vecs = model.lm_head.decoder.weight.detach().cpu().numpy()
seed_vecs = np.take(vocab_vecs, seed_ids, axis=0)
# [num_labels, vocab_size]
label_distances = spatial.distance.cdist(seed_vecs, vocab_vecs, metric="cosine")
# Establish label candidates (as k nearest neighbors).
label_candidates = []
logger.info("Re-ranked by nearest neighbors:")
for i, distances in enumerate(label_distances):
label_candidates.append(select_neighbors(distances, k_neighbors, valid_indices[i]))
logger.info("\t| Label: %s", seed_labels[i])
logger.info("\t| Neighbors: %s", " ".join([vocab[idx] for idx in label_candidates[i]]))
else:
label_candidates = likely_indices
# Brute-force search all valid pairings.
pairings = list(itertools.product(*label_candidates))
if is_regression:
eval_pairing = eval_pairing_corr
metric = "corr"
else:
eval_pairing = eval_pairing_acc
metric = "acc"
# Score each pairing.
pairing_scores = []
with multiprocessing.Pool(initializer=init, initargs=(train_logits, train_labels)) as workers:
with tqdm.tqdm(total=len(pairings)) as pbar:
chunksize = max(10, int(len(pairings) / 1000))
for score in workers.imap(eval_pairing, pairings, chunksize=chunksize):
pairing_scores.append(score)
pbar.update()
# Take top-n.
best_idx = np.argsort(-np.array(pairing_scores))[:top_n]
best_scores = [pairing_scores[i] for i in best_idx]
best_pairings = [pairings[i] for i in best_idx]
logger.info("Automatically searched pairings:")
for i, indices in enumerate(best_pairings):
logger.info("\t| %s (%s = %2.2f)", " ".join([vocab[j] for j in indices]), metric, best_scores[i])
return best_pairings
"""Custom models for few-shot learning specific operations."""
import torch
import torch.nn as nn
import transformers
from transformers.modeling_bert import BertPreTrainedModel, BertForSequenceClassification, BertModel, BertOnlyMLMHead
from transformers.modeling_roberta import RobertaForSequenceClassification, RobertaModel, RobertaLMHead, RobertaClassificationHead
from transformers.modeling_outputs import SequenceClassifierOutput
import logging
logger = logging.getLogger(__name__)
def resize_token_type_embeddings(model, new_num_types: int, random_segment: bool):
"""
Resize the segment (token type) embeddings for BERT
"""
if hasattr(model, 'bert'):
old_token_type_embeddings = model.bert.embeddings.token_type_embeddings
else:
raise NotImplementedError
new_token_type_embeddings = nn.Embedding(new_num_types, old_token_type_embeddings.weight.size(1))
if not random_segment:
new_token_type_embeddings.weight.data[:old_token_type_embeddings.weight.size(0)] = old_token_type_embeddings.weight.data
model.config.type_vocab_size = new_num_types
if hasattr(model, 'bert'):
model.bert.embeddings.token_type_embeddings = new_token_type_embeddings
else:
raise NotImplementedError
class BertForPromptFinetuning(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
# These attributes should be assigned once the model is initialized
self.model_args = None
self.data_args = None
self.label_word_list = None
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
mask_pos=None,
labels=None,
):
batch_size = input_ids.size(0)
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.cls(sequence_mask_output)
# Exit early and only return mask logits.
if self.return_full_softmax:
if labels is not None:
return torch.zeros(1, out=prediction_mask_scores.new()), prediction_mask_scores
return prediction_mask_scores
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
return ((loss,) + output) if loss is not None else output
class RobertaForPromptFinetuning(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config)
self.lm_head = RobertaLMHead(config)
self.init_weights()
# These attributes should be assigned once the model is initialized
self.model_args = None
self.data_args = None
self.label_word_list = None
# For regression
self.lb = None
self.ub = None
# For auto label search.
self.return_full_softmax = None
def forward(
self,
input_ids=None,
attention_mask=None,
mask_pos=None,
labels=None,
):
batch_size = input_ids.size(0)
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
outputs = self.roberta(
input_ids,
attention_mask=attention_mask
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.lm_head(sequence_mask_output)
# Exit early and only return mask logits.
if self.return_full_softmax:
if labels is not None:
return torch.zeros(1, out=prediction_mask_scores.new()), prediction_mask_scores
return prediction_mask_scores
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
return ((loss,) + output) if loss is not None else output
import argparse
import json
import numpy as np
import torch
from torch import device
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
# These options should be kept as their default values
parser.add_argument("--log", type=str, default="log", help="Log path.")
parser.add_argument("--key", type=str, default='', help="Validation metric name")
parser.add_argument("--test_key", type=str, default="", help="Test metric name")
parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
args = parser.parse_args()
condition = eval(args.condition)
if len(args.key) == 0:
if condition['task_name'] == 'cola':
args.key = 'cola_dev_eval_mcc'
args.test_key = 'cola_test_eval_mcc'
elif condition['task_name'] == 'mrpc/acc':
args.key = 'mrpc_dev_eval_acc'
args.test_key = 'mrpc_test_eval_acc'
args.test_key2 = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
elif condition['task_name'] == 'mrpc/f1':
args.key = 'mrpc_dev_eval_f1'
args.test_key2 = 'mrpc_test_eval_acc'
args.test_key = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
elif condition['task_name'] == 'qqp/acc':
args.key = 'qqp_dev_eval_acc'
args.test_key = 'qqp_test_eval_acc'
args.test_key2 = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
elif condition['task_name'] == 'qqp/f1':
args.key = 'qqp_dev_eval_f1'
args.test_key2 = 'qqp_test_eval_acc'
args.test_key = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
elif condition['task_name'] == 'sts-b/pearson':
args.key = 'sts-b_dev_eval_pearson'
args.test_key = 'sts-b_test_eval_pearson'
args.test_key2 = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
elif condition['task_name'] == 'sts-b/spearmanr':
args.key = 'sts-b_dev_eval_spearmanr'
args.test_key2 = 'sts-b_test_eval_pearson'
args.test_key = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
elif condition['task_name'] == 'qnli':
args.key = 'qnli_dev_eval_acc'
args.test_key = 'qnli_test_eval_acc'
elif condition['task_name'] == 'sst-2':
args.key = 'sst-2_dev_eval_acc'
args.test_key = 'sst-2_test_eval_acc'
elif condition['task_name'] == 'snli':
args.key = 'snli_dev_eval_acc'
args.test_key = 'snli_test_eval_acc'
elif condition['task_name'] == 'mnli':
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli_test_eval_mnli/acc'
elif condition['task_name'] == 'mnli-mm':
condition['task_name'] = 'mnli'
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
elif condition['task_name'] == 'rte':
args.key = 'rte_dev_eval_acc'
args.test_key = 'rte_test_eval_acc'
elif condition['task_name'] == 'ag_news':
args.key = 'ag_news_dev_eval_acc'
args.test_key = 'ag_news_test_eval_acc'
elif condition['task_name'] == 'yahoo_answers':
args.key = 'yahoo_answers_dev_eval_acc'
args.test_key = 'yahoo_answers_test_eval_acc'
elif condition['task_name'] == 'yelp_review_full':
args.key = 'yelp_review_full_dev_eval_acc'
args.test_key = 'yelp_review_full_test_eval_acc'
elif condition['task_name'] == 'mr':
args.key = 'mr_dev_eval_acc'
args.test_key = 'mr_test_eval_acc'
elif condition['task_name'] == 'sst-5':
args.key = 'sst-5_dev_eval_acc'
args.test_key = 'sst-5_test_eval_acc'
elif condition['task_name'] == 'subj':
args.key = 'subj_dev_eval_acc'
args.test_key = 'subj_test_eval_acc'
elif condition['task_name'] == 'trec':
args.key = 'trec_dev_eval_acc'
args.test_key = 'trec_test_eval_acc'
elif condition['task_name'] == 'cr':
args.key = 'cr_dev_eval_acc'
args.test_key = 'cr_test_eval_acc'
elif condition['task_name'] == 'mpqa':
args.key = 'mpqa_dev_eval_acc'
args.test_key = 'mpqa_test_eval_acc'
else:
raise NotImplementedError
with open(args.log) as f:
result_list = []
for line in f:
result_list.append(eval(line))
seed_result = {}
seed_best = {}
for item in result_list:
ok = True
for cond in condition:
if isinstance(condition[cond], list):
if cond not in item or (item[cond] not in condition[cond]):
ok = False
break
else:
if cond not in item or (item[cond] != condition[cond]):
ok = False
break
if ok:
seed = item['data_dir'].split('-')[-1] + '-' + str(item['seed'])
if seed not in seed_result:
seed_result[seed] = [item]
seed_best[seed] = item
else:
seed_result[seed].append(item)
if item[args.key] > seed_best[seed][args.key]:
seed_best[seed] = item
final_result_dev = np.zeros((len(seed_best)))
final_result_test = np.zeros((len(seed_best)))
final_result_test2 = np.zeros((len(seed_best)))
for i, seed in enumerate(seed_best):
final_result_dev[i] = seed_best[seed][args.key]
final_result_test[i] = seed_best[seed][args.test_key]
if len(args.test_key2) > 0:
final_result_test2[i] = seed_best[seed][args.test_key2]
print("%s: best dev (%.5f) test (%.5f) %s | total trials: %d" % (
seed,
seed_best[seed][args.key],
seed_best[seed][args.test_key],
"test2 (%.5f)" % (seed_best[seed][args.test_key2]) if len(args.test_key2) > 0 else "",
len(seed_result[seed])
))
s = ''
for k in ['per_device_train_batch_size', 'gradient_accumulation_steps', 'learning_rate', 'eval_steps', 'max_steps']:
s += '| {}: {} '.format(k, seed_best[seed][k])
print(' ' + s)
s = "mean +- std: %.1f (%.1f) (median %.1f)" % (final_result_test.mean() * 100, final_result_test.std() * 100, np.median(final_result_test) * 100)
if len(args.test_key2) > 0:
s += "second metric: %.1f (%.1f) (median %.1f)" % (final_result_test2.mean() * 100, final_result_test2.std() * 100, np.median(final_result_test2) * 100)
print(s)
if __name__ == '__main__':
main()
"""This script samples K examples randomly without replacement from the original data."""
import argparse
import os
import numpy as np
import pandas as pd
from pandas import DataFrame
def get_label(task, line):
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
# GLUE style
line = line.strip().split('\t')
if task == 'CoLA':
return line[1]
elif task == 'MNLI':
return line[-1]
elif task == 'MRPC':
return line[0]
elif task == 'QNLI':
return line[-1]
elif task == 'QQP':
return line[-1]
elif task == 'RTE':
return line[-1]
elif task == 'SNLI':
return line[-1]
elif task == 'SST-2':
return line[-1]
elif task == 'STS-B':
return 0 if float(line[-1]) < 2.5 else 1
elif task == 'WNLI':
return line[-1]
else:
raise NotImplementedError
else:
return line[0]
def load_datasets(data_dir, tasks):
datasets = {}
for task in tasks:
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
# GLUE style (tsv)
dataset = {}
dirname = os.path.join(data_dir, task)
if task == "MNLI":
splits = ["train", "dev_matched", "dev_mismatched"]
else:
splits = ["train", "dev"]
for split in splits:
filename = os.path.join(dirname, f"{split}.tsv")
with open(filename, "r") as f:
lines = f.readlines()
dataset[split] = lines
datasets[task] = dataset
else:
# Other datasets (csv)
dataset = {}
dirname = os.path.join(data_dir, task)
splits = ["train", "test"]
for split in splits:
filename = os.path.join(dirname, f"{split}.csv")
dataset[split] = pd.read_csv(filename, header=None)
datasets[task] = dataset
return datasets
def split_header(task, lines):
"""
Returns if the task file has a header or not. Only for GLUE tasks.
"""
if task in ["CoLA"]:
return [], lines
elif task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI"]:
return lines[0:1], lines[1:]
else:
raise ValueError("Unknown GLUE task.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--k", type=int, default=16,
help="Training examples for each class.")
parser.add_argument("--task", type=str, nargs="+",
default=['SST-2', 'sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'],
help="Task names")
parser.add_argument("--seed", type=int, nargs="+",
default=[100, 13, 21, 42, 87],
help="Random seeds")
parser.add_argument("--data_dir", type=str, default="data/original", help="Path to original data")
parser.add_argument("--output_dir", type=str, default="data", help="Output path")
parser.add_argument("--mode", type=str, default='k-shot', choices=['k-shot', 'k-shot-10x'], help="k-shot or k-shot-10x (10x dev set)")
args = parser.parse_args()
args.output_dir = os.path.join(args.output_dir, args.mode)
k = args.k
print("K =", k)
datasets = load_datasets(args.data_dir, args.task)
for seed in args.seed:
print("Seed = %d" % (seed))
for task, dataset in datasets.items():
# Set random seed
np.random.seed(seed)
# Shuffle the training set
print("| Task = %s" % (task))
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
# GLUE style
train_header, train_lines = split_header(task, dataset["train"])
np.random.shuffle(train_lines)
else:
# Other datasets
train_lines = dataset['train'].values.tolist()
np.random.shuffle(train_lines)
# Set up dir
task_dir = os.path.join(args.output_dir, task)
setting_dir = os.path.join(task_dir, f"{k}-{seed}")
os.makedirs(setting_dir, exist_ok=True)
# Write test splits
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
# GLUE style
# Use the original development set as the test set (the original test sets are not publicly available)
for split, lines in dataset.items():
if split.startswith("train"):
continue
split = split.replace('dev', 'test')
with open(os.path.join(setting_dir, f"{split}.tsv"), "w") as f:
for line in lines:
f.write(line)
else:
# Other datasets
# Use the original test sets
dataset['test'].to_csv(os.path.join(setting_dir, 'test.csv'), header=False, index=False)
# Get label list for balanced sampling
label_list = {}
for line in train_lines:
label = get_label(task, line)
if label not in label_list:
label_list[label] = [line]
else:
label_list[label].append(line)
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
with open(os.path.join(setting_dir, "train.tsv"), "w") as f:
for line in train_header:
f.write(line)
for label in label_list:
for line in label_list[label][:k]:
f.write(line)
name = "dev.tsv"
if task == 'MNLI':
name = "dev_matched.tsv"
with open(os.path.join(setting_dir, name), "w") as f:
for line in train_header:
f.write(line)
for label in label_list:
dev_rate = 11 if '10x' in args.mode else 2
for line in label_list[label][k:k*dev_rate]:
f.write(line)
else:
new_train = []
for label in label_list:
for line in label_list[label][:k]:
new_train.append(line)
new_train = DataFrame(new_train)
new_train.to_csv(os.path.join(setting_dir, 'train.csv'), header=False, index=False)
new_dev = []
for label in label_list:
dev_rate = 11 if '10x' in args.mode else 2
for line in label_list[label][k:k*dev_rate]:
new_dev.append(line)
new_dev = DataFrame(new_dev)
new_dev.to_csv(os.path.join(setting_dir, 'dev.csv'), header=False, index=False)
if __name__ == "__main__":
main()
"""Finetuning the library models for sequence classification on GLUE."""
import os, sys, inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
import logging
import json
from dataclasses import dataclass, field
from typing import Optional
from transformers import AutoConfig, AutoTokenizer
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import HfArgumentParser, TrainingArguments, set_seed
from src.label_search import find_labels
from src.dataset import FewShotDataset
from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings
from src.trainer import Trainer
from src.processors import output_modes_mapping, num_labels_mapping
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
@dataclass
class DynamicDataTrainingArguments(DataTrainingArguments):
"""
Arguments for dynamic training.
"""
# For prompting
template: str = field(
default=None,
metadata={"help": "Template"}
)
mapping: str = field(
default=None,
metadata={"help": "Label word mapping"}
)
debug_mode: bool = field(
default=False,
metadata={"help": "Debug mode"}
)
first_sent_limit: int = field(
default=None,
metadata={"help": "Limit the length of the first sentence (i.e., sent_0)"}
)
other_sent_limit: int = field(
default=None,
metadata={"help": "Limit the length of sentences other than the first sentence"}
)
use_full_length: bool = field(
default=None,
metadata={"help": "Use the full length (512)"}
)
truncate_head: bool = field(
default=False,
metadata={"help": "When exceeding the maximum length, truncate the head instead of the tail."}
)
use_space_word: bool = field(
default=True,
metadata={"help": "Use space words (e.g., Gpositive) instead of original words."}
)
use_seed_labels: bool = field(
default=False,
metadata={"help": "Regularize using seed labels"},
)
k_likely: int = field(
default=100,
metadata={"help": "Take the top-k most (conditionally) likely labels per class."}
)
k_neighbors: int = field(
default=50,
metadata={"help": "Re-rank by nearest neighbor, and take the top k."}
)
n_pairs: int = field(
default=32,
metadata={"help": "Number of label pairings to use."}
)
output_file: str = field(
default="out",
metadata={"help": "Output file"}
)
append_output_file: bool = field(
default=False,
)
write_template: bool = field(
default=False,
)
def main():
parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Fix prompt to be true.
data_args.prompt = True
data_args.num_sample = 1
data_args.template_list = None
data_args.gpt3_in_context_head = False
data_args.gpt3_in_context_tail = False
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
# Check save path
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(f"Output directory ({training_args.output_dir}) already exists.")
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
# Set seed
set_seed(training_args.seed)
try:
num_labels = num_labels_mapping[data_args.task_name]
output_mode = output_modes_mapping[data_args.task_name]
logger.info("Task name: {}, number of labels: {}, output mode: {}".format(data_args.task_name, num_labels, output_mode))
except KeyError:
raise ValueError("Task not found: %s" % (data_args.task_name))
# Create config
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
)
if config.model_type == 'roberta':
model_fn = RobertaForPromptFinetuning
elif config.model_type == 'bert':
model_fn = BertForPromptFinetuning
else:
raise NotImplementedError
special_tokens = []
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
additional_special_tokens=special_tokens,
cache_dir=model_args.cache_dir,
)
set_seed(training_args.seed)
model = model_fn.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
# For BERT, increase the size of the segment (token type) embeddings
if config.model_type == 'bert':
model.resize_token_embeddings(len(tokenizer))
resize_token_type_embeddings(model, new_num_types=10, random_segment=model_args.random_segment)
# Pass dataset and argument information to the model
model.model_args = model_args
model.data_args = data_args
model.tokenizer = tokenizer
model.return_full_softmax = True
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=None,
eval_dataset=None,
)
# First we compute zero-shot logits on all of the examples.
dataset = FewShotDataset(data_args, tokenizer=tokenizer, mode="train", use_demo=False)
# Predict logits.
dataloader = trainer.get_eval_dataloader(dataset)
output = trainer.prediction_loop(dataloader, description="Evaluation")
logits = output.predictions[0] if isinstance(output.predictions, (list, tuple)) else output.predictions
labels = output.label_ids
# Assign words to labels.
if data_args.use_seed_labels:
if data_args.use_space_word:
seed_labels = {k: "Ġ" + v for k, v in eval(data_args.mapping).items()}
else:
seed_labels = eval(data_args.word_mapping)
seed_labels = [seed_labels[label] for label in dataset.get_labels()]
else:
seed_labels = None
vocab = list(tokenizer.get_vocab())
# Find best labels.
label_pairings = find_labels(
model=trainer.model,
train_logits=logits,
train_labels=labels,
seed_labels=seed_labels,
k_likely=data_args.k_likely,
k_neighbors=data_args.k_neighbors,
top_n=data_args.n_pairs,
vocab=vocab,
is_regression=config.num_labels == 1)
labels = dataset.get_labels()
if config.num_labels == 1:
labels = ['0', '1']
os.makedirs(os.path.dirname(data_args.output_file), exist_ok=True)
if data_args.append_output_file:
mode = "a"
else:
mode = "w"
# Write to output.
with open(data_args.output_file, mode) as f:
for pairing in label_pairings:
words = [vocab[i][len("Ġ"):] for i in pairing]
mapping = {labels[i]: words[i] for i in range(len(labels))}
if data_args.write_template:
f.write(data_args.template + "\t")
f.write(json.dumps(mapping) + "\n")
if __name__ == "__main__":
main()
from sentence_transformers import SentenceTransformer, util
import argparse
import os
import numpy as np
from tqdm import tqdm
import pandas as pd
def get_sentence(task, line):
if task in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']:
# Text classification tasks
if line[1] is None or pd.isna(line[1]):
return ''
else:
return line[1]
else:
# GLUE tasks
line = line.strip().split('\t')
if task == 'CoLA':
return line[-1]
elif task == 'MNLI':
return line[8] + ' ' + line[9]
elif task == 'MRPC':
return line[-2] + ' ' + line[-1]
elif task == 'QNLI':
return line[1] + ' ' + line[2]
elif task == 'QQP':
return line[3] + ' ' + line[4]
elif task == 'RTE':
return line[1] + ' ' + line[2]
elif task == 'SNLI':
return line[7] + ' ' + line[8]
elif task == 'SST-2':
return line[0]
elif task == 'STS-B':
return line[-3] + ' ' + line[-2]
elif task == 'WNLI':
return line[1] + ' ' + line[2]
else:
raise NotImplementedError
def split_header(task, lines):
"""Returns if the task file has a header or not."""
if task in ["CoLA"]:
return [], lines
elif task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI"]:
return lines[0:1], lines[1:]
else:
raise ValueError("Unknown GLUE task.")
def load_datasets(data_dir, task, do_test=False):
dataset = {}
if task == "MNLI":
splits = ["train", "dev_matched"]
if do_test:
splits += ['test_matched', 'test_mismatched']
else:
splits = ["train", "dev"]
if do_test:
splits.append('test')
for split in splits:
if task in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']:
filename = os.path.join(data_dir, f"{split}.csv")
dataset[split] = pd.read_csv(filename, header=None).values.tolist()
else:
filename = os.path.join(data_dir, f"{split}.tsv")
with open(filename, "r") as f:
lines = f.readlines()
header, content = split_header(task, lines)
dataset[split] = content
return dataset
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--do_test", action='store_true', help="Generate embeddings for test splits (test set is usually large, so we don't want to repeatedly generate embeddings for them)")
parser.add_argument("--sbert_model", type=str, default='roberta-large', help="Sentence BERT model name")
parser.add_argument("--k", type=int, help="Number of training instances per label", default=16)
parser.add_argument("--data_dir", type=str, default="data/k-shot", help="Path to few-shot data")
parser.add_argument("--seed", type=int, nargs="+", default=[42, 13, 21, 87, 100], help="Seeds for data splits")
parser.add_argument("--task", type=str, nargs="+", default=["SST-2", "sst-5", "mr", "cr", "mpqa", "subj", "trec", "CoLA", "MRPC", "QQP", "STS-B", "MNLI", "SNLI", "QNLI", "RTE"], help="Tasks")
args = parser.parse_args()
model = SentenceTransformer('{}-nli-stsb-mean-tokens'.format(args.sbert_model))
model = model.cuda()
for task in args.task:
for seed in args.seed:
folder = os.path.join(args.data_dir, task, '{}-{}'.format(args.k, seed))
dataset = load_datasets(folder, task, do_test=args.do_test)
for split in dataset:
print('{}-{}-{}-{}'.format(task, args.k, seed, split))
lines = dataset[split]
embeddings = []
for line_id, line in tqdm(enumerate(lines)):
sent = get_sentence(task, line)
if line_id == 0:
print('|', sent)
emb = model.encode(sent)
embeddings.append(emb)
embeddings = np.stack(embeddings)
np.save(os.path.join(folder, "{}_sbert-{}.npy".format(split, args.sbert_model)), embeddings)
if __name__ == '__main__':
main()
import argparse
import json
import numpy as np
import torch
from torch import device
import os
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
parser.add_argument('--mapping_dir', type=str, help='Mapping directory')
# These options should be kept as their default values
parser.add_argument("--k", type=int, default=16)
parser.add_argument("--log", type=str, default="log", help="Log path.")
parser.add_argument("--key", type=str, default='', help="Validation metric name")
parser.add_argument("--test_key", type=str, default="", help="Test metric name")
parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
args = parser.parse_args()
condition = eval(args.condition)
if len(args.key) == 0:
if condition['task_name'] == 'cola':
args.key = 'cola_dev_eval_mcc'
args.test_key = 'cola_test_eval_mcc'
print_name = 'CoLA'
elif condition['task_name'] == 'mrpc/acc':
args.key = 'mrpc_dev_eval_acc'
args.test_key = 'mrpc_test_eval_acc'
args.test_key2 = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'mrpc/f1':
args.key = 'mrpc_dev_eval_f1'
args.test_key2 = 'mrpc_test_eval_acc'
args.test_key = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'qqp/acc':
args.key = 'qqp_dev_eval_acc'
args.test_key = 'qqp_test_eval_acc'
args.test_key2 = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'qqp/f1':
args.key = 'qqp_dev_eval_f1'
args.test_key2 = 'qqp_test_eval_acc'
args.test_key = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'sts-b/pearson':
args.key = 'sts-b_dev_eval_pearson'
args.test_key = 'sts-b_test_eval_pearson'
args.test_key2 = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'sts-b/spearmanr':
args.key = 'sts-b_dev_eval_spearmanr'
args.test_key2 = 'sts-b_test_eval_pearson'
args.test_key = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'qnli':
args.key = 'qnli_dev_eval_acc'
args.test_key = 'qnli_test_eval_acc'
print_name = 'QNLI'
elif condition['task_name'] == 'sst-2':
args.key = 'sst-2_dev_eval_acc'
args.test_key = 'sst-2_test_eval_acc'
print_name = 'SST-2'
elif condition['task_name'] == 'snli':
args.key = 'snli_dev_eval_acc'
args.test_key = 'snli_test_eval_acc'
print_name = 'SNLI'
elif condition['task_name'] == 'mnli':
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli_test_eval_mnli/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'mnli-mm':
condition['task_name'] = 'mnli'
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'rte':
args.key = 'rte_dev_eval_acc'
args.test_key = 'rte_test_eval_acc'
print_name = 'RTE'
elif condition['task_name'] == 'ag_news':
args.key = 'ag_news_dev_eval_acc'
args.test_key = 'ag_news_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yahoo_answers':
args.key = 'yahoo_answers_dev_eval_acc'
args.test_key = 'yahoo_answers_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yelp_review_full':
args.key = 'yelp_review_full_dev_eval_acc'
args.test_key = 'yelp_review_full_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mr':
args.key = 'mr_dev_eval_acc'
args.test_key = 'mr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'sst-5':
args.key = 'sst-5_dev_eval_acc'
args.test_key = 'sst-5_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'subj':
args.key = 'subj_dev_eval_acc'
args.test_key = 'subj_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'trec':
args.key = 'trec_dev_eval_acc'
args.test_key = 'trec_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'cr':
args.key = 'cr_dev_eval_acc'
args.test_key = 'cr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mpqa':
args.key = 'mpqa_dev_eval_acc'
args.test_key = 'mpqa_test_eval_acc'
print_name = condition['task_name']
else:
raise NotImplementedError
with open(args.log) as f:
result_list = []
for line in f:
result_list.append(eval(line))
seed_result = {}
seed_result_mapping_id = {} # avoid duplication
for item in result_list:
ok = True
for cond in condition:
if cond not in item or item[cond] != condition[cond]:
ok = False
break
if ok:
seed = item['seed']
if seed not in seed_result:
seed_result[seed] = [item]
seed_result_mapping_id[seed] = {item['mapping_id']: 1}
else:
if item['mapping_id'] not in seed_result_mapping_id[seed]:
seed_result[seed].append(item)
seed_result_mapping_id[seed][item['mapping_id']] = 1
for seed in seed_result:
print("Seed %d has %d results" % (seed, len(seed_result[seed])))
# Load all mappings
with open(os.path.join(args.mapping_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f:
mappings = []
for line in f:
mappings.append(line.strip())
# Write sorted mappings
fsort = open(os.path.join(args.mapping_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w')
fscore = open(os.path.join(args.mapping_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w')
seed_result[seed].sort(key=lambda x: x[args.key], reverse=True)
for item in seed_result[seed]:
fsort.write(mappings[item['mapping_id']] + '\n')
fscore.write("%.5f %s\n" % (item[args.key], mappings[item['mapping_id']]))
if __name__ == '__main__':
main()
import argparse
import json
import numpy as np
import torch
from torch import device
import os
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
parser.add_argument('--prompt_dir', type=str, help='Prompt directory')
# These options should be kept as their default values
parser.add_argument("--k", type=int, default=16)
parser.add_argument("--log", type=str, default="log", help="Log path.")
parser.add_argument("--key", type=str, default='', help="Validation metric name")
parser.add_argument("--test_key", type=str, default="", help="Test metric name")
parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
args = parser.parse_args()
condition = eval(args.condition)
if len(args.key) == 0:
if condition['task_name'] == 'cola':
args.key = 'cola_dev_eval_mcc'
args.test_key = 'cola_test_eval_mcc'
print_name = 'CoLA'
elif condition['task_name'] == 'mrpc/acc':
args.key = 'mrpc_dev_eval_acc'
args.test_key = 'mrpc_test_eval_acc'
args.test_key2 = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'mrpc/f1':
args.key = 'mrpc_dev_eval_f1'
args.test_key2 = 'mrpc_test_eval_acc'
args.test_key = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'qqp/acc':
args.key = 'qqp_dev_eval_acc'
args.test_key = 'qqp_test_eval_acc'
args.test_key2 = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'qqp/f1':
args.key = 'qqp_dev_eval_f1'
args.test_key2 = 'qqp_test_eval_acc'
args.test_key = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'sts-b/pearson':
args.key = 'sts-b_dev_eval_pearson'
args.test_key = 'sts-b_test_eval_pearson'
args.test_key2 = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'sts-b/spearmanr':
args.key = 'sts-b_dev_eval_spearmanr'
args.test_key2 = 'sts-b_test_eval_pearson'
args.test_key = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'qnli':
args.key = 'qnli_dev_eval_acc'
args.test_key = 'qnli_test_eval_acc'
print_name = 'QNLI'
elif condition['task_name'] == 'sst-2':
args.key = 'sst-2_dev_eval_acc'
args.test_key = 'sst-2_test_eval_acc'
print_name = 'SST-2'
elif condition['task_name'] == 'snli':
args.key = 'snli_dev_eval_acc'
args.test_key = 'snli_test_eval_acc'
print_name = 'SNLI'
elif condition['task_name'] == 'mnli':
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli_test_eval_mnli/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'mnli-mm':
condition['task_name'] = 'mnli'
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'rte':
args.key = 'rte_dev_eval_acc'
args.test_key = 'rte_test_eval_acc'
print_name = 'RTE'
elif condition['task_name'] == 'ag_news':
args.key = 'ag_news_dev_eval_acc'
args.test_key = 'ag_news_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yahoo_answers':
args.key = 'yahoo_answers_dev_eval_acc'
args.test_key = 'yahoo_answers_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yelp_review_full':
args.key = 'yelp_review_full_dev_eval_acc'
args.test_key = 'yelp_review_full_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mr':
args.key = 'mr_dev_eval_acc'
args.test_key = 'mr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'sst-5':
args.key = 'sst-5_dev_eval_acc'
args.test_key = 'sst-5_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'subj':
args.key = 'subj_dev_eval_acc'
args.test_key = 'subj_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'trec':
args.key = 'trec_dev_eval_acc'
args.test_key = 'trec_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'cr':
args.key = 'cr_dev_eval_acc'
args.test_key = 'cr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mpqa':
args.key = 'mpqa_dev_eval_acc'
args.test_key = 'mpqa_test_eval_acc'
print_name = condition['task_name']
else:
raise NotImplementedError
with open(args.log) as f:
result_list = []
for line in f:
result_list.append(eval(line))
seed_result = {}
seed_result_prompt_id = {} # avoid duplication
for item in result_list:
ok = True
for cond in condition:
if cond not in item or item[cond] != condition[cond]:
ok = False
break
if ok:
seed = item['seed']
if seed not in seed_result:
seed_result[seed] = [item]
seed_result_prompt_id[seed] = {item['prompt_id']: 1}
else:
if item['prompt_id'] not in seed_result_prompt_id[seed]:
seed_result[seed].append(item)
seed_result_prompt_id[seed][item['prompt_id']] = 1
for seed in seed_result:
print("Seed %d has %d results" % (seed, len(seed_result[seed])))
# Load all prompts
with open(os.path.join(args.prompt_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f:
prompts = []
for line in f:
prompts.append(line.strip())
# Write sorted prompts
fsort = open(os.path.join(args.prompt_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w')
fscore = open(os.path.join(args.prompt_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w')
seed_result[seed].sort(key=lambda x: x[args.key], reverse=True)
for item in seed_result[seed]:
fsort.write(prompts[item['prompt_id']] + '\n')
fscore.write("%.5f %s\n" % (item[args.key], prompts[item['prompt_id']]))
if __name__ == '__main__':
main()
import argparse
import json
import numpy as np
import torch
from torch import device
import os
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
parser.add_argument('--template_dir', type=str, help='Template directory')
# These options should be kept as their default values
parser.add_argument("--k", type=int, default=16)
parser.add_argument("--log", type=str, default="log", help="Log path.")
parser.add_argument("--key", type=str, default='', help="Validation metric name")
parser.add_argument("--test_key", type=str, default="", help="Test metric name")
parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
args = parser.parse_args()
condition = eval(args.condition)
if len(args.key) == 0:
if condition['task_name'] == 'cola':
args.key = 'cola_dev_eval_mcc'
args.test_key = 'cola_test_eval_mcc'
print_name = 'CoLA'
elif condition['task_name'] == 'mrpc/acc':
args.key = 'mrpc_dev_eval_acc'
args.test_key = 'mrpc_test_eval_acc'
args.test_key2 = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'mrpc/f1':
args.key = 'mrpc_dev_eval_f1'
args.test_key2 = 'mrpc_test_eval_acc'
args.test_key = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'qqp/acc':
args.key = 'qqp_dev_eval_acc'
args.test_key = 'qqp_test_eval_acc'
args.test_key2 = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'qqp/f1':
args.key = 'qqp_dev_eval_f1'
args.test_key2 = 'qqp_test_eval_acc'
args.test_key = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'sts-b/pearson':
args.key = 'sts-b_dev_eval_pearson'
args.test_key = 'sts-b_test_eval_pearson'
args.test_key2 = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'sts-b/spearmanr':
args.key = 'sts-b_dev_eval_spearmanr'
args.test_key2 = 'sts-b_test_eval_pearson'
args.test_key = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'qnli':
args.key = 'qnli_dev_eval_acc'
args.test_key = 'qnli_test_eval_acc'
print_name = 'QNLI'
elif condition['task_name'] == 'sst-2':
args.key = 'sst-2_dev_eval_acc'
args.test_key = 'sst-2_test_eval_acc'
print_name = 'SST-2'
elif condition['task_name'] == 'snli':
args.key = 'snli_dev_eval_acc'
args.test_key = 'snli_test_eval_acc'
print_name = 'SNLI'
elif condition['task_name'] == 'mnli':
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli_test_eval_mnli/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'mnli-mm':
condition['task_name'] = 'mnli'
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'rte':
args.key = 'rte_dev_eval_acc'
args.test_key = 'rte_test_eval_acc'
print_name = 'RTE'
elif condition['task_name'] == 'ag_news':
args.key = 'ag_news_dev_eval_acc'
args.test_key = 'ag_news_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yahoo_answers':
args.key = 'yahoo_answers_dev_eval_acc'
args.test_key = 'yahoo_answers_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yelp_review_full':
args.key = 'yelp_review_full_dev_eval_acc'
args.test_key = 'yelp_review_full_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mr':
args.key = 'mr_dev_eval_acc'
args.test_key = 'mr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'sst-5':
args.key = 'sst-5_dev_eval_acc'
args.test_key = 'sst-5_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'subj':
args.key = 'subj_dev_eval_acc'
args.test_key = 'subj_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'trec':
args.key = 'trec_dev_eval_acc'
args.test_key = 'trec_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'cr':
args.key = 'cr_dev_eval_acc'
args.test_key = 'cr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mpqa':
args.key = 'mpqa_dev_eval_acc'
args.test_key = 'mpqa_test_eval_acc'
print_name = condition['task_name']
else:
raise NotImplementedError
with open(args.log) as f:
result_list = []
for line in f:
result_list.append(eval(line))
seed_result = {}
seed_result_template_id = {} # avoid duplication
for item in result_list:
ok = True
for cond in condition:
if cond not in item or item[cond] != condition[cond]:
ok = False
break
if ok:
seed = item['seed']
if seed not in seed_result:
seed_result[seed] = [item]
seed_result_template_id[seed] = {item['template_id']: 1}
else:
if item['template_id'] not in seed_result_template_id[seed]:
seed_result[seed].append(item)
seed_result_template_id[seed][item['template_id']] = 1
for seed in seed_result:
print("Seed %d has %d results" % (seed, len(seed_result[seed])))
# Load all templates
with open(os.path.join(args.template_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f:
templates = []
for line in f:
templates.append(line.strip())
# Write sorted templates
fsort = open(os.path.join(args.template_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w')
fscore = open(os.path.join(args.template_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w')
seed_result[seed].sort(key=lambda x: x[args.key], reverse=True)
for item in seed_result[seed]:
fsort.write(templates[item['template_id']] + '\n')
fscore.write("%.5f %s\n" % (item[args.key], templates[item['template_id']]))
if __name__ == '__main__':
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment