Commit cf1a4f65 by 20210828028

v1

parent da8ff096
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
*.sh
.venv
.vscode
data
!data/k-shot/checksum
log*
runs
result
wandb
ensemble_predict_results
auto*
my*
slurm
MIT License
Copyright (c) 2021 Princeton Natural Language Processing
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# LM-BFF (**B**etter **F**ew-shot **F**ine-tuning of **L**anguage **M**odels)
This is the implementation of the paper [Making Pre-trained Language Models Better Few-shot Learners](https://arxiv.org/pdf/2012.15723.pdf). LM-BFF is short for **b**etter **f**ew-shot **f**ine-tuning of **l**anguage **m**odels.
## Quick links
* [Overview](#overview)
* [Requirements](#requirements)
* [Prepare the data](#prepare-the-data)
* [Run the model](#run-lm-bff)
* [Quick start](#quick-start)
* [Experiments with multiple runs](#experiments-with-multiple-runs)
* [Using demonstrations with filtering](#using-demonstrations-with-filtering)
* [Automatically searched prompt](#automatically-searched-prompt)
* [Ensemble](#ensemble-model)
* [Zero-shot experiments](#zero-shot-experiments)
* [How to design your own templates](#how-to-design-your-own-templates)
* [Citation](#citation)
## Overview
![](./figs/lmbff.png)
In this work we present LM-BFF, a suite of simple and complementary techniques for fine-tuning pre-trained language models on a small number of training examples. Our approach includes:
1. Prompt-based fine-tuning together with a novel pipeline for automating prompt generation.
2. A refined strategy for incorporating demonstrations into context.
You can find more details of this work in our [paper](https://arxiv.org/pdf/2012.15723.pdf).
## Requirements
To run our code, please install all the dependency packages by using the following command:
```
pip install -r requirements.txt
```
**NOTE**: Different versions of packages (like `pytorch`, `transformers`, etc.) may lead to different results from the paper. However, the trend should still hold no matter what versions of packages you use.
## Prepare the data
We pack the original datasets (SST-2, SST-5, MR, CR, MPQA, Subj, TREC, CoLA, MNLI, SNLI, QNLI, RTE, MRPC, QQP, STS-B) [here](https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar). Please download it and extract the files to `./data/original`, or run the following commands:
```bash
cd data
bash download_dataset.sh
```
Then use the following command (in the root directory) to generate the few-shot data we need:
```bash
python tools/generate_k_shot_data.py
```
See `tools/generate_k_shot_data.py` for more options. For results in the paper, we use the default options: we take `K=16` and take 5 different seeds of 13, 21, 42, 87, 100. The few-shot data will be generated to `data/k-shot`. In the directory of each dataset, there will be folders named as `$K-$SEED` indicating different dataset samples. You can use the following command to check whether the generated data are exactly the same as ours:
```bash
cd data/k-shot
md5sum -c checksum
```
**NOTE**: During training, the model will generate/load cache files in the data folder. If your data have changed, make sure to clean all the cache files (starting with "cache").
## Run LM-BFF
### Quick start
Our code is built on [transformers](https://github.com/huggingface/transformers) and we use its `3.4.0` version. Other versions of `transformers` might cause unexpected errors.
Before running any experiments, create the result folder by `mkdir result` to save checkpoints. Then you can run our code with the following example:
```bash
python run.py \
--task_name SST-2 \
--data_dir data/k-shot/SST-2/16-42 \
--overwrite_output_dir \
--do_train \
--do_eval \
--do_predict \
--evaluate_during_training \
--model_name_or_path roberta-large \
--few_shot_type prompt-demo \
--num_k 16 \
--max_steps 1000 \
--eval_steps 100 \
--per_device_train_batch_size 2 \
--learning_rate 1e-5 \
--num_train_epochs 0 \
--output_dir result/tmp \
--seed 42 \
--template "*cls**sent_0*_It_was*mask*.*sep+*" \
--mapping "{'0':'terrible','1':'great'}" \
--num_sample 16 \
```
Most arguments are inherited from `transformers` and are easy to understand. We further explain some of the LM-BFF's arguments:
* `few_shot_type`: There are three modes
* `finetune`: Standard fine-tuning
* `prompt`: Prompt-based fine-tuning.
* `prompt-demo`: Prompt-based fine-tuning with demonstrations.
* `num_k`: Number of training instances for each class. We take `num_k`=16 in our paper. This argument is mainly used for indexing logs afterwards (because the training example numbers are actually decided by the data split you use).
* `template`: Template for prompt-based fine-tuning. We will introduce the template format later.
* `mapping`: Label word mapping for prompt-based fine-tuning. It is a string of dictionary indicating the mapping from label names to label words. **NOTE**: For RoBERTa, the model will automatically add space before the word. See the paper appendix for details.
* `num_sample`: When using demonstrations during inference, the number of samples for each input query. Say `num_sample`=16, then we sample 16 different sets of demonstrations for one input, do the forward seperately, and average the logits for all 16 samples as the final prediction.
Also, this codebase supports BERT-series and RoBERTa-series pre-trained models in Huggingface's `transformers`. You can check [Huggingface's website](https://huggingface.co/models) for available models and pass models with a "bert" or "roberta" in their names to `--model_name_or_path`. Some examples would be `bert-base-uncased`, `bert-large-uncased`, `roberta-base`, `roberta-large`, etc.
To easily run our experiments, you can also use `run_experiment.sh` (this command runs prompt-based fine-tuning with demonstrations, no filtering, manual prompt):
```bash
TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh
```
We have already defined the templates and label word mappings in it, so you only need manipulate several hyper-parameters and `TAG` (you can use whatever tag you want and it just makes finding results easier). See `run_experiment.sh` for more options of these environment variables. Besides, you can add extra arguments by
```bash
TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--output_dir result/exp --max_seq_length 512"
```
### Experiments with multiple runs
To carry out experiments with multiple data splits, as the evaluation protocol detailed in \$3.3 of [our paper](https://arxiv.org/pdf/2012.15723.pdf) (grid-search for each seed and aggregate the results over 5 different seeds), you can use the following scripts:
```bash
for seed in 13 21 42 87 100
do
for bs in 2 4 8
do
for lr in 1e-5 2e-5 5e-5
do
TAG=exp \
TYPE=prompt-demo \
TASK=SST-2 \
BS=$bs \
LR=$lr \
SEED=$seed \
MODEL=roberta-large \
bash run_experiment.sh
done
done
done
```
All the results will be stored in `./log`. To gather all the results, run the following command:
```bash
python tools/gather_result.py --condition "{'tag': 'exp', 'task_name': 'sst-2', 'few_shot_type': 'prompt-demo'}"
```
Then the program will find all the trials that satisfy the condition in `./log`, and print the mean/std of the final results. Note that the task names are all lower-cased and if the task has more than one metric, you need to specify the major metric (used for taking the best validation trial) in the name (e.g., `mnli`, `mnli-mm`, `mrpc/acc`, `mrpc/f1`, `qqp/acc`, `qqp/f1`, `sts-b/pearson`, `sts-b/spearman`).
### Using demonstrations with filtering
To use the filtering mechanism when using demonstrations, we need to first generate [Sentence-BERT](https://github.com/UKPLab/sentence-transformers) embeddings. To generate embeddings for datasets in our paper, you can directly run
```
bash tools/get_sbert_embedding.sh roberta-large
```
`roberta-large` can also be replaced by `bert-base`, `bert-large`, `roberta-base` and `distilbert-base` (see [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) for details). See `tools/get_sbert_embedding.sh` and `tools/get_sbert_embedding.py` if you want to add more datasets.
After generating the embeddings (embeddings are saved as numpy files in the data folders), we can run the following commands to do prompt-based fine-tuning with demonstrations with filtering:
```bash
TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--demo_filter --demo_filter_model sbert-roberta-large"
```
### Automatically searched prompt
We provide our automatic search results in `auto_template` and `auto_label_mapping`. There are three types of files:
* `SST-2/16-42.txt`: Initial search results for SST-2 dataset, K=16 and SEED=42.
* `SST-2/16-42.sort.txt`: Do prompt-based fine-tuning on initial results and sort them based on dev set performance.
* `SST-2/16-42.score.txt`: Same as above, but with dev set scores.
To use the best automatic template (`auto-T` in the paper), use the following command:
```bash
TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--template_path auto_template/SST-2/16-42.sort.txt --template_id 0"
```
You can also use the _i_-th automatic result by specifying different `template_id`.
Similarly, to use automatic label (`auto-L` in the paper), use the following command:
```bash
TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--mapping_path auto_label_mapping/SST-2/16-42.sort.txt --mapping_id 0"
```
**NOTE**: Make sure to use the corresponding automatic search results with different data split seeds.
**Our final results (LM-BFF) take prompt-based fine-tuning with demonstrations, filtering and automatic template, for example**:
```bash
for seed in 13 21 42 87 100
do
for bs in 2 4 8
do
for lr in 1e-5 2e-5 5e-5
do
TAG=LM-BFF \
TYPE=prompt-demo \
TASK=SST-2 \
BS=$bs \
LR=$lr \
SEED=$seed \
MODEL=roberta-large \
bash run_experiment.sh "--template_path auto_template/SST-2/16-$seed.sort.txt --template_id 0 --demo_filter --demo_filter_model sbert-roberta-large"
done
done
done
python tools/gather_result.py --condition "{'tag': 'LM-BFF', 'task_name': 'sst-2', 'few_shot_type': 'prompt-demo'}"
```
#### Search for automatic templates
If you want to try automatically generating templates by yourself, here are the instructions. Note that it is an extremely long process :)
To get automatic templates, we first generate template candidates by using T5:
```bash
python tools/generate_template.py \
--output_dir my_auto_template \
--task_name SST-2 \
--seed 13 21 42 87 100 \
--t5_model t5-3b \
--beam 100
```
Where `--t5_model` specifies the pre-trained T5 checkpoint to use and `--beam` specifies the beam search width. Note that `t5-3b` model will take approximately 15GB GPU memory, and if your GPU does not support it, you can try smaller T5 models (e.g., `t5-base`).
Then we do prompt-based fine-tuning of all the templates
```bash
for template_id in {0..99}
do
for seed in 13 21 42 87 100
do
# To save time, we fix these hyper-parameters
bs=8
lr=1e-5
# Since we only use dev performance here, use --no_predict to skip testing
TAG=exp-template \
TYPE=prompt \
TASK=SST-2 \
BS=$bs \
LR=$lr \
SEED=$seed \
MODEL=roberta-large \
bash run_experiment.sh "--template_path my_auto_template/SST-2/16-$seed.txt --template_id $template_id --no_predict"
done
done
```
... and sort them based on dev set performance:
```bash
python tools/sort_template.py --condition "{'tag': 'exp-template', 'task_name': 'sst-2'}" --template_dir my_auto_template
```
The sorted results will be saved in `my_auto_template`, with the same format as described in [Automatically searched prompt](#automatically-searched-prompt).
#### Search for automatic label word mappings
Similar to the process of automatic template search, we first generate candidate label word mappings by running:
```bash
bash tools/run_generate_labels.sh
```
You can modify the options in `tools/run_generate_labels.sh` to run this for different datasets or save mappings to different directories. After running the generation, the candidate label mappings will be saved in `my_auto_label_mapping/manual_template`.
Then we do prompt-based fine-tuning of all the mappings by:
```bash
for mapping_id in {0..99}
do
for seed in 13 21 42 87 100
do
# To save time, we fix these hyper-parameters
bs=8
lr=1e-5
# Since we only use dev performance here, use --no_predict to skip testing
TAG=exp-mapping \
TYPE=prompt \
TASK=SST-2 \
BS=$bs \
LR=$lr \
SEED=$seed \
MODEL=roberta-large \
bash run_experiment.sh "--mapping_path my_auto_label_mapping/manual_template/SST-2/16-$seed.txt --mapping_id $mapping_id --no_predict"
done
done
```
... and sort them based on dev set performance:
```bash
python tools/sort_mapping.py --condition "{'tag': 'exp-mapping', 'task_name': 'sst-2'}" --mapping_dir my_auto_label_mapping/manual_template
```
The sorted results will be saved in `my_auto_label_mapping/manual_template`, with the same format as described in [Automatically searched prompt](#automatically-searched-prompt).
**Auto T + L**: We can also do a joint search of templates and label word mappings following these steps:
1. First, do the automatic template search following [Search for automatic templates](#search-for-automatic-templates).
2. The following steps are similar to automatic label mapping except a few arguments. When running `tools/run_generate_labels.sh`, change `LOAD_TEMPLATES` to `true` in it and the template + mapping candidates will be written in `my_auto_label_mapping/auto_template`
3. For the following fine-tuning, change `--mapping_path` and `--mapping_id` to `--prompt_path` and `--prompt_id`.
4. In the end, for re-ranking all the prompts, change `tools/sort_mapping.py` to `tools/sort_prompt.py` to get the final lists.
### Ensemble model
First we need to train models with different templates:
```bash
mkdir ensemble_predict_results
for template_id in {0..19} # Use top 20 templates
do
array_id=0
for seed in 13 21 42 87 100
do
for bs in 2 4 8
do
for lr in 1e-5 2e-5 5e-5
do
TAG=exp-ensemble \
TYPE=prompt-demo \
TASK=SST-2 \
BS=$bs \
LR=$lr \
SEED=$seed \
MODEL=roberta-large \
bash run_experiment.sh "--template_path auto_template/SST-2/16-$seed.sort.txt --template_id $template_id --model_id $template_id --array_id $array_id --save_logit --save_logit_dir ensemble_predict_results"
array_id=$(expr $array_id + 1)
done
done
done
done
```
Looks a little complicated? It's actually pretty easy to understand: `--model_id` and `--array_id` is used to distinguish different runs, and `--save_logit` tells the program to save the prediction results for ensemble.
After finishing the experiments, use the following command to get the ensemble results:
```bash
python tools/ensemble.py --condition "{'tag': 'exp-ensemble', 'task_name': 'sst-2', 'few_shot_type': 'prompt-demo'}" --n_models 20
```
where `--n_models` specify how many models you want to use for ensemble (should be kept the same as the number of templates you use in experiments).
### Zero-shot experiments
It's easy to run zero-shot experiments: just add the `--no_train` argument:
```bash
TAG=zero-shot TYPE=prompt TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--no_train"
```
To do "GPT-3 style" in-context learning:
```bash
TAG=gpt3-in-context TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--no_train --num_sample 1 --gpt3_in_context_head --gpt3_in_context_num 32 --truncate_head --use_full_length"
```
### How to design your own templates
Here are two template examples:
For SST-2: `*cls**sent_0*_It_was*mask*.*sep+*` => `[CLS] {S0} It was [MASK]. [SEP]`
For MNLI: `*cls**sent-_0*?*mask*,*+sentl_1**sep+*` => `[CLS] {S0}? [MASK], {S1} [SEP]`
The template is composed of special tokens and variables (surrounded by `*`) and text (e.g., `It_was`, where space is replaced by `_`). Special tokens and variables contain:
* `*cls*`, `*sep*`, `*sep+*` and `*mask*`: Special tokens of CLS, SEP and MASK (different for different pre-trained models and tokenizers). `*sep+*` means the contents before and after this token have different segment embeddings (only for BERT).
* `*sent_i*`: The i-th sentence.
* `*sent-_i*`: The i-th sentence, discarding the last character.
* `*sentl_i*`: The i-th sentence, lower-casing the first letter.
* `*sentl-_i*`: The i-th sentence, discarding the last character and lower-casing the first letter.
* `*+sent_i*`: The i-th sentence, adding an extra space at the beginning.
* `*+sentl_i*`: The i-th sentence, adding an extra space at the beginning and lower-casing the first letter.
## Bugs or questions?
If you have any questions related to the code or the paper, feel free to email Tianyu (`tianyug@cs.princeton.edu`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker!
## Citation
Please cite our paper if you use LM-BFF in your work:
```bibtex
@inproceedings{gao2021making,
title={Making Pre-trained Language Models Better Few-shot Learners},
author={Gao, Tianyu and Fisch, Adam and Chen, Danqi},
booktitle={Association for Computational Linguistics (ACL)},
year={2021}
}
```
0%| | 0/18 [00:00<?, ?it/s] 6%|▌ | 1/18 [02:41<45:41, 161.24s/it]
\ No newline at end of file
certifi==2020.12.5
chardet==4.0.0
click==7.1.2
dataclasses
filelock==3.0.12
flake8==3.8.4
future==0.18.2
idna==2.10
importlib-metadata==3.3.0
joblib==1.0.0
mccabe==0.6.1
nltk==3.5
numpy==1.19.4
packaging==20.8
pandas==1.1.5
protobuf==3.14.0
pycodestyle==2.6.0
pyflakes==2.2.0
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2020.5
regex==2020.11.13
requests==2.25.1
sacremoses==0.0.43
scikit-learn==0.24.0
scipy==1.5.4
sentence-transformers==0.4.0
sentencepiece==0.1.94
six==1.15.0
threadpoolctl==2.1.0
tokenizers==0.9.2
torch==1.6.0
tqdm==4.48.2
transformers==3.4.0
typing-extensions==3.7.4.3
urllib3>=1.26.4
zipp==3.4.0
"""Finetuning the library models for sequence classification on GLUE."""
import dataclasses
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional
import torch
import numpy as np
import transformers
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import HfArgumentParser, TrainingArguments, set_seed
from src.dataset import FewShotDataset
from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings
from src.trainer import Trainer
from src.processors import processors_mapping, num_labels_mapping, output_modes_mapping, compute_metrics_mapping, bound_mapping
from filelock import FileLock
from datetime import datetime
from copy import deepcopy
from tqdm import tqdm
import json
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
# Few-shot type
# - finetune: standard fine-tuning
# - prompt: prompt-based fine-tuning
# - prompt-demo: prompt-based fine-tuning with demonstrations
few_shot_type: str = field(
default='prompt-demo',
metadata={"help": "Few-shot learning model type. Choice: finetune, prompt, prompt-demo"}
)
# Only for BERT-type model
random_segment: bool = field(
default=False,
metadata={"help": "Whether to reinitialize the token type embeddings (only for BERT)."}
)
@dataclass
class DynamicDataTrainingArguments(DataTrainingArguments):
"""
Arguments for dynamic training.
"""
num_k: Optional[int] = field(
default=16,
metadata={"help": "Number of training instances per class"}
)
num_sample: Optional[int] = field(
default=16,
metadata={"help": "Number of samples (for inference) in fine-tuning with demonstrations"}
)
num_demo: Optional[int] = field(
default=1,
metadata={"help": "Number of demonstrations from each class"}
)
auto_demo: bool = field(
default=True,
metadata={"help": "Automatically generate template for using demonstrations"}
)
# For prompting
template: str = field(
default=None,
metadata={"help": "Template"}
)
mapping: str = field(
default=None,
metadata={"help": "Label word mapping"}
)
template_path: str = field(
default=None,
metadata={"help": "Path to a txt file that stores all the templates, one per line. Do not set this when prompt_path is used"}
)
mapping_path: str = field(
default=None,
metadata={"help": "Path to a txt file that stores all the label word mappings, one per line. Do not set this when prompt_path is used"}
)
prompt_path: str = field(
default=None,
metadata={"help": "Path to a txt file that stores all the prompts (templates and mappings), one per line"}
)
template_id: int = field(
default=None,
metadata={"help": "Template id if using template_path"}
)
mapping_id: int = field(
default=None,
metadata={"help": "Mapping id if using template_path"}
)
prompt_id: int = field(
default=None,
metadata={"help": "Prompt id if using prompt_path"}
)
top_n_template: int = field(
default=None,
metadata={"help": "Use top-n template in the template path"}
)
# For logging
tag: str = field(
default='',
metadata={"help": "Set the tag and find the result easier in the log."}
)
# For filtering when using demonstrations
demo_filter: bool = field(
default=False,
metadata={"help": "Only use similar instances in demonstrations"}
)
demo_filter_rate: float = field(
default=0.5,
metadata={"help": "Only use top-x\% similar instances in demonstrations"}
)
demo_filter_model: str = field(
default=None,
metadata={"help": "Model name for demonstration filter embeddings. Will load embeddings based on the model name."}
)
debug_mode: bool = field(
default=False,
metadata={"help": "Debug mode"}
)
# For max length
double_demo: bool = field(
default=False,
metadata={"help": "Use double length for using demonstrations"}
)
first_sent_limit: int = field(
default=None,
metadata={"help": "Limit the length of the first sentence (i.e., sent_0)"}
)
other_sent_limit: int = field(
default=None,
metadata={"help": "Limit the length of sentences other than the first sentence"}
)
use_full_length: bool = field(
default=None,
metadata={"help": "Use the full length (512)"}
)
# GPT-3's in-context learning
gpt3_in_context_head: bool = field(
default=False,
metadata={"help": "GPT-3's in-context learning (context at the beginning)"}
)
gpt3_in_context_tail: bool = field(
default=False,
metadata={"help": "GPT-3's in-context learning (context at the end)"}
)
gpt3_in_context_num: int = field(
default=32,
metadata={"help": "Number of context examples"}
)
truncate_head: bool = field(
default=False,
metadata={"help": "When exceeding the maximum length, truncate the head instead of the tail."}
)
# Do not set up the following fields. They are set up automatically.
prompt: bool = field(
default=False,
metadata={"help": "Whether to use prompt-based fine-tuning"}
)
template_list: list = field(
default=None,
metadata={"help": "(DO NOT List of templates (only initialized after the program starts."}
)
@dataclass
class DynamicTrainingArguments(TrainingArguments):
# For ensemble
array_id: int = field(
default=-1,
metadata={"help": "Array ID (contains seed and hyper-paramter search) to idenfity the model"}
)
model_id: int = field(
default=-1,
metadata={"help": "Model ID (contains template information) to identify the model"}
)
save_logit: bool = field(
default=False,
metadata={"help": "Save test file logit with name $TASK-$MODEL_ID-$ARRAY_ID.npy"}
)
save_logit_dir: str = field(
default=None,
metadata={"help": "Where to save the prediction result"}
)
# Regularization
fix_layers: int = field(
default=0,
metadata={"help": "Fix bottom-n layers when optimizing"}
)
# Training
save_at_last: bool = field(
default=False,
metadata={"help": "Instead of saving the best (dev performance) checkpoint, save the last checkpoint"}
)
# Turn off train/test
no_train: bool = field(
default=False,
metadata={"help": "No training"}
)
no_predict: bool = field(
default=False,
metadata={"help": "No test"}
)
def main():
parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments, DynamicTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if 'prompt' in model_args.few_shot_type:
data_args.prompt = True
if training_args.no_train:
training_args.do_train = False
if training_args.no_predict:
training_args.do_predict = False
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
# Load prompt/template/mapping file
if data_args.prompt:
if data_args.prompt_path is not None:
assert data_args.prompt_id is not None
prompt_list = []
with open(data_args.prompt_path) as f:
for line in f:
line = line.strip()
template, mapping = line.split('\t')
prompt_list.append((template, mapping))
data_args.template, data_args.mapping = prompt_list[data_args.prompt_id]
logger.info("Specify load the %d-th prompt: %s | %s" % (data_args.prompt_id, data_args.template, data_args.mapping))
else:
if data_args.template_path is not None:
with open(data_args.template_path) as f:
data_args.template_list = []
for line in f:
line = line.strip()
if len(line) > 0:
data_args.template_list.append(line)
# Load top-n templates
if data_args.top_n_template is not None:
data_args.template_list = data_args.template_list[:data_args.top_n_template]
logger.info("Load top-%d templates from %s" % (len(data_args.template_list), data_args.template_path))
# ... or load i-th template
if data_args.template_id is not None:
data_args.template = data_args.template_list[data_args.template_id]
data_args.template_list = None
logger.info("Specify load the %d-th template: %s" % (data_args.template_id, data_args.template))
if data_args.mapping_path is not None:
assert data_args.mapping_id is not None # Only can use one label word mapping
with open(data_args.mapping_path) as f:
mapping_list = []
for line in f:
line = line.strip()
mapping_list.append(line)
data_args.mapping = mapping_list[data_args.mapping_id]
logger.info("Specify using the %d-th mapping: %s" % (data_args.mapping_id, data_args.mapping))
# Check save path
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(f"Output directory ({training_args.output_dir}) already exists.")
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
# Set seed
set_seed(training_args.seed)
try:
num_labels = num_labels_mapping[data_args.task_name]
output_mode = output_modes_mapping[data_args.task_name]
logger.info("Task name: {}, number of labels: {}, output mode: {}".format(data_args.task_name, num_labels, output_mode))
except KeyError:
raise ValueError("Task not found: %s" % (data_args.task_name))
# Automatically generate template for using demonstrations
if data_args.auto_demo and model_args.few_shot_type == 'prompt-demo':
# GPT-3's in-context learning
if data_args.gpt3_in_context_head or data_args.gpt3_in_context_tail:
logger.info("Automatically convert the template to GPT-3's in-context learning.")
assert data_args.template_list is None
old_template = data_args.template
new_template = old_template + ''
old_template = old_template.replace('*cls*', '')
# Single sentence or sentence pair?
sent_num = 1
if "_1" in old_template:
sent_num = 2
for instance_id in range(data_args.gpt3_in_context_num):
sub_template = old_template + ''
# Replace sent_id
for sent_id in range(sent_num):
sub_template = sub_template.replace("_{}*".format(sent_id), "_{}*".format(sent_num + sent_num * instance_id + sent_id))
# Replace mask
sub_template = sub_template.replace("*mask*", "*labelx_{}*".format(instance_id))
if data_args.gpt3_in_context_tail:
new_template = new_template + sub_template # Put context at the end
else:
new_template = sub_template + new_template # Put context at the beginning
logger.info("| {} => {}".format(data_args.template, new_template))
data_args.template = new_template
else:
logger.info("Automatically convert the template to using demonstrations.")
if data_args.template_list is not None:
for i in range(len(data_args.template_list)):
old_template = data_args.template_list[i]
new_template = old_template + ''
old_template = old_template.replace('*cls*', '')
# Single sentence or sentence pair?
sent_num = 1
if "_1" in old_template:
sent_num = 2
for label_id in range(num_labels):
sub_template = old_template + ''
# Replace sent id
for sent_id in range(sent_num):
sub_template = sub_template.replace("_{}*".format(sent_id), "_{}*".format(sent_num + sent_num * label_id + sent_id))
# Replace mask
sub_template = sub_template.replace("*mask*", "*label_{}*".format(label_id))
new_template = new_template + sub_template
logger.info("| {} => {}".format(data_args.template_list[i], new_template))
data_args.template_list[i] = new_template
else:
old_template = data_args.template
new_template = old_template + ''
old_template = old_template.replace('*cls*', '')
# Single sentence or sentence pair?
sent_num = 1
if "_1" in old_template:
sent_num = 2
for label_id in range(num_labels):
sub_template = old_template + ''
# Replace sent id
for sent_id in range(sent_num):
sub_template = sub_template.replace("_{}".format(sent_id), "_{}".format(sent_num + sent_num * label_id + sent_id))
# Replace mask
sub_template = sub_template.replace("*mask*", "*label_{}*".format(label_id))
new_template = new_template + sub_template
logger.info("| {} => {}".format(data_args.template, new_template))
data_args.template = new_template
# Create config
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
)
if 'prompt' in model_args.few_shot_type:
if config.model_type == 'roberta':
model_fn = RobertaForPromptFinetuning
elif config.model_type == 'bert':
model_fn = BertForPromptFinetuning
else:
raise NotImplementedError
elif model_args.few_shot_type == 'finetune':
model_fn = AutoModelForSequenceClassification
else:
raise NotImplementedError
special_tokens = []
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
additional_special_tokens=special_tokens,
cache_dir=model_args.cache_dir,
)
# Get our special datasets.
train_dataset = (
FewShotDataset(data_args, tokenizer=tokenizer, mode="train", use_demo=("demo" in model_args.few_shot_type))
)
eval_dataset = (
FewShotDataset(data_args, tokenizer=tokenizer, mode="dev", use_demo=("demo" in model_args.few_shot_type))
if training_args.do_eval
else None
)
test_dataset = (
FewShotDataset(data_args, tokenizer=tokenizer, mode="test", use_demo=("demo" in model_args.few_shot_type))
if training_args.do_predict
else None
)
set_seed(training_args.seed)
model = model_fn.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
# For BERT, increase the size of the segment (token type) embeddings
if config.model_type == 'bert':
model.resize_token_embeddings(len(tokenizer))
resize_token_type_embeddings(model, new_num_types=10, random_segment=model_args.random_segment)
# Pass dataset and argument information to the model
if data_args.prompt:
model.label_word_list = torch.tensor(train_dataset.label_word_list).long().cuda()
if output_modes_mapping[data_args.task_name] == 'regression':
# lower / upper bounds
model.lb, model.ub = bound_mapping[data_args.task_name]
model.model_args = model_args
model.data_args = data_args
model.tokenizer = tokenizer
# Build metric
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def compute_metrics_fn(p: EvalPrediction):
# Note: the eval dataloader is sequential, so the examples are in order.
# We average the logits over each sample for using demonstrations.
predictions = p.predictions
num_logits = predictions.shape[-1]
logits = predictions.reshape([eval_dataset.num_sample, -1, num_logits])
logits = logits.mean(axis=0)
if num_logits == 1:
preds = np.squeeze(logits)
else:
preds = np.argmax(logits, axis=1)
# Just for sanity, assert label ids are the same.
label_ids = p.label_ids.reshape([eval_dataset.num_sample, -1])
label_ids_avg = label_ids.mean(axis=0)
label_ids_avg = label_ids_avg.astype(p.label_ids.dtype)
assert (label_ids_avg - label_ids[0]).mean() < 1e-2
label_ids = label_ids[0]
return compute_metrics_mapping[task_name](task_name, preds, label_ids)
return compute_metrics_fn
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=build_compute_metrics_fn(data_args.task_name)
)
# Training
if training_args.do_train:
trainer.train(model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None)
# Use the early stop, so do not save the model in the end (unless specify save_at_last)
if training_args.save_at_last:
trainer.save_model(training_args.output_dir)
if trainer.is_world_master():
tokenizer.save_pretrained(training_args.output_dir)
torch.save(model_args, os.path.join(training_args.output_dir, "model_args.bin"))
torch.save(data_args, os.path.join(training_args.output_dir, "data_args.bin"))
# Reload the best checkpoint (for eval)
model = model_fn.from_pretrained(training_args.output_dir)
model = model.to(training_args.device)
trainer.model = model
if data_args.prompt:
model.label_word_list = torch.tensor(train_dataset.label_word_list).long().cuda()
if output_modes_mapping[data_args.task_name] == 'regression':
# lower / upper bounds
model.lb, model.ub = bound_mapping[data_args.task_name]
model.model_args = model_args
model.data_args = data_args
model.tokenizer = tokenizer
# Evaluation
final_result = {
'time': str(datetime.today()),
}
eval_results = {}
if training_args.do_eval:
logger.info("*** Validate ***")
eval_datasets = [eval_dataset]
for eval_dataset in eval_datasets:
trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name)
output = trainer.evaluate(eval_dataset=eval_dataset)
eval_result = output.metrics
output_eval_file = os.path.join(
training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
)
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
for key, value in eval_result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
final_result[eval_dataset.args.task_name + '_dev_' + key] = value
eval_results.update(eval_result)
test_results = {}
if training_args.do_predict:
logging.info("*** Test ***")
test_datasets = [test_dataset]
if data_args.task_name == "mnli":
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
test_datasets.append(
FewShotDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", use_demo=('demo' in model_args.few_shot_type))
)
for test_dataset in test_datasets:
trainer.compute_metrics = build_compute_metrics_fn(test_dataset.args.task_name)
output = trainer.evaluate(eval_dataset=test_dataset)
test_result = output.metrics
output_test_file = os.path.join(
training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt"
)
if trainer.is_world_master():
with open(output_test_file, "w") as writer:
logger.info("***** Test results {} *****".format(test_dataset.args.task_name))
for key, value in test_result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
final_result[test_dataset.args.task_name + '_test_' + key] = value
if training_args.save_logit:
predictions = output.predictions
num_logits = predictions.shape[-1]
logits = predictions.reshape([test_dataset.num_sample, -1, num_logits]).mean(axis=0)
np.save(os.path.join(training_args.save_logit_dir, "{}-{}-{}.npy".format(test_dataset.task_name, training_args.model_id, training_args.array_id)), logits)
test_results.update(test_result)
with FileLock('log.lock'):
with open('log', 'a') as f:
final_result.update(vars(model_args))
final_result.update(vars(training_args))
final_result.update(vars(data_args))
if 'evaluation_strategy' in final_result:
final_result.pop('evaluation_strategy')
f.write(str(final_result) + '\n')
return eval_results
if __name__ == "__main__":
main()
"""Dataset utils for different data settings for GLUE."""
import os
import copy
import logging
import torch
import numpy as np
import time
from filelock import FileLock
import json
import itertools
import random
import transformers
from src.processors import processors_mapping, num_labels_mapping, output_modes_mapping, compute_metrics_mapping, median_mapping
from transformers.data.processors.utils import InputFeatures
from transformers import DataProcessor, InputExample
import dataclasses
from dataclasses import dataclass
from typing import List, Optional, Union
from sentence_transformers import SentenceTransformer, util
from copy import deepcopy
import pandas as pd
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class OurInputFeatures(InputFeatures):
"""
Inherit from Transformers' InputFeatuers.
"""
input_ids: List[int]
attention_mask: Optional[List[int]] = None
token_type_ids: Optional[List[int]] = None
label: Optional[Union[int, float]] = None
mask_pos: Optional[List[int]] = None # Position of the mask token
label_word_list: Optional[List[int]] = None # Label word mapping (dynamic)
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(dataclasses.asdict(self)) + "\n"
def input_example_to_string(example, sep_token):
if example.text_b is None:
return example.text_a
else:
# Warning: very simple hack here
return example.text_a + ' ' + sep_token + ' ' + example.text_b
def input_example_to_tuple(example):
if example.text_b is None:
if pd.isna(example.text_a) or example.text_a is None:
return ['']
logger.warn("Empty input")
else:
return [example.text_a]
else:
return [example.text_a, example.text_b]
def tokenize_multipart_input(
input_text_list,
max_length,
tokenizer,
task_name=None,
prompt=False,
template=None,
label_word_list=None,
first_sent_limit=None,
other_sent_limit=None,
gpt3=False,
truncate_head=False,
support_labels=None,
):
def enc(text):
return tokenizer.encode(text, add_special_tokens=False)
input_ids = []
attention_mask = []
token_type_ids = [] # Only for BERT
mask_pos = None # Position of the mask token
if prompt:
"""
Concatenate all sentences and prompts based on the provided template.
Template example: '*cls*It was*mask*.*sent_0**<sep>*label_0:*sent_1**<sep>**label_1*:*sent_2**<sep>*'
*xx* represent variables:
*cls*: cls_token
*mask*: mask_token
*sep*: sep_token
*sep+*: sep_token, also means +1 for segment id
*sent_i*: sentence i (input_text_list[i])
*sent-_i*: same as above, but delete the last token
*sentl_i*: same as above, but use lower case for the first word
*sentl-_i*: same as above, but use lower case for the first word and delete the last token
*+sent_i*: same as above, but add a space before the sentence
*+sentl_i*: same as above, but add a space before the sentence and use lower case for the first word
*label_i*: label_word_list[i]
*label_x*: label depends on the example id (support_labels needed). this is only used in GPT-3's in-context learning
Use "_" to replace space.
PAY ATTENTION TO SPACE!! DO NOT leave space before variables, for this will lead to extra space token.
"""
assert template is not None
special_token_mapping = {
'cls': tokenizer.cls_token_id, 'mask': tokenizer.mask_token_id, 'sep': tokenizer.sep_token_id, 'sep+': tokenizer.sep_token_id,
}
template_list = template.split('*') # Get variable list in the template
segment_id = 0 # Current segment id. Segment id +1 if encountering sep+.
for part_id, part in enumerate(template_list):
new_tokens = []
segment_plus_1_flag = False
if part in special_token_mapping:
if part == 'cls' and 'T5' in type(tokenizer).__name__:
# T5 does not have cls token
continue
new_tokens.append(special_token_mapping[part])
if part == 'sep+':
segment_plus_1_flag = True
elif part[:6] == 'label_':
# Note that label_word_list already has extra space, so do not add more space ahead of it.
label_id = int(part.split('_')[1])
label_word = label_word_list[label_id]
new_tokens.append(label_word)
elif part[:7] == 'labelx_':
instance_id = int(part.split('_')[1])
label_id = support_labels[instance_id]
label_word = label_word_list[label_id]
new_tokens.append(label_word)
elif part[:5] == 'sent_':
sent_id = int(part.split('_')[1])
new_tokens += enc(input_text_list[sent_id])
elif part[:6] == '+sent_':
# Add space
sent_id = int(part.split('_')[1])
new_tokens += enc(' ' + input_text_list[sent_id])
elif part[:6] == 'sent-_':
# Delete the last token
sent_id = int(part.split('_')[1])
new_tokens += enc(input_text_list[sent_id][:-1])
elif part[:6] == 'sentl_':
# Lower case the first token
sent_id = int(part.split('_')[1])
text = input_text_list[sent_id]
text = text[:1].lower() + text[1:]
new_tokens += enc(text)
elif part[:7] == '+sentl_':
# Lower case the first token and add space
sent_id = int(part.split('_')[1])
text = input_text_list[sent_id]
text = text[:1].lower() + text[1:]
new_tokens += enc(' ' + text)
elif part[:7] == 'sentl-_':
# Lower case the first token and discard the last token
sent_id = int(part.split('_')[1])
text = input_text_list[sent_id]
text = text[:1].lower() + text[1:]
new_tokens += enc(text[:-1])
elif part[:6] == 'sentu_':
# Upper case the first token
sent_id = int(part.split('_')[1])
text = input_text_list[sent_id]
text = text[:1].upper() + text[1:]
new_tokens += enc(text)
elif part[:7] == '+sentu_':
# Upper case the first token and add space
sent_id = int(part.split('_')[1])
text = input_text_list[sent_id]
text = text[:1].upper() + text[1:]
new_tokens += enc(' ' + text)
else:
# Just natural language prompt
part = part.replace('_', ' ')
# handle special case when T5 tokenizer might add an extra space
if len(part) == 1:
new_tokens.append(tokenizer._convert_token_to_id(part))
else:
new_tokens += enc(part)
if part[:4] == 'sent' or part[1:5] == 'sent':
# If this part is the sentence, limit the sentence length
sent_id = int(part.split('_')[1])
if sent_id == 0:
if first_sent_limit is not None:
new_tokens = new_tokens[:first_sent_limit]
else:
if other_sent_limit is not None:
new_tokens = new_tokens[:other_sent_limit]
input_ids += new_tokens
attention_mask += [1 for i in range(len(new_tokens))]
token_type_ids += [segment_id for i in range(len(new_tokens))]
if segment_plus_1_flag:
segment_id += 1
else:
input_ids = [tokenizer.cls_token_id]
attention_mask = [1]
token_type_ids = [0]
for sent_id, input_text in enumerate(input_text_list):
if input_text is None:
# Do not have text_b
continue
if pd.isna(input_text) or input_text is None:
# Empty input
input_text = ''
input_tokens = enc(input_text) + [tokenizer.sep_token_id]
input_ids += input_tokens
attention_mask += [1 for i in range(len(input_tokens))]
token_type_ids += [sent_id for i in range(len(input_tokens))]
if 'T5' in type(tokenizer).__name__: # T5 does not have CLS token
input_ids = input_ids[1:]
attention_mask = attention_mask[1:]
token_type_ids = token_type_ids[1:]
# Padding
if first_sent_limit is not None and len(input_ids) > max_length:
# If using sentence limit, the total length still exceeds the maximum limit, report a warning
logger.warn("Input exceeds max_length limit: {}".format(tokenizer.decode(input_ids)))
while len(input_ids) < max_length:
input_ids.append(tokenizer.pad_token_id)
attention_mask.append(0)
token_type_ids.append(0)
# Truncate
if len(input_ids) > max_length:
if truncate_head:
input_ids = input_ids[-max_length:]
attention_mask = attention_mask[-max_length:]
token_type_ids = token_type_ids[-max_length:]
else:
# Default is to truncate the tail
input_ids = input_ids[:max_length]
attention_mask = attention_mask[:max_length]
token_type_ids = token_type_ids[:max_length]
# Find mask token
if prompt:
mask_pos = [input_ids.index(tokenizer.mask_token_id)]
# Make sure that the masked position is inside the max_length
assert mask_pos[0] < max_length
result = {'input_ids': input_ids, 'attention_mask': attention_mask}
if 'BERT' in type(tokenizer).__name__:
# Only provide token type ids for BERT
result['token_type_ids'] = token_type_ids
if prompt:
result['mask_pos'] = mask_pos
return result
class FewShotDataset(torch.utils.data.Dataset):
"""Few-shot dataset."""
def __init__(self, args, tokenizer, cache_dir=None, mode="train", use_demo=False):
self.args = args
self.task_name = args.task_name
self.processor = processors_mapping[args.task_name]
self.tokenizer = tokenizer
self.mode = mode
# If not using demonstrations, use use_demo=True
self.use_demo = use_demo
if self.use_demo:
logger.info("Use demonstrations")
assert mode in ["train", "dev", "test"]
# Get label list and (for prompt) label word list
self.label_list = self.processor.get_labels()
self.num_labels = len(self.label_list)
if args.prompt:
assert args.mapping is not None
self.label_to_word = eval(args.mapping)
for key in self.label_to_word:
# For RoBERTa/BART/T5, tokenization also considers space, so we use space+word as label words.
if self.label_to_word[key][0] not in ['<', '[', '.', ',']:
# Make sure space+word is in the vocabulary
assert len(tokenizer.tokenize(' ' + self.label_to_word[key])) == 1
self.label_to_word[key] = tokenizer._convert_token_to_id(tokenizer.tokenize(' ' + self.label_to_word[key])[0])
else:
self.label_to_word[key] = tokenizer._convert_token_to_id(self.label_to_word[key])
logger.info("Label {} to word {} ({})".format(key, tokenizer._convert_id_to_token(self.label_to_word[key]), self.label_to_word[key]))
if len(self.label_list) > 1:
self.label_word_list = [self.label_to_word[label] for label in self.label_list]
else:
# Regression task
# '0' represents low polarity and '1' represents high polarity.
self.label_word_list = [self.label_to_word[label] for label in ['0', '1']]
else:
self.label_to_word = None
self.label_word_list = None
# Multiple sampling: when using demonstrations, we sample different combinations of demonstrations during
# inference and aggregate the results by averaging the logits. The number of different samples is num_sample.
if (mode == "train") or not self.use_demo:
# We do not do multiple sampling when not using demonstrations or when it's the training mode
self.num_sample = 1
else:
self.num_sample = args.num_sample
# If we use multiple templates, we also need to do multiple sampling during inference.
if args.prompt and args.template_list is not None:
logger.info("There are %d templates. Multiply num_sample by %d" % (len(args.template_list), len(args.template_list)))
self.num_sample *= len(args.template_list)
logger.info("Total num_sample for mode %s: %d" % (mode, self.num_sample))
# Load cache
# Cache name distinguishes mode, task name, tokenizer, and length. So if you change anything beyond these elements, make sure to clear your cache.
cached_features_file = os.path.join(
cache_dir if cache_dir is not None else args.data_dir,
"cached_{}_{}_{}_{}".format(
mode,
tokenizer.__class__.__name__,
str(args.max_seq_length),
args.task_name,
),
)
logger.info(f"Creating/loading examples from dataset file at {args.data_dir}")
lock_path = cached_features_file + ".lock"
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
start = time.time()
self.support_examples, self.query_examples = torch.load(cached_features_file)
logger.info(
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
)
else:
logger.info(f"Creating features from dataset file at {args.data_dir}")
# The support examples are sourced from the training set.
self.support_examples = self.processor.get_train_examples(args.data_dir)
if mode == "dev":
self.query_examples = self.processor.get_dev_examples(args.data_dir)
elif mode == "test":
self.query_examples = self.processor.get_test_examples(args.data_dir)
else:
self.query_examples = self.support_examples
start = time.time()
torch.save([self.support_examples, self.query_examples], cached_features_file)
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
logger.info(
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
)
# For filtering in using demonstrations, load pre-calculated embeddings
if self.use_demo and args.demo_filter:
split_name = ''
if mode == 'train':
split_name = 'train'
elif mode == 'dev':
if args.task_name == 'mnli':
split_name = 'dev_matched'
elif args.task_name == 'mnli-mm':
split_name = 'dev_mismatched'
else:
split_name = 'dev'
elif mode == 'test':
if args.task_name == 'mnli':
split_name = 'test_matched'
elif args.task_name == 'mnli-mm':
split_name = 'test_mismatched'
else:
split_name = 'test'
else:
raise NotImplementedError
self.support_emb = np.load(os.path.join(args.data_dir, "train_{}.npy".format(args.demo_filter_model)))
self.query_emb = np.load(os.path.join(args.data_dir, "{}_{}.npy".format(split_name, args.demo_filter_model)))
logger.info("Load embeddings (for demonstration filtering) from {}".format(os.path.join(args.data_dir, "{}_{}.npy".format(split_name, args.demo_filter_model))))
assert len(self.support_emb) == len(self.support_examples)
assert len(self.query_emb) == len(self.query_examples)
# Size is expanded by num_sample
self.size = len(self.query_examples) * self.num_sample
# Prepare examples (especially for using demonstrations)
support_indices = list(range(len(self.support_examples)))
self.example_idx = []
for sample_idx in range(self.num_sample):
for query_idx in range(len(self.query_examples)):
# If training, exclude the current example. Else keep all.
if self.use_demo and args.demo_filter:
# Demonstration filtering
candidate = [support_idx for support_idx in support_indices
if support_idx != query_idx or mode != "train"]
sim_score = []
for support_idx in candidate:
sim_score.append((support_idx, util.pytorch_cos_sim(self.support_emb[support_idx], self.query_emb[query_idx])))
sim_score.sort(key=lambda x: x[1], reverse=True)
if self.num_labels == 1:
# Regression task
limit_each_label = int(len(sim_score) // 2 * args.demo_filter_rate)
count_each_label = {'0': 0, '1': 0}
context_indices = []
if args.debug_mode:
print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug
for support_idx, score in sim_score:
if count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] < limit_each_label:
count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] += 1
context_indices.append(support_idx)
if args.debug_mode:
print(" %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug
else:
limit_each_label = int(len(sim_score) // self.num_labels * args.demo_filter_rate)
count_each_label = {label: 0 for label in self.label_list}
context_indices = []
if args.debug_mode:
print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug
for support_idx, score in sim_score:
if count_each_label[self.support_examples[support_idx].label] < limit_each_label:
count_each_label[self.support_examples[support_idx].label] += 1
context_indices.append(support_idx)
if args.debug_mode:
print(" %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug
else:
# Using demonstrations without filtering
context_indices = [support_idx for support_idx in support_indices
if support_idx != query_idx or mode != "train"]
# We'll subsample context_indices further later.
self.example_idx.append((query_idx, context_indices, sample_idx))
# If it is not training, we pre-process the data; otherwise, we process the data online.
if mode != "train":
self.features = []
_ = 0
for query_idx, context_indices, bootstrap_idx in self.example_idx:
# The input (query) example
example = self.query_examples[query_idx]
# The demonstrations
supports = self.select_context([self.support_examples[i] for i in context_indices])
if args.template_list is not None:
template = args.template_list[sample_idx % len(args.template_list)] # Use template in order
else:
template = args.template
self.features.append(self.convert_fn(
example=example,
supports=supports,
use_demo=self.use_demo,
label_list=self.label_list,
prompt=args.prompt,
template=template,
label_word_list=self.label_word_list,
verbose=True if _ == 0 else False,
))
_ += 1
else:
self.features = None
def select_context(self, context_examples):
"""
Select demonstrations from provided examples.
"""
max_demo_per_label = 1
counts = {k: 0 for k in self.label_list}
if len(self.label_list) == 1:
# Regression
counts = {'0': 0, '1': 0}
selection = []
if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail:
# For GPT-3's in-context learning, we sample gpt3_in_context_num demonstrations randomly.
order = np.random.permutation(len(context_examples))
for i in range(min(self.args.gpt3_in_context_num, len(order))):
selection.append(context_examples[order[i]])
else:
# Our sampling strategy
order = np.random.permutation(len(context_examples))
for i in order:
label = context_examples[i].label
if len(self.label_list) == 1:
# Regression
label = '0' if float(label) <= median_mapping[self.args.task_name] else '1'
if counts[label] < max_demo_per_label:
selection.append(context_examples[i])
counts[label] += 1
if sum(counts.values()) == len(counts) * max_demo_per_label:
break
assert len(selection) > 0
return selection
def __len__(self):
return self.size
def __getitem__(self, i):
if self.features is None:
query_idx, context_indices, bootstrap_idx = self.example_idx[i]
# The input (query) example
example = self.query_examples[query_idx]
# The demonstrations
supports = self.select_context([self.support_examples[i] for i in context_indices])
if self.args.template_list is not None:
template = self.args.template_list[sample_idx % len(self.args.template_list)]
else:
template = self.args.template
features = self.convert_fn(
example=example,
supports=supports,
use_demo=self.use_demo,
label_list=self.label_list,
prompt=self.args.prompt,
template=template,
label_word_list=self.label_word_list,
verbose=False,
)
else:
features = self.features[i]
return features
def get_labels(self):
return self.label_list
def convert_fn(
self,
example,
supports,
use_demo=False,
label_list=None,
prompt=False,
template=None,
label_word_list=None,
verbose=False
):
"""
Returns a list of processed "InputFeatures".
"""
max_length = self.args.max_seq_length
# Prepare labels
label_map = {label: i for i, label in enumerate(label_list)} # Mapping the label names to label ids
if len(label_list) == 1:
# Regression
label_map = {'0': 0, '1': 1}
# Get example's label id (for training/inference)
if example.label is None:
example_label = None
elif len(label_list) == 1:
# Regerssion
example_label = float(example.label)
else:
example_label = label_map[example.label]
# Prepare other features
if not use_demo:
# No using demonstrations
inputs = tokenize_multipart_input(
input_text_list=input_example_to_tuple(example),
max_length=max_length,
tokenizer=self.tokenizer,
task_name=self.args.task_name,
prompt=prompt,
template=template,
label_word_list=label_word_list,
first_sent_limit=self.args.first_sent_limit,
other_sent_limit=self.args.other_sent_limit,
)
features = OurInputFeatures(**inputs, label=example_label)
else:
# Using demonstrations
# Max length
if self.args.double_demo:
# When using demonstrations, double the maximum length
# Note that in this case, args.max_seq_length is the maximum length for a single sentence
max_length = max_length * 2
if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail:
# When using GPT-3's in-context learning, take the maximum tokenization length of the model (512)
max_length = 512
# All input sentences, including the query and the demonstrations, are put into augmented_examples,
# and are numbered based on the order (starting from 0). For single sentence tasks, the input (query)
# is the sentence 0; for sentence-pair tasks, the input (query) is the sentence 0 and 1. Note that for GPT-3's
# in-context learning, the input (query) might be at the end instead of the beginning (gpt3_in_context_head)
augmented_example = []
query_text = input_example_to_tuple(example) # Input sentence list for query
support_by_label = [[] for i in range(len(label_map))]
if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail:
support_labels = []
augmented_example = query_text
for support_example in supports:
augmented_example += input_example_to_tuple(support_example)
current_label = support_example.label
if len(label_list) == 1:
current_label = '0' if float(current_label) <= median_mapping[self.args.task_name] else '1' # Regression
support_labels.append(label_map[current_label])
else:
# Group support examples by label
for label_name, label_id in label_map.items():
if len(label_list) == 1:
# Regression
for support_example in filter(lambda s: ('0' if float(s.label) <= median_mapping[self.args.task_name] else '1') == label_name, supports):
support_by_label[label_id] += input_example_to_tuple(support_example)
else:
for support_example in filter(lambda s: s.label == label_name, supports):
support_by_label[label_id] += input_example_to_tuple(support_example)
augmented_example = query_text
for label_id in range(len(label_map)):
augmented_example += support_by_label[label_id]
# Tokenization (based on the template)
inputs = tokenize_multipart_input(
input_text_list=augmented_example,
max_length=max_length,
tokenizer=self.tokenizer,
task_name=self.args.task_name,
prompt=prompt,
template=template,
label_word_list=label_word_list,
first_sent_limit=self.args.first_sent_limit,
other_sent_limit=self.args.other_sent_limit,
truncate_head=self.args.truncate_head,
gpt3=self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail,
support_labels=None if not (self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail) else support_labels
)
features = OurInputFeatures(**inputs, label=example_label)
if verbose:
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("features: %s" % features)
logger.info("text: %s" % self.tokenizer.decode(features.input_ids))
return features
"""Automatic label search helpers."""
import itertools
import torch
import tqdm
import multiprocessing
import numpy as np
import scipy.spatial as spatial
import scipy.special as special
import scipy.stats as stats
import logging
logger = logging.getLogger(__name__)
def select_likely_words(train_logits, train_labels, k_likely=1000, vocab=None, is_regression=False):
"""Pre-select likely words based on conditional likelihood."""
indices = []
if is_regression:
median = np.median(train_labels)
train_labels = (train_labels > median).astype(np.int)
num_labels = np.max(train_labels) + 1
for idx in range(num_labels):
label_logits = train_logits[train_labels == idx]
scores = label_logits.mean(axis=0)
kept = []
for i in np.argsort(-scores):
text = vocab[i]
if not text.startswith("Ġ"):
continue
kept.append(i)
indices.append(kept[:k_likely])
return indices
def select_neighbors(distances, k_neighbors, valid):
"""Select k nearest neighbors based on distance (filtered to be within the 'valid' set)."""
indices = np.argsort(distances)
neighbors = []
for i in indices:
if i not in valid:
continue
neighbors.append(i)
if k_neighbors > 0:
return neighbors[:k_neighbors]
return neighbors
def init(train_logits, train_labels):
global logits, labels
logits = train_logits
labels = train_labels
def eval_pairing_acc(pairing):
global logits, labels
label_logits = np.take(logits, pairing, axis=-1)
preds = np.argmax(label_logits, axis=-1)
correct = np.sum(preds == labels)
return correct / len(labels)
def eval_pairing_corr(pairing):
global logits, labels
if pairing[0] == pairing[1]:
return -1
label_logits = np.take(logits, pairing, axis=-1)
label_probs = special.softmax(label_logits, axis=-1)[:, 1]
pearson_corr = stats.pearsonr(label_probs, labels)[0]
return pearson_corr
def find_labels(
model,
train_logits,
train_labels,
seed_labels=None,
k_likely=1000,
k_neighbors=None,
top_n=-1,
vocab=None,
is_regression=False,
):
# Get top indices based on conditional likelihood using the LM.
likely_indices = select_likely_words(
train_logits=train_logits,
train_labels=train_labels,
k_likely=k_likely,
vocab=vocab,
is_regression=is_regression)
logger.info("Top labels (conditional) per class:")
for i, inds in enumerate(likely_indices):
logger.info("\t| Label %d: %s", i, ", ".join([vocab[i] for i in inds[:10]]))
# Convert to sets.
valid_indices = [set(inds) for inds in likely_indices]
# If specified, further re-rank according to nearest neighbors of seed labels.
# Otherwise, keep ranking as is (based on conditional likelihood only).
if seed_labels:
assert(vocab is not None)
seed_ids = [vocab.index(l) for l in seed_labels]
vocab_vecs = model.lm_head.decoder.weight.detach().cpu().numpy()
seed_vecs = np.take(vocab_vecs, seed_ids, axis=0)
# [num_labels, vocab_size]
label_distances = spatial.distance.cdist(seed_vecs, vocab_vecs, metric="cosine")
# Establish label candidates (as k nearest neighbors).
label_candidates = []
logger.info("Re-ranked by nearest neighbors:")
for i, distances in enumerate(label_distances):
label_candidates.append(select_neighbors(distances, k_neighbors, valid_indices[i]))
logger.info("\t| Label: %s", seed_labels[i])
logger.info("\t| Neighbors: %s", " ".join([vocab[idx] for idx in label_candidates[i]]))
else:
label_candidates = likely_indices
# Brute-force search all valid pairings.
pairings = list(itertools.product(*label_candidates))
if is_regression:
eval_pairing = eval_pairing_corr
metric = "corr"
else:
eval_pairing = eval_pairing_acc
metric = "acc"
# Score each pairing.
pairing_scores = []
with multiprocessing.Pool(initializer=init, initargs=(train_logits, train_labels)) as workers:
with tqdm.tqdm(total=len(pairings)) as pbar:
chunksize = max(10, int(len(pairings) / 1000))
for score in workers.imap(eval_pairing, pairings, chunksize=chunksize):
pairing_scores.append(score)
pbar.update()
# Take top-n.
best_idx = np.argsort(-np.array(pairing_scores))[:top_n]
best_scores = [pairing_scores[i] for i in best_idx]
best_pairings = [pairings[i] for i in best_idx]
logger.info("Automatically searched pairings:")
for i, indices in enumerate(best_pairings):
logger.info("\t| %s (%s = %2.2f)", " ".join([vocab[j] for j in indices]), metric, best_scores[i])
return best_pairings
"""Custom models for few-shot learning specific operations."""
import torch
import torch.nn as nn
import transformers
from transformers.modeling_bert import BertPreTrainedModel, BertForSequenceClassification, BertModel, BertOnlyMLMHead
from transformers.modeling_roberta import RobertaForSequenceClassification, RobertaModel, RobertaLMHead, RobertaClassificationHead
from transformers.modeling_outputs import SequenceClassifierOutput
import logging
logger = logging.getLogger(__name__)
def resize_token_type_embeddings(model, new_num_types: int, random_segment: bool):
"""
Resize the segment (token type) embeddings for BERT
"""
if hasattr(model, 'bert'):
old_token_type_embeddings = model.bert.embeddings.token_type_embeddings
else:
raise NotImplementedError
new_token_type_embeddings = nn.Embedding(new_num_types, old_token_type_embeddings.weight.size(1))
if not random_segment:
new_token_type_embeddings.weight.data[:old_token_type_embeddings.weight.size(0)] = old_token_type_embeddings.weight.data
model.config.type_vocab_size = new_num_types
if hasattr(model, 'bert'):
model.bert.embeddings.token_type_embeddings = new_token_type_embeddings
else:
raise NotImplementedError
class BertForPromptFinetuning(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
# These attributes should be assigned once the model is initialized
self.model_args = None
self.data_args = None
self.label_word_list = None
# For regression
self.lb = None
self.ub = None
# For label search.
self.return_full_softmax = None
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
mask_pos=None,
labels=None,
):
batch_size = input_ids.size(0)
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.cls(sequence_mask_output)
# Exit early and only return mask logits.
if self.return_full_softmax:
if labels is not None:
return torch.zeros(1, out=prediction_mask_scores.new()), prediction_mask_scores
return prediction_mask_scores
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
return ((loss,) + output) if loss is not None else output
class RobertaForPromptFinetuning(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config)
self.lm_head = RobertaLMHead(config)
self.init_weights()
# These attributes should be assigned once the model is initialized
self.model_args = None
self.data_args = None
self.label_word_list = None
# For regression
self.lb = None
self.ub = None
# For auto label search.
self.return_full_softmax = None
def forward(
self,
input_ids=None,
attention_mask=None,
mask_pos=None,
labels=None,
):
batch_size = input_ids.size(0)
if mask_pos is not None:
mask_pos = mask_pos.squeeze()
# Encode everything
outputs = self.roberta(
input_ids,
attention_mask=attention_mask
)
# Get <mask> token representation
sequence_output, pooled_output = outputs[:2]
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
# Logits over vocabulary tokens
prediction_mask_scores = self.lm_head(sequence_mask_output)
# Exit early and only return mask logits.
if self.return_full_softmax:
if labels is not None:
return torch.zeros(1, out=prediction_mask_scores.new()), prediction_mask_scores
return prediction_mask_scores
# Return logits for each label
logits = []
for label_id in range(len(self.label_word_list)):
logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
logits = torch.cat(logits, -1)
# Regression task
if self.config.num_labels == 1:
logsoftmax = nn.LogSoftmax(-1)
logits = logsoftmax(logits) # Log prob of right polarity
loss = None
if labels is not None:
if self.num_labels == 1:
# Regression task
loss_fct = nn.KLDivLoss(log_target=True)
labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
loss = loss_fct(logits.view(-1, 2), labels)
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
output = (logits,)
if self.num_labels == 1:
# Regression output
output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
return ((loss,) + output) if loss is not None else output
"""Dataset utils for different data settings for GLUE."""
import os
import copy
import logging
import torch
import numpy as np
import time
from filelock import FileLock
import json
import itertools
import random
import transformers
from transformers.data.processors.utils import InputFeatures
from transformers import DataProcessor, InputExample
from transformers.data.processors.glue import *
from transformers.data.metrics import glue_compute_metrics
import dataclasses
from dataclasses import dataclass, asdict
from typing import List, Optional, Union
from sentence_transformers import SentenceTransformer, util
from copy import deepcopy
import pandas as pd
import logging
logger = logging.getLogger(__name__)
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[3]
text_b = line[4]
label = line[0]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["premise"].numpy().decode("utf-8"),
tensor_dict["hypothesis"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[8]
text_b = line[9]
label = line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MnliMismatchedProcessor(MnliProcessor):
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
class SnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["premise"].numpy().decode("utf-8"),
tensor_dict["hypothesis"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[7]
text_b = line[8]
label = line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence"].numpy().decode("utf-8"),
None,
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
test_mode = set_type == "test"
text_index = 3
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = line[text_index]
label = line[1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class Sst2Processor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence"].numpy().decode("utf-8"),
None,
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
examples = []
text_index = 0
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[text_index]
label = line[1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class StsbProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return [None]
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[7]
text_b = line[8]
label = line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class QqpProcessor(DataProcessor):
"""Processor for the QQP data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["question1"].numpy().decode("utf-8"),
tensor_dict["question2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
test_mode = set_type == "test"
q1_index = 3
q2_index = 4
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
try:
text_a = line[q1_index]
text_b = line[q2_index]
label = line[5]
except IndexError:
continue
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["question"].numpy().decode("utf-8"),
tensor_dict["sentence"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class TextClassificationProcessor(DataProcessor):
"""
Data processor for text classification datasets (mr, sst-5, subj, trec, cr, mpqa).
"""
def __init__(self, task_name):
self.task_name = task_name
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence"].numpy().decode("utf-8"),
None,
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(pd.read_csv(os.path.join(data_dir, "train.csv"), header=None).values.tolist(), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(pd.read_csv(os.path.join(data_dir, "dev.csv"), header=None).values.tolist(), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(pd.read_csv(os.path.join(data_dir, "test.csv"), header=None).values.tolist(), "test")
def get_labels(self):
"""See base class."""
if self.task_name == "mr":
return list(range(2))
elif self.task_name == "sst-5":
return list(range(5))
elif self.task_name == "subj":
return list(range(2))
elif self.task_name == "trec":
return list(range(6))
elif self.task_name == "cr":
return list(range(2))
elif self.task_name == "mpqa":
return list(range(2))
else:
raise Exception("task_name not supported.")
def _create_examples(self, lines, set_type):
"""Creates examples for the training, dev and test sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
if self.task_name == "ag_news":
examples.append(InputExample(guid=guid, text_a=line[1] + '. ' + line[2], short_text=line[1] + ".", label=line[0]))
elif self.task_name == "yelp_review_full":
examples.append(InputExample(guid=guid, text_a=line[1], short_text=line[1], label=line[0]))
elif self.task_name == "yahoo_answers":
text = line[1]
if not pd.isna(line[2]):
text += ' ' + line[2]
if not pd.isna(line[3]):
text += ' ' + line[3]
examples.append(InputExample(guid=guid, text_a=text, short_text=line[1], label=line[0]))
elif self.task_name in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']:
examples.append(InputExample(guid=guid, text_a=line[1], label=line[0]))
else:
raise Exception("Task_name not supported.")
return examples
def text_classification_metrics(task_name, preds, labels):
return {"acc": (preds == labels).mean()}
# Add your task to the following mappings
processors_mapping = {
"cola": ColaProcessor(),
"mnli": MnliProcessor(),
"mnli-mm": MnliMismatchedProcessor(),
"mrpc": MrpcProcessor(),
"sst-2": Sst2Processor(),
"sts-b": StsbProcessor(),
"qqp": QqpProcessor(),
"qnli": QnliProcessor(),
"rte": RteProcessor(),
"wnli": WnliProcessor(),
"snli": SnliProcessor(),
"mr": TextClassificationProcessor("mr"),
"sst-5": TextClassificationProcessor("sst-5"),
"subj": TextClassificationProcessor("subj"),
"trec": TextClassificationProcessor("trec"),
"cr": TextClassificationProcessor("cr"),
"mpqa": TextClassificationProcessor("mpqa")
}
num_labels_mapping = {
"cola": 2,
"mnli": 3,
"mrpc": 2,
"sst-2": 2,
"sts-b": 1,
"qqp": 2,
"qnli": 2,
"rte": 2,
"wnli": 2,
"snli": 3,
"mr": 2,
"sst-5": 5,
"subj": 2,
"trec": 6,
"cr": 2,
"mpqa": 2
}
output_modes_mapping = {
"cola": "classification",
"mnli": "classification",
"mnli-mm": "classification",
"mrpc": "classification",
"sst-2": "classification",
"sts-b": "regression",
"qqp": "classification",
"qnli": "classification",
"rte": "classification",
"wnli": "classification",
"snli": "classification",
"mr": "classification",
"sst-5": "classification",
"subj": "classification",
"trec": "classification",
"cr": "classification",
"mpqa": "classification"
}
# Return a function that takes (task_name, preds, labels) as inputs
compute_metrics_mapping = {
"cola": glue_compute_metrics,
"mnli": glue_compute_metrics,
"mnli-mm": glue_compute_metrics,
"mrpc": glue_compute_metrics,
"sst-2": glue_compute_metrics,
"sts-b": glue_compute_metrics,
"qqp": glue_compute_metrics,
"qnli": glue_compute_metrics,
"rte": glue_compute_metrics,
"wnli": glue_compute_metrics,
"snli": text_classification_metrics,
"mr": text_classification_metrics,
"sst-5": text_classification_metrics,
"subj": text_classification_metrics,
"trec": text_classification_metrics,
"cr": text_classification_metrics,
"mpqa": text_classification_metrics,
}
# For regression task only: median
median_mapping = {
"sts-b": 2.5
}
bound_mapping = {
"sts-b": (0, 5)
}
########## The following part is copied from Transformers' trainer (3.4.0) ##########
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import collections
import inspect
import math
import os
import re
import shutil
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
import transformers
from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from transformers.file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
from transformers.integrations import (
default_hp_search_backend,
is_comet_available,
is_optuna_available,
is_ray_available,
is_tensorboard_available,
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
)
from transformers.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from transformers.modeling_utils import PreTrainedModel
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from transformers.trainer_pt_utils import (
DistributedTensorGatherer,
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
get_tpu_sampler,
nested_concat,
nested_detach,
nested_numpify,
nested_xla_mesh_reduce,
reissue_pt_warnings,
)
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalPrediction,
HPSearchBackend,
PredictionOutput,
TrainOutput,
default_compute_objective,
default_hp_space,
set_seed,
)
from transformers.training_args import TrainingArguments
from transformers.utils import logging
from tqdm import tqdm, trange
_use_native_amp = False
_use_apex = False
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
if is_in_notebook():
from transformers.utils.notebook import NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
from transformers.file_utils import is_apex_available
if is_apex_available():
from apex import amp
_use_apex = True
else:
_use_native_amp = True
from torch.cuda.amp import autocast
if version.parse(torch.__version__) < version.parse("1.2"):
_use_ddp_no_sync = False
else:
_use_ddp_no_sync = True
if is_datasets_available():
import datasets
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
if is_tensorboard_available():
from transformers.integrations import TensorBoardCallback
DEFAULT_CALLBACKS.append(TensorBoardCallback)
if is_wandb_available():
from transformers.integrations import WandbCallback
DEFAULT_CALLBACKS.append(WandbCallback)
if is_comet_available():
from transformers.integrations import CometCallback
DEFAULT_CALLBACKS.append(CometCallback)
if is_optuna_available():
import optuna
if is_ray_available():
from ray import tune
logger = logging.get_logger(__name__)
########## The above part is copied from Transformers' trainer (3.4.0) ##########
def default_dev_objective(metrics):
"""
Objective used for picking the best model on development sets
"""
if "eval_mnli/acc" in metrics:
return metrics["eval_mnli/acc"]
elif "eval_mnli-mm/acc" in metrics:
return metrics["eval_mnli-mm/acc"]
elif "eval_f1" in metrics:
return metrics["eval_f1"]
elif "eval_mcc" in metrics:
return metrics["eval_mcc"]
elif "eval_pearson" in metrics:
return metrics["eval_pearson"]
elif "eval_acc" in metrics:
return metrics["eval_acc"]
raise Exception("No metric founded for {}".format(metrics))
class Trainer(transformers.Trainer):
"""
Adding some functions based on Transformers' Trainer class.
"""
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Based on Transformers' default one, we add fixing layer option where the bottom n layers' parameters
are fixed and only the top layers are further fine-tuned.
"""
if self.optimizer is None:
params = {}
for n, p in self.model.named_parameters():
if self.args.fix_layers > 0:
if 'encoder.layer' in n:
try:
layer_num = int(n[n.find('encoder.layer') + 14:].split('.')[0])
except:
print(n)
raise Exception("")
if layer_num >= self.args.fix_layers:
print('yes', n)
params[n] = p
else:
print('no ', n)
elif 'embeddings' in n:
print('no ', n)
else:
print('yes', n)
params[n] = p
else:
params[n] = p
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in params.items() if not any(nd in n for nd in no_decay)],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in params.items() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
self.optimizer = AdamW(
optimizer_grouped_parameters,
lr=self.args.learning_rate,
betas=(self.args.adam_beta1, self.args.adam_beta2),
eps=self.args.adam_epsilon,
)
if self.lr_scheduler is None:
self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)
def train(self, model_path=None, dev_objective=None):
"""
Main training entry point.
The training logic is directly borrowed from transformers.Trainer (version 3.0.2).
Add early stopping.
"""
self.best_dir = None
self.objective = -float("inf")
self.dev_objective = dev_objective if dev_objective is not None else default_dev_objective
# Data loading.
train_dataloader = self.get_train_dataloader()
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
if num_update_steps_per_epoch == 0:
num_update_steps_per_epoch = 1
if self.args.max_steps > 0:
t_total = self.args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0
)
else:
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs
self.create_optimizer_and_scheduler(num_training_steps=t_total)
optimizer = self.optimizer
scheduler = self.lr_scheduler
# Check if saved optimizer or scheduler states exist
if (
model_path is not None
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
):
# Load in optimizer and scheduler states
optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
model = self.model
if self.args.fp16 and _use_apex:
if not transformers.is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if self.args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
find_unused_parameters=True,
)
# Train
if transformers.is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
else:
total_train_batch_size = (
self.args.train_batch_size
* self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
logger.info(" Num Epochs = %d", num_train_epochs)
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
self.global_step = 0
self.epoch = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if model_path is not None:
# set global_step to global_step of last saved checkpoint from model path
try:
self.global_step = int(model_path.split("-")[-1].split("/")[0])
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
steps_trained_in_current_epoch = self.global_step % (
len(train_dataloader) // self.args.gradient_accumulation_steps
)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
self.global_step = 0
logger.info(" Starting fine-tuning.")
tr_loss = torch.tensor(0.0).to(self.args.device)
logging_loss_scalar = 0.0
model.zero_grad()
train_iterator = trange(
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
)
for epoch in train_iterator:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
if transformers.is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device
)
epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
tr_loss += self.training_step(model, inputs)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
len(epoch_iterator) <= self.args.gradient_accumulation_steps
and (step + 1) == len(epoch_iterator)
):
if self.args.fp16 and _use_native_amp:
self.scaler.unscale_(optimizer)
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.args.fp16:
norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
else:
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
if transformers.is_torch_tpu_available():
xm.optimizer_step(optimizer)
elif self.args.fp16 and _use_native_amp:
self.scaler.step(optimizer)
self.scaler.update()
else:
optimizer.step()
scheduler.step()
model.zero_grad()
self.global_step += 1
self.epoch = epoch + (step + 1) / len(epoch_iterator)
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
self.global_step == 1 and self.args.logging_first_step
):
logs = {}
tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
logs["norm"] = norm.item()
# backward compatibility for pytorch schedulers
logs["learning_rate"] = (
scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else scheduler.get_lr()[0]
)
logging_loss_scalar = tr_loss_scalar
self.log(logs)
# ----------------------------------------------------------------------
# BEGIN CHANGES.
# ----------------------------------------------------------------------
metrics = None
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
output = self.evaluate()
metrics = output.metrics
objective = self.dev_objective(metrics)
if objective > self.objective:
logger.info("Best dev result: {}".format(objective))
self.objective = objective
self.save_model(self.args.output_dir)
# ----------------------------------------------------------------------
# END CHANGES.
# ----------------------------------------------------------------------
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
epoch_iterator.close()
break
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
train_iterator.close()
break
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
return TrainOutput(self.global_step, tr_loss / self.global_step), self.objective
"""
Difference compared to original implementation: return output instead of output.metrics (so there is also the logits)
"""
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are
task-dependent (pass it to the init :obj:`compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement
the :obj:`__len__` method.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
"""
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
eval_dataloader = self.get_eval_dataloader(eval_dataset)
output = self.prediction_loop(eval_dataloader, description="Evaluation")
self.log(output.metrics)
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
return output
import argparse
import pandas as pd
import json
import numpy as np
import torch
import os
from torch import device
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
from transformers import GlueDataTrainingArguments, glue_compute_metrics
from transformers.data.metrics import simple_accuracy
from transformers.data.processors.glue import glue_processors
def get_glue_label(task, line):
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
line = line.strip().split('\t')
if task == 'CoLA':
return line[1]
elif task == 'MNLI':
return line[-1]
elif task == 'MRPC':
return line[0]
elif task == 'QNLI':
return line[-1]
elif task == 'QQP':
return line[-1]
elif task == 'RTE':
return line[-1]
elif task == 'SNLI':
return line[-1]
elif task == 'SST-2':
return line[-1]
elif task == 'STS-B':
return 0 if float(line[-1]) < 2.5 else 1
elif task == 'WNLI':
return line[-1]
else:
raise NotImplementedError
else:
raise NotImplementedError
def get_labels(data_dir, k, seed, task, print_name):
if print_name in ['sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec']:
data = pd.read_csv(os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test.csv'), header=None).values.tolist()
labels = np.zeros((len(data)), dtype=np.uint8)
for i, example in enumerate(data):
labels[i] = example[0]
elif print_name in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
lines = []
file_name = os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test.tsv')
if task == 'mnli':
file_name = os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test_matched.tsv')
elif task == 'mnli-mm':
file_name = os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test_mismatched.tsv')
for line in open(file_name):
lines.append(line.strip())
if task != 'cola':
lines = lines[1:]
label_list = glue_processors[task]().get_labels()
label_map = {k: i for i, k in enumerate(label_list)}
if task == 'sts-b':
label_map = {0: 0, 1: 1}
label_ids = np.zeros((len(lines)))
for line_id, line in enumerate(lines):
label_ids[line_id] = label_map[get_glue_label(print_name, line)]
return label_ids
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--n_models", type=int, help="Number of models")
parser.add_argument("--k", type=int, default=16, help="Number of training instances per label")
parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
# These options should usually be kept as their default values
parser.add_argument("--data_dir", type=str, default="data/k-shot", help="Data directory")
parser.add_argument("--save_logit_dir", type=str, default="ensemble_predict_results", help="Directory to store the logit file.")
parser.add_argument("--log", type=str, default="log", help="Log path.")
parser.add_argument("--key", type=str, default='', help="Validation metric name")
parser.add_argument("--test_key", type=str, default="", help="Test metric name")
parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
args = parser.parse_args()
condition = eval(args.condition)
if len(args.key) == 0:
if condition['task_name'] == 'cola':
args.key = 'cola_dev_eval_mcc'
args.test_key = 'cola_test_eval_mcc'
elif condition['task_name'] == 'mrpc/acc':
args.key = 'mrpc_dev_eval_acc'
args.test_key = 'mrpc_test_eval_acc'
args.test_key2 = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
elif condition['task_name'] == 'mrpc/f1':
args.key = 'mrpc_dev_eval_f1'
args.test_key2 = 'mrpc_test_eval_acc'
args.test_key = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
elif condition['task_name'] == 'qqp/acc':
args.key = 'qqp_dev_eval_acc'
args.test_key = 'qqp_test_eval_acc'
args.test_key2 = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
elif condition['task_name'] == 'qqp/f1':
args.key = 'qqp_dev_eval_f1'
args.test_key2 = 'qqp_test_eval_acc'
args.test_key = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
elif condition['task_name'] == 'sts-b/pearson':
args.key = 'sts-b_dev_eval_pearson'
args.test_key = 'sts-b_test_eval_pearson'
args.test_key2 = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
elif condition['task_name'] == 'sts-b/spearmanr':
args.key = 'sts-b_dev_eval_spearmanr'
args.test_key2 = 'sts-b_test_eval_pearson'
args.test_key = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
elif condition['task_name'] == 'qnli':
args.key = 'qnli_dev_eval_acc'
args.test_key = 'qnli_test_eval_acc'
elif condition['task_name'] == 'sst-2':
args.key = 'sst-2_dev_eval_acc'
args.test_key = 'sst-2_test_eval_acc'
elif condition['task_name'] == 'snli':
args.key = 'snli_dev_eval_acc'
args.test_key = 'snli_test_eval_acc'
elif condition['task_name'] == 'mnli':
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli_test_eval_mnli/acc'
elif condition['task_name'] == 'mnli-mm':
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
elif condition['task_name'] == 'rte':
args.key = 'rte_dev_eval_acc'
args.test_key = 'rte_test_eval_acc'
elif condition['task_name'] == 'ag_news':
args.key = 'ag_news_dev_eval_acc'
args.test_key = 'ag_news_test_eval_acc'
elif condition['task_name'] == 'yahoo_answers':
args.key = 'yahoo_answers_dev_eval_acc'
args.test_key = 'yahoo_answers_test_eval_acc'
elif condition['task_name'] == 'yelp_review_full':
args.key = 'yelp_review_full_dev_eval_acc'
args.test_key = 'yelp_review_full_test_eval_acc'
elif condition['task_name'] == 'mr':
args.key = 'mr_dev_eval_acc'
args.test_key = 'mr_test_eval_acc'
elif condition['task_name'] == 'sst-5':
args.key = 'sst-5_dev_eval_acc'
args.test_key = 'sst-5_test_eval_acc'
elif condition['task_name'] == 'subj':
args.key = 'subj_dev_eval_acc'
args.test_key = 'subj_test_eval_acc'
elif condition['task_name'] == 'trec':
args.key = 'trec_dev_eval_acc'
args.test_key = 'trec_test_eval_acc'
elif condition['task_name'] == 'cr':
args.key = 'cr_dev_eval_acc'
args.test_key = 'cr_test_eval_acc'
elif condition['task_name'] == 'mpqa':
args.key = 'mpqa_dev_eval_acc'
args.test_key = 'mpqa_test_eval_acc'
else:
raise NotImplementedError
with open(args.log) as f:
result_list = []
for line in f:
result_list.append(eval(line))
seed_result = {}
seed_best = {}
# Gather all logs satisfying the conditions
for item in result_list:
ok = True
for cond in condition:
if cond == 'task_name' and condition['task_name'] == 'mnli-mm':
if cond not in item or item[cond] != 'mnli':
ok = False
break
else:
if cond not in item or item[cond] != condition[cond]:
ok = False
break
if 'model_id' not in item or 'array_id' not in item:
ok = False
if ok:
seed = int(item['data_dir'].split('-')[-1])
model_id = item['model_id']
array_id = item['array_id']
if model_id >= 0 and model_id < args.n_models:
if seed not in seed_result:
seed_result[seed] = {}
seed_best[seed] = {}
if model_id not in seed_result[seed]:
seed_result[seed][model_id] = []
seed_best[seed][model_id] = {args.key: -1e9}
seed_result[seed][model_id].append(item)
if item[args.key] > seed_best[seed][model_id][args.key]:
seed_best[seed][model_id] = item
final_result_dev = np.zeros((len(seed_result), args.n_models))
final_result_test = np.zeros((len(seed_result), args.n_models))
final_result_test2 = np.zeros((len(seed_result), args.n_models))
logit_file_list = {}
for seed in seed_result:
logit_file_list[seed] = []
# Get the results for each model and pick the best dev trial for each model/seed
for model_id in range(args.n_models):
for i, seed in enumerate(seed_result):
final_result_dev[i][model_id] = seed_best[seed][model_id][args.key]
final_result_test[i][model_id] = seed_best[seed][model_id][args.test_key]
if len(args.test_key2) > 0:
final_result_test2[i][model_id] = seed_best[seed][model_id][args.test_key2]
logit_file_list[seed].append("{}-{}-{}.npy".format(condition['task_name'], model_id, seed_best[seed][model_id]["array_id"]))
s = "Model %d | val: mean +- std: %.1f +- %.1f | test: mean +- std: %.1f (%.1f) (median %.1f)" % (model_id, final_result_dev[:, model_id].mean() * 100, final_result_dev[:, model_id].std() * 100, final_result_test[:, model_id].mean() * 100, final_result_test[:, model_id].std() * 100, np.median(final_result_test[:, model_id]) * 100)
if len(args.test_key2) > 0:
s += " / %.1f +- %.1f (median %.1f)" % (final_result_test2[:, model_id].mean() * 100, final_result_test2[:, model_id].std() * 100, np.median(final_result_test2[:, model_id]) * 100)
print(s)
# Map lower-case names to official names (data folder name)
data_dir_mapping = {
'cola': 'CoLA',
'mrpc': 'MRPC',
'qqp': 'QQP',
'sts-b': 'STS-B',
'sst-2': 'SST-2',
'snli': 'SNLI',
'mnli': 'MNLI',
'mnli-mm': 'MNLI',
'rte': 'RTE',
'ag_news': 'ag_news',
'yahoo_answers': 'yahoo_answers',
'yelp_review_full': 'yelp_review_full',
'sst-5': 'sst-5',
'mr': 'mr',
'cr': 'cr',
'mpqa': 'mpqa',
'subj': 'subj',
'trec': 'trec'
}
tokenizer = AutoTokenizer.from_pretrained('roberta-large')
ensemble_result = np.zeros((len(seed_result)))
ensemble_result2 = np.zeros((len(seed_result))) # for second metric
# Ensemble for each seed
for seed_id, seed in enumerate(seed_result):
labels = get_labels(args.data_dir, args.k, seed, condition['task_name'], data_dir_mapping[condition['task_name']])
# Logits
mean_logits = None
for fname in logit_file_list[seed]:
logits = np.load(os.path.join(args.save_logit_dir, fname))
if mean_logits is None:
mean_logits = logits
else:
mean_logits += logits
mean_logits /= len(logit_file_list[seed])
# Compute metrics
preds = mean_logits.argmax(-1)
if condition['task_name'] in ['sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec']:
metric = {"acc": simple_accuracy(preds, labels)}
else:
metric = glue_compute_metrics(condition['task_name'], preds, labels)
ensemble_result[seed_id] = metric[args.test_key.split('_')[-1]]
if len(args.test_key2) > 0:
ensemble_result2[seed_id] = metric[args.test_key2.split('_')[-1]]
s = "mean +- std: %.1f (%.1f) (median %.1f)" % (ensemble_result.mean() * 100, ensemble_result.std() * 100, np.median(ensemble_result) * 100)
if len(args.test_key2) > 0:
s += " / %.1f (%.1f) (median %.1f)" % (ensemble_result2.mean() * 100, ensemble_result2.std() * 100, np.median(ensemble_result2) * 100)
print(s)
if __name__ == '__main__':
main()
import argparse
import json
import numpy as np
import torch
from torch import device
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
# These options should be kept as their default values
parser.add_argument("--log", type=str, default="log", help="Log path.")
parser.add_argument("--key", type=str, default='', help="Validation metric name")
parser.add_argument("--test_key", type=str, default="", help="Test metric name")
parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
args = parser.parse_args()
condition = eval(args.condition)
if len(args.key) == 0:
if condition['task_name'] == 'cola':
args.key = 'cola_dev_eval_mcc'
args.test_key = 'cola_test_eval_mcc'
elif condition['task_name'] == 'mrpc/acc':
args.key = 'mrpc_dev_eval_acc'
args.test_key = 'mrpc_test_eval_acc'
args.test_key2 = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
elif condition['task_name'] == 'mrpc/f1':
args.key = 'mrpc_dev_eval_f1'
args.test_key2 = 'mrpc_test_eval_acc'
args.test_key = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
elif condition['task_name'] == 'qqp/acc':
args.key = 'qqp_dev_eval_acc'
args.test_key = 'qqp_test_eval_acc'
args.test_key2 = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
elif condition['task_name'] == 'qqp/f1':
args.key = 'qqp_dev_eval_f1'
args.test_key2 = 'qqp_test_eval_acc'
args.test_key = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
elif condition['task_name'] == 'sts-b/pearson':
args.key = 'sts-b_dev_eval_pearson'
args.test_key = 'sts-b_test_eval_pearson'
args.test_key2 = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
elif condition['task_name'] == 'sts-b/spearmanr':
args.key = 'sts-b_dev_eval_spearmanr'
args.test_key2 = 'sts-b_test_eval_pearson'
args.test_key = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
elif condition['task_name'] == 'qnli':
args.key = 'qnli_dev_eval_acc'
args.test_key = 'qnli_test_eval_acc'
elif condition['task_name'] == 'sst-2':
args.key = 'sst-2_dev_eval_acc'
args.test_key = 'sst-2_test_eval_acc'
elif condition['task_name'] == 'snli':
args.key = 'snli_dev_eval_acc'
args.test_key = 'snli_test_eval_acc'
elif condition['task_name'] == 'mnli':
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli_test_eval_mnli/acc'
elif condition['task_name'] == 'mnli-mm':
condition['task_name'] = 'mnli'
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
elif condition['task_name'] == 'rte':
args.key = 'rte_dev_eval_acc'
args.test_key = 'rte_test_eval_acc'
elif condition['task_name'] == 'ag_news':
args.key = 'ag_news_dev_eval_acc'
args.test_key = 'ag_news_test_eval_acc'
elif condition['task_name'] == 'yahoo_answers':
args.key = 'yahoo_answers_dev_eval_acc'
args.test_key = 'yahoo_answers_test_eval_acc'
elif condition['task_name'] == 'yelp_review_full':
args.key = 'yelp_review_full_dev_eval_acc'
args.test_key = 'yelp_review_full_test_eval_acc'
elif condition['task_name'] == 'mr':
args.key = 'mr_dev_eval_acc'
args.test_key = 'mr_test_eval_acc'
elif condition['task_name'] == 'sst-5':
args.key = 'sst-5_dev_eval_acc'
args.test_key = 'sst-5_test_eval_acc'
elif condition['task_name'] == 'subj':
args.key = 'subj_dev_eval_acc'
args.test_key = 'subj_test_eval_acc'
elif condition['task_name'] == 'trec':
args.key = 'trec_dev_eval_acc'
args.test_key = 'trec_test_eval_acc'
elif condition['task_name'] == 'cr':
args.key = 'cr_dev_eval_acc'
args.test_key = 'cr_test_eval_acc'
elif condition['task_name'] == 'mpqa':
args.key = 'mpqa_dev_eval_acc'
args.test_key = 'mpqa_test_eval_acc'
else:
raise NotImplementedError
with open(args.log) as f:
result_list = []
for line in f:
result_list.append(eval(line))
seed_result = {}
seed_best = {}
for item in result_list:
ok = True
for cond in condition:
if isinstance(condition[cond], list):
if cond not in item or (item[cond] not in condition[cond]):
ok = False
break
else:
if cond not in item or (item[cond] != condition[cond]):
ok = False
break
if ok:
seed = item['data_dir'].split('-')[-1] + '-' + str(item['seed'])
if seed not in seed_result:
seed_result[seed] = [item]
seed_best[seed] = item
else:
seed_result[seed].append(item)
if item[args.key] > seed_best[seed][args.key]:
seed_best[seed] = item
final_result_dev = np.zeros((len(seed_best)))
final_result_test = np.zeros((len(seed_best)))
final_result_test2 = np.zeros((len(seed_best)))
for i, seed in enumerate(seed_best):
final_result_dev[i] = seed_best[seed][args.key]
final_result_test[i] = seed_best[seed][args.test_key]
if len(args.test_key2) > 0:
final_result_test2[i] = seed_best[seed][args.test_key2]
print("%s: best dev (%.5f) test (%.5f) %s | total trials: %d" % (
seed,
seed_best[seed][args.key],
seed_best[seed][args.test_key],
"test2 (%.5f)" % (seed_best[seed][args.test_key2]) if len(args.test_key2) > 0 else "",
len(seed_result[seed])
))
s = ''
for k in ['per_device_train_batch_size', 'gradient_accumulation_steps', 'learning_rate', 'eval_steps', 'max_steps']:
s += '| {}: {} '.format(k, seed_best[seed][k])
print(' ' + s)
s = "mean +- std: %.1f (%.1f) (median %.1f)" % (final_result_test.mean() * 100, final_result_test.std() * 100, np.median(final_result_test) * 100)
if len(args.test_key2) > 0:
s += "second metric: %.1f (%.1f) (median %.1f)" % (final_result_test2.mean() * 100, final_result_test2.std() * 100, np.median(final_result_test2) * 100)
print(s)
if __name__ == '__main__':
main()
"""This script samples K examples randomly without replacement from the original data."""
import argparse
import os
import numpy as np
import pandas as pd
from pandas import DataFrame
def get_label(task, line):
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
# GLUE style
line = line.strip().split('\t')
if task == 'CoLA':
return line[1]
elif task == 'MNLI':
return line[-1]
elif task == 'MRPC':
return line[0]
elif task == 'QNLI':
return line[-1]
elif task == 'QQP':
return line[-1]
elif task == 'RTE':
return line[-1]
elif task == 'SNLI':
return line[-1]
elif task == 'SST-2':
return line[-1]
elif task == 'STS-B':
return 0 if float(line[-1]) < 2.5 else 1
elif task == 'WNLI':
return line[-1]
else:
raise NotImplementedError
else:
return line[0]
def load_datasets(data_dir, tasks):
datasets = {}
for task in tasks:
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
# GLUE style (tsv)
dataset = {}
dirname = os.path.join(data_dir, task)
if task == "MNLI":
splits = ["train", "dev_matched", "dev_mismatched"]
else:
splits = ["train", "dev"]
for split in splits:
filename = os.path.join(dirname, f"{split}.tsv")
with open(filename, "r") as f:
lines = f.readlines()
dataset[split] = lines
datasets[task] = dataset
else:
# Other datasets (csv)
dataset = {}
dirname = os.path.join(data_dir, task)
splits = ["train", "test"]
for split in splits:
filename = os.path.join(dirname, f"{split}.csv")
dataset[split] = pd.read_csv(filename, header=None)
datasets[task] = dataset
return datasets
def split_header(task, lines):
"""
Returns if the task file has a header or not. Only for GLUE tasks.
"""
if task in ["CoLA"]:
return [], lines
elif task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI"]:
return lines[0:1], lines[1:]
else:
raise ValueError("Unknown GLUE task.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--k", type=int, default=16,
help="Training examples for each class.")
parser.add_argument("--task", type=str, nargs="+",
default=['SST-2', 'sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'],
help="Task names")
parser.add_argument("--seed", type=int, nargs="+",
default=[100, 13, 21, 42, 87],
help="Random seeds")
parser.add_argument("--data_dir", type=str, default="data/original", help="Path to original data")
parser.add_argument("--output_dir", type=str, default="data", help="Output path")
parser.add_argument("--mode", type=str, default='k-shot', choices=['k-shot', 'k-shot-10x'], help="k-shot or k-shot-10x (10x dev set)")
args = parser.parse_args()
args.output_dir = os.path.join(args.output_dir, args.mode)
k = args.k
print("K =", k)
datasets = load_datasets(args.data_dir, args.task)
for seed in args.seed:
print("Seed = %d" % (seed))
for task, dataset in datasets.items():
# Set random seed
np.random.seed(seed)
# Shuffle the training set
print("| Task = %s" % (task))
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
# GLUE style
train_header, train_lines = split_header(task, dataset["train"])
np.random.shuffle(train_lines)
else:
# Other datasets
train_lines = dataset['train'].values.tolist()
np.random.shuffle(train_lines)
# Set up dir
task_dir = os.path.join(args.output_dir, task)
setting_dir = os.path.join(task_dir, f"{k}-{seed}")
os.makedirs(setting_dir, exist_ok=True)
# Write test splits
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
# GLUE style
# Use the original development set as the test set (the original test sets are not publicly available)
for split, lines in dataset.items():
if split.startswith("train"):
continue
split = split.replace('dev', 'test')
with open(os.path.join(setting_dir, f"{split}.tsv"), "w") as f:
for line in lines:
f.write(line)
else:
# Other datasets
# Use the original test sets
dataset['test'].to_csv(os.path.join(setting_dir, 'test.csv'), header=False, index=False)
# Get label list for balanced sampling
label_list = {}
for line in train_lines:
label = get_label(task, line)
if label not in label_list:
label_list[label] = [line]
else:
label_list[label].append(line)
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
with open(os.path.join(setting_dir, "train.tsv"), "w") as f:
for line in train_header:
f.write(line)
for label in label_list:
for line in label_list[label][:k]:
f.write(line)
name = "dev.tsv"
if task == 'MNLI':
name = "dev_matched.tsv"
with open(os.path.join(setting_dir, name), "w") as f:
for line in train_header:
f.write(line)
for label in label_list:
dev_rate = 11 if '10x' in args.mode else 2
for line in label_list[label][k:k*dev_rate]:
f.write(line)
else:
new_train = []
for label in label_list:
for line in label_list[label][:k]:
new_train.append(line)
new_train = DataFrame(new_train)
new_train.to_csv(os.path.join(setting_dir, 'train.csv'), header=False, index=False)
new_dev = []
for label in label_list:
dev_rate = 11 if '10x' in args.mode else 2
for line in label_list[label][k:k*dev_rate]:
new_dev.append(line)
new_dev = DataFrame(new_dev)
new_dev.to_csv(os.path.join(setting_dir, 'dev.csv'), header=False, index=False)
if __name__ == "__main__":
main()
"""Finetuning the library models for sequence classification on GLUE."""
import os, sys, inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
import logging
import json
from dataclasses import dataclass, field
from typing import Optional
from transformers import AutoConfig, AutoTokenizer
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import HfArgumentParser, TrainingArguments, set_seed
from src.label_search import find_labels
from src.dataset import FewShotDataset
from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings
from src.trainer import Trainer
from src.processors import output_modes_mapping, num_labels_mapping
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
@dataclass
class DynamicDataTrainingArguments(DataTrainingArguments):
"""
Arguments for dynamic training.
"""
# For prompting
template: str = field(
default=None,
metadata={"help": "Template"}
)
mapping: str = field(
default=None,
metadata={"help": "Label word mapping"}
)
debug_mode: bool = field(
default=False,
metadata={"help": "Debug mode"}
)
first_sent_limit: int = field(
default=None,
metadata={"help": "Limit the length of the first sentence (i.e., sent_0)"}
)
other_sent_limit: int = field(
default=None,
metadata={"help": "Limit the length of sentences other than the first sentence"}
)
use_full_length: bool = field(
default=None,
metadata={"help": "Use the full length (512)"}
)
truncate_head: bool = field(
default=False,
metadata={"help": "When exceeding the maximum length, truncate the head instead of the tail."}
)
use_space_word: bool = field(
default=True,
metadata={"help": "Use space words (e.g., Gpositive) instead of original words."}
)
use_seed_labels: bool = field(
default=False,
metadata={"help": "Regularize using seed labels"},
)
k_likely: int = field(
default=100,
metadata={"help": "Take the top-k most (conditionally) likely labels per class."}
)
k_neighbors: int = field(
default=50,
metadata={"help": "Re-rank by nearest neighbor, and take the top k."}
)
n_pairs: int = field(
default=32,
metadata={"help": "Number of label pairings to use."}
)
output_file: str = field(
default="out",
metadata={"help": "Output file"}
)
append_output_file: bool = field(
default=False,
)
write_template: bool = field(
default=False,
)
def main():
parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Fix prompt to be true.
data_args.prompt = True
data_args.num_sample = 1
data_args.template_list = None
data_args.gpt3_in_context_head = False
data_args.gpt3_in_context_tail = False
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
# Check save path
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(f"Output directory ({training_args.output_dir}) already exists.")
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
# Set seed
set_seed(training_args.seed)
try:
num_labels = num_labels_mapping[data_args.task_name]
output_mode = output_modes_mapping[data_args.task_name]
logger.info("Task name: {}, number of labels: {}, output mode: {}".format(data_args.task_name, num_labels, output_mode))
except KeyError:
raise ValueError("Task not found: %s" % (data_args.task_name))
# Create config
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
)
if config.model_type == 'roberta':
model_fn = RobertaForPromptFinetuning
elif config.model_type == 'bert':
model_fn = BertForPromptFinetuning
else:
raise NotImplementedError
special_tokens = []
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
additional_special_tokens=special_tokens,
cache_dir=model_args.cache_dir,
)
set_seed(training_args.seed)
model = model_fn.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
# For BERT, increase the size of the segment (token type) embeddings
if config.model_type == 'bert':
model.resize_token_embeddings(len(tokenizer))
resize_token_type_embeddings(model, new_num_types=10, random_segment=model_args.random_segment)
# Pass dataset and argument information to the model
model.model_args = model_args
model.data_args = data_args
model.tokenizer = tokenizer
model.return_full_softmax = True
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=None,
eval_dataset=None,
)
# First we compute zero-shot logits on all of the examples.
dataset = FewShotDataset(data_args, tokenizer=tokenizer, mode="train", use_demo=False)
# Predict logits.
dataloader = trainer.get_eval_dataloader(dataset)
output = trainer.prediction_loop(dataloader, description="Evaluation")
logits = output.predictions[0] if isinstance(output.predictions, (list, tuple)) else output.predictions
labels = output.label_ids
# Assign words to labels.
if data_args.use_seed_labels:
if data_args.use_space_word:
seed_labels = {k: "Ġ" + v for k, v in eval(data_args.mapping).items()}
else:
seed_labels = eval(data_args.word_mapping)
seed_labels = [seed_labels[label] for label in dataset.get_labels()]
else:
seed_labels = None
vocab = list(tokenizer.get_vocab())
# Find best labels.
label_pairings = find_labels(
model=trainer.model,
train_logits=logits,
train_labels=labels,
seed_labels=seed_labels,
k_likely=data_args.k_likely,
k_neighbors=data_args.k_neighbors,
top_n=data_args.n_pairs,
vocab=vocab,
is_regression=config.num_labels == 1)
labels = dataset.get_labels()
if config.num_labels == 1:
labels = ['0', '1']
os.makedirs(os.path.dirname(data_args.output_file), exist_ok=True)
if data_args.append_output_file:
mode = "a"
else:
mode = "w"
# Write to output.
with open(data_args.output_file, mode) as f:
for pairing in label_pairings:
words = [vocab[i][len("Ġ"):] for i in pairing]
mapping = {labels[i]: words[i] for i in range(len(labels))}
if data_args.write_template:
f.write(data_args.template + "\t")
f.write(json.dumps(mapping) + "\n")
if __name__ == "__main__":
main()
import transformers
from transformers import T5ForConditionalGeneration, T5Tokenizer
import argparse
import torch
import os
from tqdm import tqdm
import json
import argparse
import pandas as pd
def get_text(template, input_text_tuple, label, tokenizer, mapping):
def enc(text):
return tokenizer.encode(text, add_special_tokens=False)
special_token_mapping = {'cls': tokenizer.cls_token_id, 'mask': tokenizer.mask_token_id, 'sep': tokenizer.sep_token_id, 'sep+': tokenizer.sep_token_id}
for i in range(10):
special_token_mapping["<extra_id_%d>" % (i)] = tokenizer._convert_token_to_id("<extra_id_%d>" % (i))
template_list = template.split('*')
input_ids = []
for part in template_list:
new_tokens = []
if part in special_token_mapping:
if part == 'cls' and 'T5' in type(tokenizer).__name__:
# T5 does not have cls token
continue
new_tokens.append(special_token_mapping[part])
elif part[:5] == 'label':
new_tokens += enc(' ' + mapping[label])
elif part[:5] == 'sent_':
sent_id = int(part.split('_')[1])
new_tokens += enc(input_text_tuple[sent_id])
elif part[:6] == '+sent_':
sent_id = int(part.split('_')[1])
new_tokens += enc(' ' + input_text_tuple[sent_id]) # add space
elif part[:6] == 'sent-_':
# Delete the last token
sent_id = int(part.split('_')[1])
new_tokens += enc(input_text_tuple[sent_id][:-1])
elif part[:7] == '+sentl_':
# Lower case the first token
sent_id = int(part.split('_')[1])
text = input_text_tuple[sent_id]
text = text[:1].lower() + text[1:]
new_tokens += enc(' ' + text)
elif part[:7] == '+sentu_':
# Upper case the first token
sent_id = int(part.split('_')[1])
text = input_text_tuple[sent_id]
text = text[:1].upper() + text[1:]
new_tokens += enc(' ' + text)
elif part[:6] == 'sentl_':
# Lower case the first token
sent_id = int(part.split('_')[1])
text = input_text_tuple[sent_id]
text = text[:1].lower() + text[1:]
new_tokens += enc(text)
elif part[:6] == 'sentu_':
# Lower case the first token
sent_id = int(part.split('_')[1])
text = input_text_tuple[sent_id]
text = text[:1].upper() + text[1:]
new_tokens += enc(text)
elif part[:7] == 'sentl-_':
# Lower case the first token
sent_id = int(part.split('_')[1])
text = input_text_tuple[sent_id]
text = text[:1].lower() + text[1:]
new_tokens += enc(text[:-1])
else:
part = part.replace('_', ' ') # there cannot be space in command, so use '_' to replace space
# handle special case when t5 tokenizer might add an extra space
if len(part) == 1:
new_tokens.append(tokenizer._convert_token_to_id(part))
else:
new_tokens += enc(part)
input_ids += new_tokens
return input_ids
def generate(dataset, template, model, tokenizer, target_number, mapping, beam, label=None, length_limit=None, truncate=None):
"""
Generate templates based on given inputs
label: Only use instances with this label (deprecated)
length_limit: At least generate content as long as length_limit (deprecated)
"""
input_texts = []
input_tensors = []
max_length = 0
# Process the inputs
for item in dataset:
if label is None or item['label'] == label:
input_text = get_text(template, item['text'], item['label'], tokenizer, mapping)
if truncate is not None:
if truncate == 'head':
input_text = input_text[-256:]
elif truncate == 'tail':
input_text = input_text[:256]
else:
raise NotImplementedError
input_ids = torch.tensor(input_text).long()
max_length = max(max_length, input_ids.size(-1))
input_tensors.append(input_ids)
# Concatenate inputs as a batch
input_ids = torch.zeros((len(input_tensors), max_length)).long()
attention_mask = torch.zeros((len(input_tensors), max_length)).long()
for i in range(len(input_tensors)):
input_ids[i, :input_tensors[i].size(-1)] = input_tensors[i]
attention_mask[i, :input_tensors[i].size(-1)] = 1
# Print some examples
print('####### example #######')
print(tokenizer.decode(input_ids[0]))
print(tokenizer.decode(input_ids[1]))
print(tokenizer.decode(input_ids[2]))
print('####### example #######\n')
#input_ids = input_ids.cuda()
#attention_mask = attention_mask.cuda()
assert len(input_tensors) > 0
# Maximum generate content length
max_length = 20
start_mask = tokenizer._convert_token_to_id('<extra_id_0>')
ori_decoder_input_ids = torch.zeros((input_ids.size(0), max_length)).long()
ori_decoder_input_ids[..., 0] = model.config.decoder_start_token_id
# decoder_input_ids: decoder inputs for next regressive generation
# ll: log likelihood
# output_id: which part of generated contents we are at
# output: generated content so far
# last_length (deprecated): how long we have generated for this part
current_output = [{'decoder_input_ids': ori_decoder_input_ids, 'll': 0, 'output_id': 1, 'output': [], 'last_length': -1}]
for i in tqdm(range(max_length - 2)):
new_current_output = []
for item in current_output:
if item['output_id'] > target_number:
# Enough contents
new_current_output.append(item)
continue
decoder_input_ids = item['decoder_input_ids']
# Forward
batch_size = 10
turn = input_ids.size(0) // batch_size
if input_ids.size(0) % batch_size != 0:
turn += 1
aggr_output = []
for t in range(turn):
start = t * batch_size
end = min((t + 1) * batch_size, input_ids.size(0))
with torch.no_grad():
#aggr_output.append(model(input_ids[start:end], attention_mask=attention_mask[start:end], decoder_input_ids=decoder_input_ids.cuda()[start:end])[0])
aggr_output.append(model(input_ids[start:end], attention_mask=attention_mask[start:end], decoder_input_ids=decoder_input_ids[start:end])[0])
aggr_output = torch.cat(aggr_output, 0)
# Gather results across all input sentences, and sort generated tokens by log likelihood
aggr_output = aggr_output.mean(0)
log_denominator = torch.logsumexp(aggr_output[i], -1).item()
ids = list(range(model.config.vocab_size))
ids.sort(key=lambda x: aggr_output[i][x].item(), reverse=True)
ids = ids[:beam+3]
for word_id in ids:
output_id = item['output_id']
if word_id == start_mask - output_id or word_id == tokenizer._convert_token_to_id('</s>'):
# Finish one part
if length_limit is not None and item['last_length'] < length_limit[output_id - 1]:
check = False
else:
check = True
output_id += 1
last_length = 0
else:
last_length = item['last_length'] + 1
check = True
output_text = item['output'] + [word_id]
ll = item['ll'] + aggr_output[i][word_id] - log_denominator
new_decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.size())
new_decoder_input_ids[:] = decoder_input_ids
new_decoder_input_ids[..., i + 1] = word_id
# Forbid single space token, "....", and ".........."
if word_id in [3, 19794, 22354]:
check = False
# Forbid continuous "."
if len(output_text) > 1 and output_text[-2] == 5 and output_text[-1] == 5:
check = False
if check:
# Add new results to beam search pool
new_item = {'decoder_input_ids': new_decoder_input_ids, 'll': ll, 'output_id': output_id, 'output': output_text, 'last_length': last_length}
new_current_output.append(new_item)
if len(new_current_output) == 0:
break
new_current_output.sort(key=lambda x: x['ll'], reverse=True)
new_current_output = new_current_output[:beam]
current_output = new_current_output
result = []
print("####### generated results #######")
for item in current_output:
generate_text = ''
for token in item['output']:
generate_text += tokenizer._convert_id_to_token(token)
print('--------------')
print('score:', item['ll'].item())
print('generated ids', item['output'])
print('generated text', generate_text)
result.append(generate_text)
print("####### generated results #######\n")
return result
def load_dataset(task, data_dir):
if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
lines = open(os.path.join(data_dir, 'train.tsv')).readlines()
if task != 'CoLA':
lines = lines[1:]
dataset = []
for line in lines:
line = line.strip().split('\t')
if task == 'CoLA':
dataset.append({'label': line[1], 'text': [line[-1]]})
elif task == 'MNLI':
dataset.append({'label': line[-1], 'text': [line[8], line[9]]})
elif task == 'MRPC':
dataset.append({'label': line[0], 'text': [line[-2], line[-1]]})
elif task == 'QNLI':
dataset.append({'label': line[-1], 'text': [line[1], line[2]]})
elif task == 'QQP':
dataset.append({'label': line[-1], 'text': [line[3], line[4]]})
elif task == 'RTE':
dataset.append({'label': line[-1], 'text': [line[1], line[2]]})
elif task == 'SNLI':
dataset.append({'label': line[-1], 'text': [line[7], line[8]]})
elif task == 'SST-2':
dataset.append({'label': line[-1], 'text': [line[0]]})
elif task == 'STS-B':
dataset.append({'label': '0' if float(line[-1]) < 2.5 else '1', 'text': [line[-3], line[-2]]})
elif task == 'WNLI':
dataset.append({'label': line[-1], 'text': [line[1], line[2]]})
else:
raise NotImplementedError
else:
lines = pd.read_csv(os.path.join(data_dir, 'train.csv')).values.tolist()
dataset = []
for line in lines:
dataset.append({'label': line[0], 'text': [line[1]]})
return dataset
def search_template(model, tokenizer, task_name, k, seed, beam, output_dir, data_dir):
print('#', task_name, k, seed, beam)
dataset_path = os.path.join(data_dir, task_name, "{}-{}".format(k, seed))
dataset = load_dataset(task_name, dataset_path)
print('|', 'dataset examples')
print('|', dataset[0])
print('|', dataset[-1])
print()
# Manual label word mappings
map_of_mapping = {
'SST-2': {'0':'terrible','1':'great'},
'sst-5': {0:'terrible',1:'bad',2:'okay',3:'good',4:'great'},
'mr': {0:'terrible',1:'great'},
'cr': {0:'terrible',1:'great'},
'subj': {0:'subjective',1:'objective'},
'trec': {0:'Description',1:'Entity',2:'Expression',3:'Human',4:'Location',5:'Number'},
'mpqa': {0:'terrible',1:'great'},
'CoLA': {'0':'incorrect','1':'correct'},
'MRPC': {'0':'No','1':'Yes'},
'QQP': {'0':'No','1':'Yes'},
'STS-B': {'0':'No','1':'Yes'},
'MNLI': {'contradiction':'No','entailment':'Yes','neutral':'Maybe'},
'SNLI': {'contradiction':'No','entailment':'Yes','neutral':'Maybe'},
'QNLI': {'not_entailment':'No','entailment':'Yes'},
'RTE': {'not_entailment':'No','entailment':'Yes'}
}
mapping = map_of_mapping[task_name]
print('|', 'mapping')
print('|', mapping)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, task_name), exist_ok=True)
f = open(os.path.join(output_dir, task_name, "{}-{}.txt".format(k, seed)), 'w')
if task_name in ['SST-2', 'sst-5', 'mr', 'cr', 'subj', 'trec', 'CoLA', 'mpqa']:
# Single sentence tasks
# We take two kinds of templates: put [MASK] at the beginning or the end
template = "*cls**sentu_0**<extra_id_0>**label**<extra_id_1>**sep+*"
generate_text = generate(dataset, template, model, tokenizer, target_number=2, mapping=mapping, beam=beam, label=None, truncate='head')[:beam//2]
print("####### generated templates #######")
for text in generate_text:
# Transform T5 outputs to our template format
text = text.replace('<extra_id_0>', '*cls**sent_0*')
text = text.replace('<extra_id_1>', '*mask*')
text = text.replace('<extra_id_2>', '*sep+*')
text = text.replace('</s>', '*sep+*')
text = text.replace('▁', '_')
print(text)
f.write(text + '\n')
print("####### generated templates #######\n")
template = "*cls*.*<extra_id_0>**label**<extra_id_1>**+sentu_0**sep+*"
generate_text = generate(dataset, template, model, tokenizer, target_number=2, mapping=mapping, beam=beam, label=None, truncate='tail')[:beam//2]
print("####### generated templates #######")
for text in generate_text:
# Transform T5 outputs to our template format
text = text.replace('<extra_id_0>', '*cls*')
text = text.replace('<extra_id_1>', '*mask*')
text = text.replace('<extra_id_2>', '*+sent_0**sep+*')
text = text.replace('</s>', '*+sent_0**sep+*')
text = text.replace('▁', '_')
print(text)
f.write(text + '\n')
print("####### generated templates #######\n")
elif task_name in ['MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE']:
# Sentence pair tasks
# We always put [MASK] between the two sentences
template = "*cls**sent-_0**<extra_id_0>**label**<extra_id_1>**+sentl_1**sep+*"
generate_text = generate(dataset, template, model, tokenizer, target_number=2, mapping=mapping, beam=beam, label=None)
print("####### generated templates #######")
for text in generate_text:
# Transform T5 outputs to our template format
text = text.replace('<extra_id_0>', '*cls**sent-_0*')
text = text.replace('<extra_id_1>', '*mask*')
text = text.replace('<extra_id_2>', '*+sentl_1**sep+*')
text = text.replace('</s>', '*+sentl_1**sep+*')
text = text.replace('▁', '_')
print(text)
f.write(text + '\n')
print("####### generated templates #######\n")
else:
raise NotImplementedError
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--t5_model', type=str, default='t5-3b', help='T5 pre-trained model')
parser.add_argument('--seed', type=int, nargs='+', default=[42, 13, 21, 100, 87], help="Data split seeds")
parser.add_argument('--task_name', type=str, nargs='+', default=['SST-2', 'sst-5', 'mr', 'cr', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'], help="Task names")
parser.add_argument('--output_dir', type=str, default='Output directory')
parser.add_argument('--data_dir', type=str, default="data/k-shot", help="Data directory")
parser.add_argument('--beam', type=int, default=100, help="Beam search width")
parser.add_argument('--k', type=int, default=16, help="Number of training instances per label")
args = parser.parse_args()
model = T5ForConditionalGeneration.from_pretrained(args.t5_model)
tokenizer = T5Tokenizer.from_pretrained(args.t5_model)
tokenizer.sep_token = '</s>'
#model = model.cuda()
model.eval()
for task_name in args.task_name:
for seed in args.seed:
search_template(model=model, tokenizer=tokenizer, task_name=task_name, k=args.k, seed=seed, beam=args.beam, output_dir=args.output_dir, data_dir=args.data_dir)
if __name__ == '__main__':
main()
from sentence_transformers import SentenceTransformer, util
import argparse
import os
import numpy as np
from tqdm import tqdm
import pandas as pd
def get_sentence(task, line):
if task in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']:
# Text classification tasks
if line[1] is None or pd.isna(line[1]):
return ''
else:
return line[1]
else:
# GLUE tasks
line = line.strip().split('\t')
if task == 'CoLA':
return line[-1]
elif task == 'MNLI':
return line[8] + ' ' + line[9]
elif task == 'MRPC':
return line[-2] + ' ' + line[-1]
elif task == 'QNLI':
return line[1] + ' ' + line[2]
elif task == 'QQP':
return line[3] + ' ' + line[4]
elif task == 'RTE':
return line[1] + ' ' + line[2]
elif task == 'SNLI':
return line[7] + ' ' + line[8]
elif task == 'SST-2':
return line[0]
elif task == 'STS-B':
return line[-3] + ' ' + line[-2]
elif task == 'WNLI':
return line[1] + ' ' + line[2]
else:
raise NotImplementedError
def split_header(task, lines):
"""Returns if the task file has a header or not."""
if task in ["CoLA"]:
return [], lines
elif task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI"]:
return lines[0:1], lines[1:]
else:
raise ValueError("Unknown GLUE task.")
def load_datasets(data_dir, task, do_test=False):
dataset = {}
if task == "MNLI":
splits = ["train", "dev_matched"]
if do_test:
splits += ['test_matched', 'test_mismatched']
else:
splits = ["train", "dev"]
if do_test:
splits.append('test')
for split in splits:
if task in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']:
filename = os.path.join(data_dir, f"{split}.csv")
dataset[split] = pd.read_csv(filename, header=None).values.tolist()
else:
filename = os.path.join(data_dir, f"{split}.tsv")
with open(filename, "r") as f:
lines = f.readlines()
header, content = split_header(task, lines)
dataset[split] = content
return dataset
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--do_test", action='store_true', help="Generate embeddings for test splits (test set is usually large, so we don't want to repeatedly generate embeddings for them)")
parser.add_argument("--sbert_model", type=str, default='roberta-large', help="Sentence BERT model name")
parser.add_argument("--k", type=int, help="Number of training instances per label", default=16)
parser.add_argument("--data_dir", type=str, default="data/k-shot", help="Path to few-shot data")
parser.add_argument("--seed", type=int, nargs="+", default=[42, 13, 21, 87, 100], help="Seeds for data splits")
parser.add_argument("--task", type=str, nargs="+", default=["SST-2", "sst-5", "mr", "cr", "mpqa", "subj", "trec", "CoLA", "MRPC", "QQP", "STS-B", "MNLI", "SNLI", "QNLI", "RTE"], help="Tasks")
args = parser.parse_args()
model = SentenceTransformer('{}-nli-stsb-mean-tokens'.format(args.sbert_model))
model = model.cuda()
for task in args.task:
for seed in args.seed:
folder = os.path.join(args.data_dir, task, '{}-{}'.format(args.k, seed))
dataset = load_datasets(folder, task, do_test=args.do_test)
for split in dataset:
print('{}-{}-{}-{}'.format(task, args.k, seed, split))
lines = dataset[split]
embeddings = []
for line_id, line in tqdm(enumerate(lines)):
sent = get_sentence(task, line)
if line_id == 0:
print('|', sent)
emb = model.encode(sent)
embeddings.append(emb)
embeddings = np.stack(embeddings)
np.save(os.path.join(folder, "{}_sbert-{}.npy".format(split, args.sbert_model)), embeddings)
if __name__ == '__main__':
main()
import argparse
import json
import numpy as np
import torch
from torch import device
import os
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
parser.add_argument('--mapping_dir', type=str, help='Mapping directory')
# These options should be kept as their default values
parser.add_argument("--k", type=int, default=16)
parser.add_argument("--log", type=str, default="log", help="Log path.")
parser.add_argument("--key", type=str, default='', help="Validation metric name")
parser.add_argument("--test_key", type=str, default="", help="Test metric name")
parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
args = parser.parse_args()
condition = eval(args.condition)
if len(args.key) == 0:
if condition['task_name'] == 'cola':
args.key = 'cola_dev_eval_mcc'
args.test_key = 'cola_test_eval_mcc'
print_name = 'CoLA'
elif condition['task_name'] == 'mrpc/acc':
args.key = 'mrpc_dev_eval_acc'
args.test_key = 'mrpc_test_eval_acc'
args.test_key2 = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'mrpc/f1':
args.key = 'mrpc_dev_eval_f1'
args.test_key2 = 'mrpc_test_eval_acc'
args.test_key = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'qqp/acc':
args.key = 'qqp_dev_eval_acc'
args.test_key = 'qqp_test_eval_acc'
args.test_key2 = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'qqp/f1':
args.key = 'qqp_dev_eval_f1'
args.test_key2 = 'qqp_test_eval_acc'
args.test_key = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'sts-b/pearson':
args.key = 'sts-b_dev_eval_pearson'
args.test_key = 'sts-b_test_eval_pearson'
args.test_key2 = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'sts-b/spearmanr':
args.key = 'sts-b_dev_eval_spearmanr'
args.test_key2 = 'sts-b_test_eval_pearson'
args.test_key = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'qnli':
args.key = 'qnli_dev_eval_acc'
args.test_key = 'qnli_test_eval_acc'
print_name = 'QNLI'
elif condition['task_name'] == 'sst-2':
args.key = 'sst-2_dev_eval_acc'
args.test_key = 'sst-2_test_eval_acc'
print_name = 'SST-2'
elif condition['task_name'] == 'snli':
args.key = 'snli_dev_eval_acc'
args.test_key = 'snli_test_eval_acc'
print_name = 'SNLI'
elif condition['task_name'] == 'mnli':
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli_test_eval_mnli/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'mnli-mm':
condition['task_name'] = 'mnli'
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'rte':
args.key = 'rte_dev_eval_acc'
args.test_key = 'rte_test_eval_acc'
print_name = 'RTE'
elif condition['task_name'] == 'ag_news':
args.key = 'ag_news_dev_eval_acc'
args.test_key = 'ag_news_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yahoo_answers':
args.key = 'yahoo_answers_dev_eval_acc'
args.test_key = 'yahoo_answers_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yelp_review_full':
args.key = 'yelp_review_full_dev_eval_acc'
args.test_key = 'yelp_review_full_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mr':
args.key = 'mr_dev_eval_acc'
args.test_key = 'mr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'sst-5':
args.key = 'sst-5_dev_eval_acc'
args.test_key = 'sst-5_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'subj':
args.key = 'subj_dev_eval_acc'
args.test_key = 'subj_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'trec':
args.key = 'trec_dev_eval_acc'
args.test_key = 'trec_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'cr':
args.key = 'cr_dev_eval_acc'
args.test_key = 'cr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mpqa':
args.key = 'mpqa_dev_eval_acc'
args.test_key = 'mpqa_test_eval_acc'
print_name = condition['task_name']
else:
raise NotImplementedError
with open(args.log) as f:
result_list = []
for line in f:
result_list.append(eval(line))
seed_result = {}
seed_result_mapping_id = {} # avoid duplication
for item in result_list:
ok = True
for cond in condition:
if cond not in item or item[cond] != condition[cond]:
ok = False
break
if ok:
seed = item['seed']
if seed not in seed_result:
seed_result[seed] = [item]
seed_result_mapping_id[seed] = {item['mapping_id']: 1}
else:
if item['mapping_id'] not in seed_result_mapping_id[seed]:
seed_result[seed].append(item)
seed_result_mapping_id[seed][item['mapping_id']] = 1
for seed in seed_result:
print("Seed %d has %d results" % (seed, len(seed_result[seed])))
# Load all mappings
with open(os.path.join(args.mapping_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f:
mappings = []
for line in f:
mappings.append(line.strip())
# Write sorted mappings
fsort = open(os.path.join(args.mapping_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w')
fscore = open(os.path.join(args.mapping_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w')
seed_result[seed].sort(key=lambda x: x[args.key], reverse=True)
for item in seed_result[seed]:
fsort.write(mappings[item['mapping_id']] + '\n')
fscore.write("%.5f %s\n" % (item[args.key], mappings[item['mapping_id']]))
if __name__ == '__main__':
main()
import argparse
import json
import numpy as np
import torch
from torch import device
import os
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
parser.add_argument('--prompt_dir', type=str, help='Prompt directory')
# These options should be kept as their default values
parser.add_argument("--k", type=int, default=16)
parser.add_argument("--log", type=str, default="log", help="Log path.")
parser.add_argument("--key", type=str, default='', help="Validation metric name")
parser.add_argument("--test_key", type=str, default="", help="Test metric name")
parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
args = parser.parse_args()
condition = eval(args.condition)
if len(args.key) == 0:
if condition['task_name'] == 'cola':
args.key = 'cola_dev_eval_mcc'
args.test_key = 'cola_test_eval_mcc'
print_name = 'CoLA'
elif condition['task_name'] == 'mrpc/acc':
args.key = 'mrpc_dev_eval_acc'
args.test_key = 'mrpc_test_eval_acc'
args.test_key2 = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'mrpc/f1':
args.key = 'mrpc_dev_eval_f1'
args.test_key2 = 'mrpc_test_eval_acc'
args.test_key = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'qqp/acc':
args.key = 'qqp_dev_eval_acc'
args.test_key = 'qqp_test_eval_acc'
args.test_key2 = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'qqp/f1':
args.key = 'qqp_dev_eval_f1'
args.test_key2 = 'qqp_test_eval_acc'
args.test_key = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'sts-b/pearson':
args.key = 'sts-b_dev_eval_pearson'
args.test_key = 'sts-b_test_eval_pearson'
args.test_key2 = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'sts-b/spearmanr':
args.key = 'sts-b_dev_eval_spearmanr'
args.test_key2 = 'sts-b_test_eval_pearson'
args.test_key = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'qnli':
args.key = 'qnli_dev_eval_acc'
args.test_key = 'qnli_test_eval_acc'
print_name = 'QNLI'
elif condition['task_name'] == 'sst-2':
args.key = 'sst-2_dev_eval_acc'
args.test_key = 'sst-2_test_eval_acc'
print_name = 'SST-2'
elif condition['task_name'] == 'snli':
args.key = 'snli_dev_eval_acc'
args.test_key = 'snli_test_eval_acc'
print_name = 'SNLI'
elif condition['task_name'] == 'mnli':
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli_test_eval_mnli/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'mnli-mm':
condition['task_name'] = 'mnli'
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'rte':
args.key = 'rte_dev_eval_acc'
args.test_key = 'rte_test_eval_acc'
print_name = 'RTE'
elif condition['task_name'] == 'ag_news':
args.key = 'ag_news_dev_eval_acc'
args.test_key = 'ag_news_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yahoo_answers':
args.key = 'yahoo_answers_dev_eval_acc'
args.test_key = 'yahoo_answers_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yelp_review_full':
args.key = 'yelp_review_full_dev_eval_acc'
args.test_key = 'yelp_review_full_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mr':
args.key = 'mr_dev_eval_acc'
args.test_key = 'mr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'sst-5':
args.key = 'sst-5_dev_eval_acc'
args.test_key = 'sst-5_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'subj':
args.key = 'subj_dev_eval_acc'
args.test_key = 'subj_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'trec':
args.key = 'trec_dev_eval_acc'
args.test_key = 'trec_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'cr':
args.key = 'cr_dev_eval_acc'
args.test_key = 'cr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mpqa':
args.key = 'mpqa_dev_eval_acc'
args.test_key = 'mpqa_test_eval_acc'
print_name = condition['task_name']
else:
raise NotImplementedError
with open(args.log) as f:
result_list = []
for line in f:
result_list.append(eval(line))
seed_result = {}
seed_result_prompt_id = {} # avoid duplication
for item in result_list:
ok = True
for cond in condition:
if cond not in item or item[cond] != condition[cond]:
ok = False
break
if ok:
seed = item['seed']
if seed not in seed_result:
seed_result[seed] = [item]
seed_result_prompt_id[seed] = {item['prompt_id']: 1}
else:
if item['prompt_id'] not in seed_result_prompt_id[seed]:
seed_result[seed].append(item)
seed_result_prompt_id[seed][item['prompt_id']] = 1
for seed in seed_result:
print("Seed %d has %d results" % (seed, len(seed_result[seed])))
# Load all prompts
with open(os.path.join(args.prompt_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f:
prompts = []
for line in f:
prompts.append(line.strip())
# Write sorted prompts
fsort = open(os.path.join(args.prompt_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w')
fscore = open(os.path.join(args.prompt_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w')
seed_result[seed].sort(key=lambda x: x[args.key], reverse=True)
for item in seed_result[seed]:
fsort.write(prompts[item['prompt_id']] + '\n')
fscore.write("%.5f %s\n" % (item[args.key], prompts[item['prompt_id']]))
if __name__ == '__main__':
main()
import argparse
import json
import numpy as np
import torch
from torch import device
import os
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
parser.add_argument('--template_dir', type=str, help='Template directory')
# These options should be kept as their default values
parser.add_argument("--k", type=int, default=16)
parser.add_argument("--log", type=str, default="log", help="Log path.")
parser.add_argument("--key", type=str, default='', help="Validation metric name")
parser.add_argument("--test_key", type=str, default="", help="Test metric name")
parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
args = parser.parse_args()
condition = eval(args.condition)
if len(args.key) == 0:
if condition['task_name'] == 'cola':
args.key = 'cola_dev_eval_mcc'
args.test_key = 'cola_test_eval_mcc'
print_name = 'CoLA'
elif condition['task_name'] == 'mrpc/acc':
args.key = 'mrpc_dev_eval_acc'
args.test_key = 'mrpc_test_eval_acc'
args.test_key2 = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'mrpc/f1':
args.key = 'mrpc_dev_eval_f1'
args.test_key2 = 'mrpc_test_eval_acc'
args.test_key = 'mrpc_test_eval_f1'
condition['task_name'] = 'mrpc'
print_name = 'MRPC'
elif condition['task_name'] == 'qqp/acc':
args.key = 'qqp_dev_eval_acc'
args.test_key = 'qqp_test_eval_acc'
args.test_key2 = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'qqp/f1':
args.key = 'qqp_dev_eval_f1'
args.test_key2 = 'qqp_test_eval_acc'
args.test_key = 'qqp_test_eval_f1'
condition['task_name'] = 'qqp'
print_name = 'QQP'
elif condition['task_name'] == 'sts-b/pearson':
args.key = 'sts-b_dev_eval_pearson'
args.test_key = 'sts-b_test_eval_pearson'
args.test_key2 = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'sts-b/spearmanr':
args.key = 'sts-b_dev_eval_spearmanr'
args.test_key2 = 'sts-b_test_eval_pearson'
args.test_key = 'sts-b_test_eval_spearmanr'
condition['task_name'] = 'sts-b'
print_name = 'STS-B'
elif condition['task_name'] == 'qnli':
args.key = 'qnli_dev_eval_acc'
args.test_key = 'qnli_test_eval_acc'
print_name = 'QNLI'
elif condition['task_name'] == 'sst-2':
args.key = 'sst-2_dev_eval_acc'
args.test_key = 'sst-2_test_eval_acc'
print_name = 'SST-2'
elif condition['task_name'] == 'snli':
args.key = 'snli_dev_eval_acc'
args.test_key = 'snli_test_eval_acc'
print_name = 'SNLI'
elif condition['task_name'] == 'mnli':
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli_test_eval_mnli/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'mnli-mm':
condition['task_name'] = 'mnli'
args.key = 'mnli_dev_eval_mnli/acc'
args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
print_name = 'MNLI'
elif condition['task_name'] == 'rte':
args.key = 'rte_dev_eval_acc'
args.test_key = 'rte_test_eval_acc'
print_name = 'RTE'
elif condition['task_name'] == 'ag_news':
args.key = 'ag_news_dev_eval_acc'
args.test_key = 'ag_news_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yahoo_answers':
args.key = 'yahoo_answers_dev_eval_acc'
args.test_key = 'yahoo_answers_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'yelp_review_full':
args.key = 'yelp_review_full_dev_eval_acc'
args.test_key = 'yelp_review_full_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mr':
args.key = 'mr_dev_eval_acc'
args.test_key = 'mr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'sst-5':
args.key = 'sst-5_dev_eval_acc'
args.test_key = 'sst-5_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'subj':
args.key = 'subj_dev_eval_acc'
args.test_key = 'subj_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'trec':
args.key = 'trec_dev_eval_acc'
args.test_key = 'trec_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'cr':
args.key = 'cr_dev_eval_acc'
args.test_key = 'cr_test_eval_acc'
print_name = condition['task_name']
elif condition['task_name'] == 'mpqa':
args.key = 'mpqa_dev_eval_acc'
args.test_key = 'mpqa_test_eval_acc'
print_name = condition['task_name']
else:
raise NotImplementedError
with open(args.log) as f:
result_list = []
for line in f:
result_list.append(eval(line))
seed_result = {}
seed_result_template_id = {} # avoid duplication
for item in result_list:
ok = True
for cond in condition:
if cond not in item or item[cond] != condition[cond]:
ok = False
break
if ok:
seed = item['seed']
if seed not in seed_result:
seed_result[seed] = [item]
seed_result_template_id[seed] = {item['template_id']: 1}
else:
if item['template_id'] not in seed_result_template_id[seed]:
seed_result[seed].append(item)
seed_result_template_id[seed][item['template_id']] = 1
for seed in seed_result:
print("Seed %d has %d results" % (seed, len(seed_result[seed])))
# Load all templates
with open(os.path.join(args.template_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f:
templates = []
for line in f:
templates.append(line.strip())
# Write sorted templates
fsort = open(os.path.join(args.template_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w')
fscore = open(os.path.join(args.template_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w')
seed_result[seed].sort(key=lambda x: x[args.key], reverse=True)
for item in seed_result[seed]:
fsort.write(templates[item['template_id']] + '\n')
fscore.write("%.5f %s\n" % (item[args.key], templates[item['template_id']]))
if __name__ == '__main__':
main()
{
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "f857d1db-7d6d-4c76-af4f-d508a4027192",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"head: error reading 'data/x-stance/': Is a directory\n"
]
}
],
"source": [
"!head -2 data/x-stance/questions.en.json\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "b6421d22-ca47-4ecf-b667-8f19f4cb035a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"questions.de.jsonl questions.en.jsonl\tquestions.fr.jsonl questions.it.jsonl\n"
]
}
],
"source": [
"!ls data/x-stance"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "1d288f21-0b89-4974-bded-d5ca9ff24f82",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import pandas as pd\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "67b7a07f-26d9-4e6a-8473-2614f6b34887",
"metadata": {},
"outputs": [],
"source": [
"json_list = []\n",
"for line in open('data/x-stance/questions.it.jsonl'):\n",
" json_list.append(json.loads(line))\n",
"data = pd.DataFrame.from_dict(json_list)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "d505a983-7d08-4fa0-8281-99027e8edd4c",
"metadata": {},
"outputs": [],
"source": [
"k = 100\n",
"seed_list = [100,13,21,42,87]"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "824ff804-347d-4913-a3f0-20bd38cb4159",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100-100 100-13 100-21 100-42 100-87 16-100 16-13\t16-21 16-42 16-87\n"
]
}
],
"source": [
"!ls /home/mist/projects/LM-BFF/data/k-shot/x-stance"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "5ddb0ab8-0dfb-44bc-99c3-69829c86d8a7",
"metadata": {},
"outputs": [],
"source": [
"output_dir = '/home/mist/projects/LM-BFF/data/k-shot/'"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "8b8b604e-5767-4bdb-b620-2d4c464c595c",
"metadata": {},
"outputs": [],
"source": [
"task = 'x-stance'"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "59dd94e5-b067-4644-832f-bc0e45d3486f",
"metadata": {},
"outputs": [],
"source": [
"for seed in seed_list:\n",
" task_dir = os.path.join(output_dir, task)\n",
" setting_dir = os.path.join(task_dir, f\"{k}-{seed}\")\n",
" os.makedirs(setting_dir, exist_ok=True)\n",
" data.sample(random_state=seed,n=k)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "531acdd0-cbf6-430a-91e2-66923236908c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>text</th>\n",
" <th>topic</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2</td>\n",
" <td>Finden Sie es grundsätzlich richtig, dass der ...</td>\n",
" <td>Welfare</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4</td>\n",
" <td>Soll zusätzlich zur bestehenden Mutterschaftsv...</td>\n",
" <td>Welfare</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>6</td>\n",
" <td>Die Invalidenversicherung spricht bei nicht ob...</td>\n",
" <td>Welfare</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>7</td>\n",
" <td>Würden Sie eine nationale Spitalplanung befürw...</td>\n",
" <td>Healthcare</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>9</td>\n",
" <td>Finden Sie es richtig, dass einzelne ärztliche...</td>\n",
" <td>Healthcare</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>189</th>\n",
" <td>3464</td>\n",
" <td>Würden Sie eine Ausdehnung der rechtlichen Mög...</td>\n",
" <td>Security</td>\n",
" </tr>\n",
" <tr>\n",
" <th>190</th>\n",
" <td>3468</td>\n",
" <td>Soll die Schweiz Verhandlungen über den Beitri...</td>\n",
" <td>Foreign Policy</td>\n",
" </tr>\n",
" <tr>\n",
" <th>191</th>\n",
" <td>3469</td>\n",
" <td>Soll der Bundesrat ein Freihandelsabkommen mit...</td>\n",
" <td>Foreign Policy</td>\n",
" </tr>\n",
" <tr>\n",
" <th>192</th>\n",
" <td>3470</td>\n",
" <td>Eine Initiative fordert, dass die Haftungsrege...</td>\n",
" <td>Foreign Policy</td>\n",
" </tr>\n",
" <tr>\n",
" <th>193</th>\n",
" <td>3471</td>\n",
" <td>Befürworten Sie die Kandidatur der Schweiz für...</td>\n",
" <td>Foreign Policy</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>194 rows × 3 columns</p>\n",
"</div>"
],
"text/plain": [
" id text topic\n",
"0 2 Finden Sie es grundsätzlich richtig, dass der ... Welfare\n",
"1 4 Soll zusätzlich zur bestehenden Mutterschaftsv... Welfare\n",
"2 6 Die Invalidenversicherung spricht bei nicht ob... Welfare\n",
"3 7 Würden Sie eine nationale Spitalplanung befürw... Healthcare\n",
"4 9 Finden Sie es richtig, dass einzelne ärztliche... Healthcare\n",
".. ... ... ...\n",
"189 3464 Würden Sie eine Ausdehnung der rechtlichen Mög... Security\n",
"190 3468 Soll die Schweiz Verhandlungen über den Beitri... Foreign Policy\n",
"191 3469 Soll der Bundesrat ein Freihandelsabkommen mit... Foreign Policy\n",
"192 3470 Eine Initiative fordert, dass die Haftungsrege... Foreign Policy\n",
"193 3471 Befürworten Sie die Kandidatur der Schweiz für... Foreign Policy\n",
"\n",
"[194 rows x 3 columns]"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "bcf450cf-cb23-462f-b070-3529c1dfa86d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Infrastructure & Environment 31\n",
"Economy 23\n",
"Security 20\n",
"Immigration 19\n",
"Society 17\n",
"Education 16\n",
"Foreign Policy 16\n",
"Finances 15\n",
"Welfare 15\n",
"Healthcare 11\n",
"Political System 9\n",
"Digitisation 2\n",
"Name: topic, dtype: int64"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data['topic'].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b07328e8-5fcf-4fad-9468-49ff91652ef9",
"metadata": {},
"outputs": [],
"source": [
"for seed in seed_list:\n",
" data.sample(random_state=seed,n=k)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "8d7aa681-bd14-4c48-881a-9130a9b88edc",
"metadata": {},
"outputs": [],
"source": [
"label_encoding = {\n",
"'Infrastructure & Environment': 0,\n",
"'Economy': 1 , \n",
"'Security': 2 , \n",
"'Immigration': 3 ,\n",
"'Society': 4 ,\n",
"'Education': 5 ,\n",
"'Foreign Policy': 6 ,\n",
"'Finances': 7 ,\n",
"'Welfare':8 ,\n",
"'Healthcare':9 ,\n",
"'Political System': 10 , \n",
"'Digitisation':11 }"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "3c2fad2e-77cf-4eff-9dda-5773f59f4402",
"metadata": {},
"outputs": [],
"source": [
"data['label'] = data['topic'].apply(lambda x:label_encoding[x])"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "6e4c0a41-31a8-4620-bfed-c776a21e0c1c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(194, 4)"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.shape"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4fc29374-8495-43b7-829f-90c11bf8a974",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'pd' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-4-a66617debac6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mline\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'data/x-stance/questions.en.jsonl'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mjson_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataFrame\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjson_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m'+1'\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'stars'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m>=\u001b[0m\u001b[0;36m3\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m'-1'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrac\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'pd' is not defined"
]
}
],
"source": [
"\n",
"\n",
"data['label']=data.apply(lambda x:'+1' if x['stars']>=3 else '-1',axis=1)\n",
"data = data.sample(frac=1)\n",
"data['text']=data['text'].apply(lambda x:' '.join(x.replace('\\n','.').split(' ')[:500]))\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "551574a5-f461-4a6f-8c0a-97cf92747b74",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'data' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-c5d84736ba45>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mNameError\u001b[0m: name 'data' is not defined"
]
}
],
"source": [
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3187648-9c75-423d-bc21-ba655e30df6f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment