From cf1a4f6527d9f35a3bb63fa33bcd6eee07dfccd0 Mon Sep 17 00:00:00 2001 From: xyw <yunwanx@foxmail.com> Date: Thu, 28 Oct 2021 21:31:53 +0800 Subject: [PATCH] v1 --- LM-BFF/.gitignore | 152 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/LICENSE | 21 +++++++++++++++++++++ LM-BFF/README.md | 405 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/figs/lmbff.png | Bin 0 -> 378387 bytes LM-BFF/nohup.out | 1 + LM-BFF/requirements.txt | 37 +++++++++++++++++++++++++++++++++++++ LM-BFF/run.py | 628 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/src/dataset.py | 658 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/src/label_search.py | 148 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/src/models.py | 195 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/src/processors.py | 626 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/src/trainer.py | 473 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/tools/ensemble.py | 289 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/tools/gather_result.py | 160 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/tools/generate_k_shot_data.py | 181 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/tools/generate_labels.py | 284 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/tools/generate_template.py | 375 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/tools/get_sbert_embedding.py | 105 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/tools/sort_mapping.py | 173 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/tools/sort_prompt.py | 173 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/tools/sort_template.py | 173 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ LM-BFF/未命名.ipynb | 432 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 22 files changed, 5689 insertions(+) create mode 100755 LM-BFF/.gitignore create mode 100755 LM-BFF/LICENSE create mode 100755 LM-BFF/README.md create mode 100755 LM-BFF/figs/lmbff.png create mode 100755 LM-BFF/nohup.out create mode 100755 LM-BFF/requirements.txt create mode 100755 LM-BFF/run.py create mode 100755 LM-BFF/src/dataset.py create mode 100755 LM-BFF/src/label_search.py create mode 100755 LM-BFF/src/models.py create mode 100755 LM-BFF/src/processors.py create mode 100755 LM-BFF/src/trainer.py create mode 100755 LM-BFF/tools/ensemble.py create mode 100755 LM-BFF/tools/gather_result.py create mode 100755 LM-BFF/tools/generate_k_shot_data.py create mode 100755 LM-BFF/tools/generate_labels.py create mode 100755 LM-BFF/tools/generate_template.py create mode 100755 LM-BFF/tools/get_sbert_embedding.py create mode 100755 LM-BFF/tools/sort_mapping.py create mode 100755 LM-BFF/tools/sort_prompt.py create mode 100755 LM-BFF/tools/sort_template.py create mode 100755 LM-BFF/未命名.ipynb diff --git a/LM-BFF/.gitignore b/LM-BFF/.gitignore new file mode 100755 index 0000000..4cf2bba --- /dev/null +++ b/LM-BFF/.gitignore @@ -0,0 +1,152 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +*.sh +.venv +.vscode +data +!data/k-shot/checksum +log* +runs +result +wandb +ensemble_predict_results +auto* +my* +slurm diff --git a/LM-BFF/LICENSE b/LM-BFF/LICENSE new file mode 100755 index 0000000..61d8849 --- /dev/null +++ b/LM-BFF/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Princeton Natural Language Processing + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/LM-BFF/README.md b/LM-BFF/README.md new file mode 100755 index 0000000..0853418 --- /dev/null +++ b/LM-BFF/README.md @@ -0,0 +1,405 @@ +# LM-BFF (**B**etter **F**ew-shot **F**ine-tuning of **L**anguage **M**odels) + +This is the implementation of the paper [Making Pre-trained Language Models Better Few-shot Learners](https://arxiv.org/pdf/2012.15723.pdf). LM-BFF is short for **b**etter **f**ew-shot **f**ine-tuning of **l**anguage **m**odels. + +## Quick links + +* [Overview](#overview) +* [Requirements](#requirements) +* [Prepare the data](#prepare-the-data) +* [Run the model](#run-lm-bff) + * [Quick start](#quick-start) + * [Experiments with multiple runs](#experiments-with-multiple-runs) + * [Using demonstrations with filtering](#using-demonstrations-with-filtering) + * [Automatically searched prompt](#automatically-searched-prompt) + * [Ensemble](#ensemble-model) + * [Zero-shot experiments](#zero-shot-experiments) + * [How to design your own templates](#how-to-design-your-own-templates) +* [Citation](#citation) + + +## Overview + + + +In this work we present LM-BFF, a suite of simple and complementary techniques for fine-tuning pre-trained language models on a small number of training examples. Our approach includes: + +1. Prompt-based fine-tuning together with a novel pipeline for automating prompt generation. +2. A refined strategy for incorporating demonstrations into context. + +You can find more details of this work in our [paper](https://arxiv.org/pdf/2012.15723.pdf). + +## Requirements + +To run our code, please install all the dependency packages by using the following command: + +``` +pip install -r requirements.txt +``` + +**NOTE**: Different versions of packages (like `pytorch`, `transformers`, etc.) may lead to different results from the paper. However, the trend should still hold no matter what versions of packages you use. + +## Prepare the data + +We pack the original datasets (SST-2, SST-5, MR, CR, MPQA, Subj, TREC, CoLA, MNLI, SNLI, QNLI, RTE, MRPC, QQP, STS-B) [here](https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar). Please download it and extract the files to `./data/original`, or run the following commands: + +```bash +cd data +bash download_dataset.sh +``` + +Then use the following command (in the root directory) to generate the few-shot data we need: + +```bash +python tools/generate_k_shot_data.py +``` + +See `tools/generate_k_shot_data.py` for more options. For results in the paper, we use the default options: we take `K=16` and take 5 different seeds of 13, 21, 42, 87, 100. The few-shot data will be generated to `data/k-shot`. In the directory of each dataset, there will be folders named as `$K-$SEED` indicating different dataset samples. You can use the following command to check whether the generated data are exactly the same as ours: + +```bash +cd data/k-shot +md5sum -c checksum +``` + +**NOTE**: During training, the model will generate/load cache files in the data folder. If your data have changed, make sure to clean all the cache files (starting with "cache"). + +## Run LM-BFF + +### Quick start +Our code is built on [transformers](https://github.com/huggingface/transformers) and we use its `3.4.0` version. Other versions of `transformers` might cause unexpected errors. + +Before running any experiments, create the result folder by `mkdir result` to save checkpoints. Then you can run our code with the following example: + +```bash +python run.py \ + --task_name SST-2 \ + --data_dir data/k-shot/SST-2/16-42 \ + --overwrite_output_dir \ + --do_train \ + --do_eval \ + --do_predict \ + --evaluate_during_training \ + --model_name_or_path roberta-large \ + --few_shot_type prompt-demo \ + --num_k 16 \ + --max_steps 1000 \ + --eval_steps 100 \ + --per_device_train_batch_size 2 \ + --learning_rate 1e-5 \ + --num_train_epochs 0 \ + --output_dir result/tmp \ + --seed 42 \ + --template "*cls**sent_0*_It_was*mask*.*sep+*" \ + --mapping "{'0':'terrible','1':'great'}" \ + --num_sample 16 \ +``` + +Most arguments are inherited from `transformers` and are easy to understand. We further explain some of the LM-BFF's arguments: + +* `few_shot_type`: There are three modes + * `finetune`: Standard fine-tuning + * `prompt`: Prompt-based fine-tuning. + * `prompt-demo`: Prompt-based fine-tuning with demonstrations. +* `num_k`: Number of training instances for each class. We take `num_k`=16 in our paper. This argument is mainly used for indexing logs afterwards (because the training example numbers are actually decided by the data split you use). +* `template`: Template for prompt-based fine-tuning. We will introduce the template format later. +* `mapping`: Label word mapping for prompt-based fine-tuning. It is a string of dictionary indicating the mapping from label names to label words. **NOTE**: For RoBERTa, the model will automatically add space before the word. See the paper appendix for details. +* `num_sample`: When using demonstrations during inference, the number of samples for each input query. Say `num_sample`=16, then we sample 16 different sets of demonstrations for one input, do the forward seperately, and average the logits for all 16 samples as the final prediction. + +Also, this codebase supports BERT-series and RoBERTa-series pre-trained models in Huggingface's `transformers`. You can check [Huggingface's website](https://huggingface.co/models) for available models and pass models with a "bert" or "roberta" in their names to `--model_name_or_path`. Some examples would be `bert-base-uncased`, `bert-large-uncased`, `roberta-base`, `roberta-large`, etc. + +To easily run our experiments, you can also use `run_experiment.sh` (this command runs prompt-based fine-tuning with demonstrations, no filtering, manual prompt): + +```bash +TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh +``` + +We have already defined the templates and label word mappings in it, so you only need manipulate several hyper-parameters and `TAG` (you can use whatever tag you want and it just makes finding results easier). See `run_experiment.sh` for more options of these environment variables. Besides, you can add extra arguments by + +```bash +TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--output_dir result/exp --max_seq_length 512" +``` + +### Experiments with multiple runs + +To carry out experiments with multiple data splits, as the evaluation protocol detailed in \$3.3 of [our paper](https://arxiv.org/pdf/2012.15723.pdf) (grid-search for each seed and aggregate the results over 5 different seeds), you can use the following scripts: + +```bash +for seed in 13 21 42 87 100 +do + for bs in 2 4 8 + do + for lr in 1e-5 2e-5 5e-5 + do + TAG=exp \ + TYPE=prompt-demo \ + TASK=SST-2 \ + BS=$bs \ + LR=$lr \ + SEED=$seed \ + MODEL=roberta-large \ + bash run_experiment.sh + done + done +done +``` + +All the results will be stored in `./log`. To gather all the results, run the following command: + +```bash +python tools/gather_result.py --condition "{'tag': 'exp', 'task_name': 'sst-2', 'few_shot_type': 'prompt-demo'}" +``` + +Then the program will find all the trials that satisfy the condition in `./log`, and print the mean/std of the final results. Note that the task names are all lower-cased and if the task has more than one metric, you need to specify the major metric (used for taking the best validation trial) in the name (e.g., `mnli`, `mnli-mm`, `mrpc/acc`, `mrpc/f1`, `qqp/acc`, `qqp/f1`, `sts-b/pearson`, `sts-b/spearman`). + +### Using demonstrations with filtering + +To use the filtering mechanism when using demonstrations, we need to first generate [Sentence-BERT](https://github.com/UKPLab/sentence-transformers) embeddings. To generate embeddings for datasets in our paper, you can directly run + +``` +bash tools/get_sbert_embedding.sh roberta-large +``` + +`roberta-large` can also be replaced by `bert-base`, `bert-large`, `roberta-base` and `distilbert-base` (see [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) for details). See `tools/get_sbert_embedding.sh` and `tools/get_sbert_embedding.py` if you want to add more datasets. + +After generating the embeddings (embeddings are saved as numpy files in the data folders), we can run the following commands to do prompt-based fine-tuning with demonstrations with filtering: + +```bash +TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--demo_filter --demo_filter_model sbert-roberta-large" +``` + +### Automatically searched prompt + +We provide our automatic search results in `auto_template` and `auto_label_mapping`. There are three types of files: + +* `SST-2/16-42.txt`: Initial search results for SST-2 dataset, K=16 and SEED=42. +* `SST-2/16-42.sort.txt`: Do prompt-based fine-tuning on initial results and sort them based on dev set performance. +* `SST-2/16-42.score.txt`: Same as above, but with dev set scores. + +To use the best automatic template (`auto-T` in the paper), use the following command: + +```bash +TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--template_path auto_template/SST-2/16-42.sort.txt --template_id 0" +``` + +You can also use the _i_-th automatic result by specifying different `template_id`. + +Similarly, to use automatic label (`auto-L` in the paper), use the following command: + +```bash +TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--mapping_path auto_label_mapping/SST-2/16-42.sort.txt --mapping_id 0" +``` + +**NOTE**: Make sure to use the corresponding automatic search results with different data split seeds. + +**Our final results (LM-BFF) take prompt-based fine-tuning with demonstrations, filtering and automatic template, for example**: + +```bash +for seed in 13 21 42 87 100 +do + for bs in 2 4 8 + do + for lr in 1e-5 2e-5 5e-5 + do + TAG=LM-BFF \ + TYPE=prompt-demo \ + TASK=SST-2 \ + BS=$bs \ + LR=$lr \ + SEED=$seed \ + MODEL=roberta-large \ + bash run_experiment.sh "--template_path auto_template/SST-2/16-$seed.sort.txt --template_id 0 --demo_filter --demo_filter_model sbert-roberta-large" + done + done +done + +python tools/gather_result.py --condition "{'tag': 'LM-BFF', 'task_name': 'sst-2', 'few_shot_type': 'prompt-demo'}" +``` + +#### Search for automatic templates + +If you want to try automatically generating templates by yourself, here are the instructions. Note that it is an extremely long process :) + +To get automatic templates, we first generate template candidates by using T5: + +```bash +python tools/generate_template.py \ + --output_dir my_auto_template \ + --task_name SST-2 \ + --seed 13 21 42 87 100 \ + --t5_model t5-3b \ + --beam 100 +``` + +Where `--t5_model` specifies the pre-trained T5 checkpoint to use and `--beam` specifies the beam search width. Note that `t5-3b` model will take approximately 15GB GPU memory, and if your GPU does not support it, you can try smaller T5 models (e.g., `t5-base`). + +Then we do prompt-based fine-tuning of all the templates + +```bash +for template_id in {0..99} +do + for seed in 13 21 42 87 100 + do + # To save time, we fix these hyper-parameters + bs=8 + lr=1e-5 + + # Since we only use dev performance here, use --no_predict to skip testing + TAG=exp-template \ + TYPE=prompt \ + TASK=SST-2 \ + BS=$bs \ + LR=$lr \ + SEED=$seed \ + MODEL=roberta-large \ + bash run_experiment.sh "--template_path my_auto_template/SST-2/16-$seed.txt --template_id $template_id --no_predict" + done +done +``` + +... and sort them based on dev set performance: + +```bash +python tools/sort_template.py --condition "{'tag': 'exp-template', 'task_name': 'sst-2'}" --template_dir my_auto_template +``` + +The sorted results will be saved in `my_auto_template`, with the same format as described in [Automatically searched prompt](#automatically-searched-prompt). + +#### Search for automatic label word mappings + +Similar to the process of automatic template search, we first generate candidate label word mappings by running: + +```bash +bash tools/run_generate_labels.sh +``` + +You can modify the options in `tools/run_generate_labels.sh` to run this for different datasets or save mappings to different directories. After running the generation, the candidate label mappings will be saved in `my_auto_label_mapping/manual_template`. + +Then we do prompt-based fine-tuning of all the mappings by: + +```bash +for mapping_id in {0..99} +do + for seed in 13 21 42 87 100 + do + # To save time, we fix these hyper-parameters + bs=8 + lr=1e-5 + + # Since we only use dev performance here, use --no_predict to skip testing + TAG=exp-mapping \ + TYPE=prompt \ + TASK=SST-2 \ + BS=$bs \ + LR=$lr \ + SEED=$seed \ + MODEL=roberta-large \ + bash run_experiment.sh "--mapping_path my_auto_label_mapping/manual_template/SST-2/16-$seed.txt --mapping_id $mapping_id --no_predict" + done +done +``` + +... and sort them based on dev set performance: + +```bash +python tools/sort_mapping.py --condition "{'tag': 'exp-mapping', 'task_name': 'sst-2'}" --mapping_dir my_auto_label_mapping/manual_template +``` + +The sorted results will be saved in `my_auto_label_mapping/manual_template`, with the same format as described in [Automatically searched prompt](#automatically-searched-prompt). + +**Auto T + L**: We can also do a joint search of templates and label word mappings following these steps: + +1. First, do the automatic template search following [Search for automatic templates](#search-for-automatic-templates). +2. The following steps are similar to automatic label mapping except a few arguments. When running `tools/run_generate_labels.sh`, change `LOAD_TEMPLATES` to `true` in it and the template + mapping candidates will be written in `my_auto_label_mapping/auto_template` +3. For the following fine-tuning, change `--mapping_path` and `--mapping_id` to `--prompt_path` and `--prompt_id`. +4. In the end, for re-ranking all the prompts, change `tools/sort_mapping.py` to `tools/sort_prompt.py` to get the final lists. + +### Ensemble model + +First we need to train models with different templates: + +```bash +mkdir ensemble_predict_results +for template_id in {0..19} # Use top 20 templates +do + array_id=0 + for seed in 13 21 42 87 100 + do + for bs in 2 4 8 + do + for lr in 1e-5 2e-5 5e-5 + do + TAG=exp-ensemble \ + TYPE=prompt-demo \ + TASK=SST-2 \ + BS=$bs \ + LR=$lr \ + SEED=$seed \ + MODEL=roberta-large \ + bash run_experiment.sh "--template_path auto_template/SST-2/16-$seed.sort.txt --template_id $template_id --model_id $template_id --array_id $array_id --save_logit --save_logit_dir ensemble_predict_results" + + array_id=$(expr $array_id + 1) + done + done + done +done +``` + +Looks a little complicated? It's actually pretty easy to understand: `--model_id` and `--array_id` is used to distinguish different runs, and `--save_logit` tells the program to save the prediction results for ensemble. + +After finishing the experiments, use the following command to get the ensemble results: + +```bash +python tools/ensemble.py --condition "{'tag': 'exp-ensemble', 'task_name': 'sst-2', 'few_shot_type': 'prompt-demo'}" --n_models 20 +``` + +where `--n_models` specify how many models you want to use for ensemble (should be kept the same as the number of templates you use in experiments). + +### Zero-shot experiments + +It's easy to run zero-shot experiments: just add the `--no_train` argument: + +```bash +TAG=zero-shot TYPE=prompt TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--no_train" +``` + +To do "GPT-3 style" in-context learning: + +```bash +TAG=gpt3-in-context TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--no_train --num_sample 1 --gpt3_in_context_head --gpt3_in_context_num 32 --truncate_head --use_full_length" +``` + +### How to design your own templates + +Here are two template examples: + +For SST-2: `*cls**sent_0*_It_was*mask*.*sep+*` => `[CLS] {S0} It was [MASK]. [SEP]` + +For MNLI: `*cls**sent-_0*?*mask*,*+sentl_1**sep+*` => `[CLS] {S0}? [MASK], {S1} [SEP]` + +The template is composed of special tokens and variables (surrounded by `*`) and text (e.g., `It_was`, where space is replaced by `_`). Special tokens and variables contain: + +* `*cls*`, `*sep*`, `*sep+*` and `*mask*`: Special tokens of CLS, SEP and MASK (different for different pre-trained models and tokenizers). `*sep+*` means the contents before and after this token have different segment embeddings (only for BERT). +* `*sent_i*`: The i-th sentence. +* `*sent-_i*`: The i-th sentence, discarding the last character. +* `*sentl_i*`: The i-th sentence, lower-casing the first letter. +* `*sentl-_i*`: The i-th sentence, discarding the last character and lower-casing the first letter. +* `*+sent_i*`: The i-th sentence, adding an extra space at the beginning. +* `*+sentl_i*`: The i-th sentence, adding an extra space at the beginning and lower-casing the first letter. + + +## Bugs or questions? + +If you have any questions related to the code or the paper, feel free to email Tianyu (`tianyug@cs.princeton.edu`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker! + +## Citation + +Please cite our paper if you use LM-BFF in your work: + +```bibtex +@inproceedings{gao2021making, + title={Making Pre-trained Language Models Better Few-shot Learners}, + author={Gao, Tianyu and Fisch, Adam and Chen, Danqi}, + booktitle={Association for Computational Linguistics (ACL)}, + year={2021} +} +``` diff --git a/LM-BFF/figs/lmbff.png b/LM-BFF/figs/lmbff.png new file mode 100755 index 0000000..04dac6e Binary files /dev/null and b/LM-BFF/figs/lmbff.png differ diff --git a/LM-BFF/nohup.out b/LM-BFF/nohup.out new file mode 100755 index 0000000..0e744ed --- /dev/null +++ b/LM-BFF/nohup.out @@ -0,0 +1 @@ + 0%| | 0/18 [00:00<?, ?it/s] 6%|▌ | 1/18 [02:41<45:41, 161.24s/it] \ No newline at end of file diff --git a/LM-BFF/requirements.txt b/LM-BFF/requirements.txt new file mode 100755 index 0000000..e91061e --- /dev/null +++ b/LM-BFF/requirements.txt @@ -0,0 +1,37 @@ +certifi==2020.12.5 +chardet==4.0.0 +click==7.1.2 +dataclasses +filelock==3.0.12 +flake8==3.8.4 +future==0.18.2 +idna==2.10 +importlib-metadata==3.3.0 +joblib==1.0.0 +mccabe==0.6.1 +nltk==3.5 +numpy==1.19.4 +packaging==20.8 +pandas==1.1.5 +protobuf==3.14.0 +pycodestyle==2.6.0 +pyflakes==2.2.0 +pyparsing==2.4.7 +python-dateutil==2.8.1 +pytz==2020.5 +regex==2020.11.13 +requests==2.25.1 +sacremoses==0.0.43 +scikit-learn==0.24.0 +scipy==1.5.4 +sentence-transformers==0.4.0 +sentencepiece==0.1.94 +six==1.15.0 +threadpoolctl==2.1.0 +tokenizers==0.9.2 +torch==1.6.0 +tqdm==4.48.2 +transformers==3.4.0 +typing-extensions==3.7.4.3 +urllib3>=1.26.4 +zipp==3.4.0 diff --git a/LM-BFF/run.py b/LM-BFF/run.py new file mode 100755 index 0000000..baa73cf --- /dev/null +++ b/LM-BFF/run.py @@ -0,0 +1,628 @@ +"""Finetuning the library models for sequence classification on GLUE.""" + +import dataclasses +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Callable, Dict, Optional +import torch + +import numpy as np + +import transformers +from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction +from transformers import GlueDataTrainingArguments as DataTrainingArguments +from transformers import HfArgumentParser, TrainingArguments, set_seed + +from src.dataset import FewShotDataset +from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings +from src.trainer import Trainer +from src.processors import processors_mapping, num_labels_mapping, output_modes_mapping, compute_metrics_mapping, bound_mapping + +from filelock import FileLock +from datetime import datetime + +from copy import deepcopy +from tqdm import tqdm +import json + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + # Few-shot type + # - finetune: standard fine-tuning + # - prompt: prompt-based fine-tuning + # - prompt-demo: prompt-based fine-tuning with demonstrations + few_shot_type: str = field( + default='prompt-demo', + metadata={"help": "Few-shot learning model type. Choice: finetune, prompt, prompt-demo"} + ) + + # Only for BERT-type model + random_segment: bool = field( + default=False, + metadata={"help": "Whether to reinitialize the token type embeddings (only for BERT)."} + ) + +@dataclass +class DynamicDataTrainingArguments(DataTrainingArguments): + """ + Arguments for dynamic training. + """ + num_k: Optional[int] = field( + default=16, + metadata={"help": "Number of training instances per class"} + ) + + num_sample: Optional[int] = field( + default=16, + metadata={"help": "Number of samples (for inference) in fine-tuning with demonstrations"} + ) + + num_demo: Optional[int] = field( + default=1, + metadata={"help": "Number of demonstrations from each class"} + ) + + auto_demo: bool = field( + default=True, + metadata={"help": "Automatically generate template for using demonstrations"} + ) + + # For prompting + template: str = field( + default=None, + metadata={"help": "Template"} + ) + + mapping: str = field( + default=None, + metadata={"help": "Label word mapping"} + ) + + template_path: str = field( + default=None, + metadata={"help": "Path to a txt file that stores all the templates, one per line. Do not set this when prompt_path is used"} + ) + + mapping_path: str = field( + default=None, + metadata={"help": "Path to a txt file that stores all the label word mappings, one per line. Do not set this when prompt_path is used"} + ) + + prompt_path: str = field( + default=None, + metadata={"help": "Path to a txt file that stores all the prompts (templates and mappings), one per line"} + ) + + template_id: int = field( + default=None, + metadata={"help": "Template id if using template_path"} + ) + + mapping_id: int = field( + default=None, + metadata={"help": "Mapping id if using template_path"} + ) + + prompt_id: int = field( + default=None, + metadata={"help": "Prompt id if using prompt_path"} + ) + + top_n_template: int = field( + default=None, + metadata={"help": "Use top-n template in the template path"} + ) + + # For logging + tag: str = field( + default='', + metadata={"help": "Set the tag and find the result easier in the log."} + ) + + # For filtering when using demonstrations + demo_filter: bool = field( + default=False, + metadata={"help": "Only use similar instances in demonstrations"} + ) + + demo_filter_rate: float = field( + default=0.5, + metadata={"help": "Only use top-x\% similar instances in demonstrations"} + ) + + demo_filter_model: str = field( + default=None, + metadata={"help": "Model name for demonstration filter embeddings. Will load embeddings based on the model name."} + ) + + debug_mode: bool = field( + default=False, + metadata={"help": "Debug mode"} + ) + + # For max length + double_demo: bool = field( + default=False, + metadata={"help": "Use double length for using demonstrations"} + ) + + first_sent_limit: int = field( + default=None, + metadata={"help": "Limit the length of the first sentence (i.e., sent_0)"} + ) + + other_sent_limit: int = field( + default=None, + metadata={"help": "Limit the length of sentences other than the first sentence"} + ) + + use_full_length: bool = field( + default=None, + metadata={"help": "Use the full length (512)"} + ) + + # GPT-3's in-context learning + gpt3_in_context_head: bool = field( + default=False, + metadata={"help": "GPT-3's in-context learning (context at the beginning)"} + ) + + gpt3_in_context_tail: bool = field( + default=False, + metadata={"help": "GPT-3's in-context learning (context at the end)"} + ) + + gpt3_in_context_num: int = field( + default=32, + metadata={"help": "Number of context examples"} + ) + + truncate_head: bool = field( + default=False, + metadata={"help": "When exceeding the maximum length, truncate the head instead of the tail."} + ) + + # Do not set up the following fields. They are set up automatically. + prompt: bool = field( + default=False, + metadata={"help": "Whether to use prompt-based fine-tuning"} + ) + template_list: list = field( + default=None, + metadata={"help": "(DO NOT List of templates (only initialized after the program starts."} + ) + + +@dataclass +class DynamicTrainingArguments(TrainingArguments): + # For ensemble + array_id: int = field( + default=-1, + metadata={"help": "Array ID (contains seed and hyper-paramter search) to idenfity the model"} + ) + + model_id: int = field( + default=-1, + metadata={"help": "Model ID (contains template information) to identify the model"} + ) + + save_logit: bool = field( + default=False, + metadata={"help": "Save test file logit with name $TASK-$MODEL_ID-$ARRAY_ID.npy"} + ) + + save_logit_dir: str = field( + default=None, + metadata={"help": "Where to save the prediction result"} + ) + + # Regularization + fix_layers: int = field( + default=0, + metadata={"help": "Fix bottom-n layers when optimizing"} + ) + + # Training + save_at_last: bool = field( + default=False, + metadata={"help": "Instead of saving the best (dev performance) checkpoint, save the last checkpoint"} + ) + + # Turn off train/test + no_train: bool = field( + default=False, + metadata={"help": "No training"} + ) + no_predict: bool = field( + default=False, + metadata={"help": "No test"} + ) + + +def main(): + parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments, DynamicTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if 'prompt' in model_args.few_shot_type: + data_args.prompt = True + + if training_args.no_train: + training_args.do_train = False + if training_args.no_predict: + training_args.do_predict = False + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + ) + + # Load prompt/template/mapping file + if data_args.prompt: + if data_args.prompt_path is not None: + assert data_args.prompt_id is not None + prompt_list = [] + with open(data_args.prompt_path) as f: + for line in f: + line = line.strip() + template, mapping = line.split('\t') + prompt_list.append((template, mapping)) + + data_args.template, data_args.mapping = prompt_list[data_args.prompt_id] + logger.info("Specify load the %d-th prompt: %s | %s" % (data_args.prompt_id, data_args.template, data_args.mapping)) + else: + if data_args.template_path is not None: + with open(data_args.template_path) as f: + data_args.template_list = [] + for line in f: + line = line.strip() + if len(line) > 0: + data_args.template_list.append(line) + + # Load top-n templates + if data_args.top_n_template is not None: + data_args.template_list = data_args.template_list[:data_args.top_n_template] + logger.info("Load top-%d templates from %s" % (len(data_args.template_list), data_args.template_path)) + + # ... or load i-th template + if data_args.template_id is not None: + data_args.template = data_args.template_list[data_args.template_id] + data_args.template_list = None + logger.info("Specify load the %d-th template: %s" % (data_args.template_id, data_args.template)) + + if data_args.mapping_path is not None: + assert data_args.mapping_id is not None # Only can use one label word mapping + with open(data_args.mapping_path) as f: + mapping_list = [] + for line in f: + line = line.strip() + mapping_list.append(line) + + data_args.mapping = mapping_list[data_args.mapping_id] + logger.info("Specify using the %d-th mapping: %s" % (data_args.mapping_id, data_args.mapping)) + + # Check save path + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError(f"Output directory ({training_args.output_dir}) already exists.") + + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + training_args.fp16, + ) + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed + set_seed(training_args.seed) + + try: + num_labels = num_labels_mapping[data_args.task_name] + output_mode = output_modes_mapping[data_args.task_name] + logger.info("Task name: {}, number of labels: {}, output mode: {}".format(data_args.task_name, num_labels, output_mode)) + except KeyError: + raise ValueError("Task not found: %s" % (data_args.task_name)) + + # Automatically generate template for using demonstrations + if data_args.auto_demo and model_args.few_shot_type == 'prompt-demo': + # GPT-3's in-context learning + if data_args.gpt3_in_context_head or data_args.gpt3_in_context_tail: + logger.info("Automatically convert the template to GPT-3's in-context learning.") + assert data_args.template_list is None + + old_template = data_args.template + new_template = old_template + '' + old_template = old_template.replace('*cls*', '') + # Single sentence or sentence pair? + sent_num = 1 + if "_1" in old_template: + sent_num = 2 + for instance_id in range(data_args.gpt3_in_context_num): + sub_template = old_template + '' + # Replace sent_id + for sent_id in range(sent_num): + sub_template = sub_template.replace("_{}*".format(sent_id), "_{}*".format(sent_num + sent_num * instance_id + sent_id)) + # Replace mask + sub_template = sub_template.replace("*mask*", "*labelx_{}*".format(instance_id)) + if data_args.gpt3_in_context_tail: + new_template = new_template + sub_template # Put context at the end + else: + new_template = sub_template + new_template # Put context at the beginning + logger.info("| {} => {}".format(data_args.template, new_template)) + data_args.template = new_template + else: + logger.info("Automatically convert the template to using demonstrations.") + if data_args.template_list is not None: + for i in range(len(data_args.template_list)): + old_template = data_args.template_list[i] + new_template = old_template + '' + old_template = old_template.replace('*cls*', '') + # Single sentence or sentence pair? + sent_num = 1 + if "_1" in old_template: + sent_num = 2 + for label_id in range(num_labels): + sub_template = old_template + '' + # Replace sent id + for sent_id in range(sent_num): + sub_template = sub_template.replace("_{}*".format(sent_id), "_{}*".format(sent_num + sent_num * label_id + sent_id)) + # Replace mask + sub_template = sub_template.replace("*mask*", "*label_{}*".format(label_id)) + new_template = new_template + sub_template + logger.info("| {} => {}".format(data_args.template_list[i], new_template)) + data_args.template_list[i] = new_template + else: + old_template = data_args.template + new_template = old_template + '' + old_template = old_template.replace('*cls*', '') + # Single sentence or sentence pair? + sent_num = 1 + if "_1" in old_template: + sent_num = 2 + for label_id in range(num_labels): + sub_template = old_template + '' + # Replace sent id + for sent_id in range(sent_num): + sub_template = sub_template.replace("_{}".format(sent_id), "_{}".format(sent_num + sent_num * label_id + sent_id)) + # Replace mask + sub_template = sub_template.replace("*mask*", "*label_{}*".format(label_id)) + new_template = new_template + sub_template + logger.info("| {} => {}".format(data_args.template, new_template)) + data_args.template = new_template + + # Create config + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + ) + + if 'prompt' in model_args.few_shot_type: + if config.model_type == 'roberta': + model_fn = RobertaForPromptFinetuning + elif config.model_type == 'bert': + model_fn = BertForPromptFinetuning + else: + raise NotImplementedError + elif model_args.few_shot_type == 'finetune': + model_fn = AutoModelForSequenceClassification + else: + raise NotImplementedError + special_tokens = [] + + # Create tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + additional_special_tokens=special_tokens, + cache_dir=model_args.cache_dir, + ) + + # Get our special datasets. + train_dataset = ( + FewShotDataset(data_args, tokenizer=tokenizer, mode="train", use_demo=("demo" in model_args.few_shot_type)) + ) + eval_dataset = ( + FewShotDataset(data_args, tokenizer=tokenizer, mode="dev", use_demo=("demo" in model_args.few_shot_type)) + if training_args.do_eval + else None + ) + test_dataset = ( + FewShotDataset(data_args, tokenizer=tokenizer, mode="test", use_demo=("demo" in model_args.few_shot_type)) + if training_args.do_predict + else None + ) + + set_seed(training_args.seed) + + model = model_fn.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + ) + + # For BERT, increase the size of the segment (token type) embeddings + if config.model_type == 'bert': + model.resize_token_embeddings(len(tokenizer)) + resize_token_type_embeddings(model, new_num_types=10, random_segment=model_args.random_segment) + + # Pass dataset and argument information to the model + if data_args.prompt: + model.label_word_list = torch.tensor(train_dataset.label_word_list).long().cuda() + if output_modes_mapping[data_args.task_name] == 'regression': + # lower / upper bounds + model.lb, model.ub = bound_mapping[data_args.task_name] + model.model_args = model_args + model.data_args = data_args + model.tokenizer = tokenizer + + # Build metric + def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: + def compute_metrics_fn(p: EvalPrediction): + # Note: the eval dataloader is sequential, so the examples are in order. + # We average the logits over each sample for using demonstrations. + predictions = p.predictions + num_logits = predictions.shape[-1] + logits = predictions.reshape([eval_dataset.num_sample, -1, num_logits]) + logits = logits.mean(axis=0) + + if num_logits == 1: + preds = np.squeeze(logits) + else: + preds = np.argmax(logits, axis=1) + + # Just for sanity, assert label ids are the same. + label_ids = p.label_ids.reshape([eval_dataset.num_sample, -1]) + label_ids_avg = label_ids.mean(axis=0) + label_ids_avg = label_ids_avg.astype(p.label_ids.dtype) + assert (label_ids_avg - label_ids[0]).mean() < 1e-2 + label_ids = label_ids[0] + + return compute_metrics_mapping[task_name](task_name, preds, label_ids) + + return compute_metrics_fn + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=build_compute_metrics_fn(data_args.task_name) + ) + + # Training + if training_args.do_train: + trainer.train(model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None) + # Use the early stop, so do not save the model in the end (unless specify save_at_last) + if training_args.save_at_last: + trainer.save_model(training_args.output_dir) + + if trainer.is_world_master(): + tokenizer.save_pretrained(training_args.output_dir) + torch.save(model_args, os.path.join(training_args.output_dir, "model_args.bin")) + torch.save(data_args, os.path.join(training_args.output_dir, "data_args.bin")) + + # Reload the best checkpoint (for eval) + model = model_fn.from_pretrained(training_args.output_dir) + model = model.to(training_args.device) + trainer.model = model + if data_args.prompt: + model.label_word_list = torch.tensor(train_dataset.label_word_list).long().cuda() + if output_modes_mapping[data_args.task_name] == 'regression': + # lower / upper bounds + model.lb, model.ub = bound_mapping[data_args.task_name] + model.model_args = model_args + model.data_args = data_args + model.tokenizer = tokenizer + + # Evaluation + final_result = { + 'time': str(datetime.today()), + } + + eval_results = {} + if training_args.do_eval: + logger.info("*** Validate ***") + + eval_datasets = [eval_dataset] + + for eval_dataset in eval_datasets: + trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name) + output = trainer.evaluate(eval_dataset=eval_dataset) + eval_result = output.metrics + + output_eval_file = os.path.join( + training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" + ) + if trainer.is_world_master(): + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) + for key, value in eval_result.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) + final_result[eval_dataset.args.task_name + '_dev_' + key] = value + eval_results.update(eval_result) + + test_results = {} + if training_args.do_predict: + logging.info("*** Test ***") + test_datasets = [test_dataset] + if data_args.task_name == "mnli": + mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") + test_datasets.append( + FewShotDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", use_demo=('demo' in model_args.few_shot_type)) + ) + + for test_dataset in test_datasets: + trainer.compute_metrics = build_compute_metrics_fn(test_dataset.args.task_name) + output = trainer.evaluate(eval_dataset=test_dataset) + test_result = output.metrics + + output_test_file = os.path.join( + training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt" + ) + if trainer.is_world_master(): + with open(output_test_file, "w") as writer: + logger.info("***** Test results {} *****".format(test_dataset.args.task_name)) + for key, value in test_result.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) + final_result[test_dataset.args.task_name + '_test_' + key] = value + + if training_args.save_logit: + predictions = output.predictions + num_logits = predictions.shape[-1] + logits = predictions.reshape([test_dataset.num_sample, -1, num_logits]).mean(axis=0) + np.save(os.path.join(training_args.save_logit_dir, "{}-{}-{}.npy".format(test_dataset.task_name, training_args.model_id, training_args.array_id)), logits) + + test_results.update(test_result) + + with FileLock('log.lock'): + with open('log', 'a') as f: + final_result.update(vars(model_args)) + final_result.update(vars(training_args)) + final_result.update(vars(data_args)) + if 'evaluation_strategy' in final_result: + final_result.pop('evaluation_strategy') + f.write(str(final_result) + '\n') + + return eval_results + +if __name__ == "__main__": + main() diff --git a/LM-BFF/src/dataset.py b/LM-BFF/src/dataset.py new file mode 100755 index 0000000..4617608 --- /dev/null +++ b/LM-BFF/src/dataset.py @@ -0,0 +1,658 @@ +"""Dataset utils for different data settings for GLUE.""" + +import os +import copy +import logging +import torch +import numpy as np +import time +from filelock import FileLock +import json +import itertools +import random +import transformers +from src.processors import processors_mapping, num_labels_mapping, output_modes_mapping, compute_metrics_mapping, median_mapping +from transformers.data.processors.utils import InputFeatures +from transformers import DataProcessor, InputExample +import dataclasses +from dataclasses import dataclass +from typing import List, Optional, Union +from sentence_transformers import SentenceTransformer, util +from copy import deepcopy +import pandas as pd + +logger = logging.getLogger(__name__) + +@dataclass(frozen=True) +class OurInputFeatures(InputFeatures): + """ + Inherit from Transformers' InputFeatuers. + """ + + input_ids: List[int] + attention_mask: Optional[List[int]] = None + token_type_ids: Optional[List[int]] = None + label: Optional[Union[int, float]] = None + mask_pos: Optional[List[int]] = None # Position of the mask token + label_word_list: Optional[List[int]] = None # Label word mapping (dynamic) + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(dataclasses.asdict(self)) + "\n" + +def input_example_to_string(example, sep_token): + if example.text_b is None: + return example.text_a + else: + # Warning: very simple hack here + return example.text_a + ' ' + sep_token + ' ' + example.text_b + +def input_example_to_tuple(example): + if example.text_b is None: + if pd.isna(example.text_a) or example.text_a is None: + return [''] + logger.warn("Empty input") + else: + return [example.text_a] + else: + return [example.text_a, example.text_b] + +def tokenize_multipart_input( + input_text_list, + max_length, + tokenizer, + task_name=None, + prompt=False, + template=None, + label_word_list=None, + first_sent_limit=None, + other_sent_limit=None, + gpt3=False, + truncate_head=False, + support_labels=None, +): + def enc(text): + return tokenizer.encode(text, add_special_tokens=False) + + input_ids = [] + attention_mask = [] + token_type_ids = [] # Only for BERT + mask_pos = None # Position of the mask token + + if prompt: + """ + Concatenate all sentences and prompts based on the provided template. + Template example: '*cls*It was*mask*.*sent_0**<sep>*label_0:*sent_1**<sep>**label_1*:*sent_2**<sep>*' + *xx* represent variables: + *cls*: cls_token + *mask*: mask_token + *sep*: sep_token + *sep+*: sep_token, also means +1 for segment id + *sent_i*: sentence i (input_text_list[i]) + *sent-_i*: same as above, but delete the last token + *sentl_i*: same as above, but use lower case for the first word + *sentl-_i*: same as above, but use lower case for the first word and delete the last token + *+sent_i*: same as above, but add a space before the sentence + *+sentl_i*: same as above, but add a space before the sentence and use lower case for the first word + *label_i*: label_word_list[i] + *label_x*: label depends on the example id (support_labels needed). this is only used in GPT-3's in-context learning + + Use "_" to replace space. + PAY ATTENTION TO SPACE!! DO NOT leave space before variables, for this will lead to extra space token. + """ + assert template is not None + + special_token_mapping = { + 'cls': tokenizer.cls_token_id, 'mask': tokenizer.mask_token_id, 'sep': tokenizer.sep_token_id, 'sep+': tokenizer.sep_token_id, + } + template_list = template.split('*') # Get variable list in the template + segment_id = 0 # Current segment id. Segment id +1 if encountering sep+. + + for part_id, part in enumerate(template_list): + new_tokens = [] + segment_plus_1_flag = False + if part in special_token_mapping: + if part == 'cls' and 'T5' in type(tokenizer).__name__: + # T5 does not have cls token + continue + new_tokens.append(special_token_mapping[part]) + if part == 'sep+': + segment_plus_1_flag = True + elif part[:6] == 'label_': + # Note that label_word_list already has extra space, so do not add more space ahead of it. + label_id = int(part.split('_')[1]) + label_word = label_word_list[label_id] + new_tokens.append(label_word) + elif part[:7] == 'labelx_': + instance_id = int(part.split('_')[1]) + label_id = support_labels[instance_id] + label_word = label_word_list[label_id] + new_tokens.append(label_word) + elif part[:5] == 'sent_': + sent_id = int(part.split('_')[1]) + new_tokens += enc(input_text_list[sent_id]) + elif part[:6] == '+sent_': + # Add space + sent_id = int(part.split('_')[1]) + new_tokens += enc(' ' + input_text_list[sent_id]) + elif part[:6] == 'sent-_': + # Delete the last token + sent_id = int(part.split('_')[1]) + new_tokens += enc(input_text_list[sent_id][:-1]) + elif part[:6] == 'sentl_': + # Lower case the first token + sent_id = int(part.split('_')[1]) + text = input_text_list[sent_id] + text = text[:1].lower() + text[1:] + new_tokens += enc(text) + elif part[:7] == '+sentl_': + # Lower case the first token and add space + sent_id = int(part.split('_')[1]) + text = input_text_list[sent_id] + text = text[:1].lower() + text[1:] + new_tokens += enc(' ' + text) + elif part[:7] == 'sentl-_': + # Lower case the first token and discard the last token + sent_id = int(part.split('_')[1]) + text = input_text_list[sent_id] + text = text[:1].lower() + text[1:] + new_tokens += enc(text[:-1]) + elif part[:6] == 'sentu_': + # Upper case the first token + sent_id = int(part.split('_')[1]) + text = input_text_list[sent_id] + text = text[:1].upper() + text[1:] + new_tokens += enc(text) + elif part[:7] == '+sentu_': + # Upper case the first token and add space + sent_id = int(part.split('_')[1]) + text = input_text_list[sent_id] + text = text[:1].upper() + text[1:] + new_tokens += enc(' ' + text) + else: + # Just natural language prompt + part = part.replace('_', ' ') + # handle special case when T5 tokenizer might add an extra space + if len(part) == 1: + new_tokens.append(tokenizer._convert_token_to_id(part)) + else: + new_tokens += enc(part) + + if part[:4] == 'sent' or part[1:5] == 'sent': + # If this part is the sentence, limit the sentence length + sent_id = int(part.split('_')[1]) + if sent_id == 0: + if first_sent_limit is not None: + new_tokens = new_tokens[:first_sent_limit] + else: + if other_sent_limit is not None: + new_tokens = new_tokens[:other_sent_limit] + + input_ids += new_tokens + attention_mask += [1 for i in range(len(new_tokens))] + token_type_ids += [segment_id for i in range(len(new_tokens))] + + if segment_plus_1_flag: + segment_id += 1 + else: + input_ids = [tokenizer.cls_token_id] + attention_mask = [1] + token_type_ids = [0] + + for sent_id, input_text in enumerate(input_text_list): + if input_text is None: + # Do not have text_b + continue + if pd.isna(input_text) or input_text is None: + # Empty input + input_text = '' + input_tokens = enc(input_text) + [tokenizer.sep_token_id] + input_ids += input_tokens + attention_mask += [1 for i in range(len(input_tokens))] + token_type_ids += [sent_id for i in range(len(input_tokens))] + + if 'T5' in type(tokenizer).__name__: # T5 does not have CLS token + input_ids = input_ids[1:] + attention_mask = attention_mask[1:] + token_type_ids = token_type_ids[1:] + + # Padding + if first_sent_limit is not None and len(input_ids) > max_length: + # If using sentence limit, the total length still exceeds the maximum limit, report a warning + logger.warn("Input exceeds max_length limit: {}".format(tokenizer.decode(input_ids))) + + while len(input_ids) < max_length: + input_ids.append(tokenizer.pad_token_id) + attention_mask.append(0) + token_type_ids.append(0) + + # Truncate + if len(input_ids) > max_length: + if truncate_head: + input_ids = input_ids[-max_length:] + attention_mask = attention_mask[-max_length:] + token_type_ids = token_type_ids[-max_length:] + else: + # Default is to truncate the tail + input_ids = input_ids[:max_length] + attention_mask = attention_mask[:max_length] + token_type_ids = token_type_ids[:max_length] + + # Find mask token + if prompt: + mask_pos = [input_ids.index(tokenizer.mask_token_id)] + # Make sure that the masked position is inside the max_length + assert mask_pos[0] < max_length + + result = {'input_ids': input_ids, 'attention_mask': attention_mask} + if 'BERT' in type(tokenizer).__name__: + # Only provide token type ids for BERT + result['token_type_ids'] = token_type_ids + + if prompt: + result['mask_pos'] = mask_pos + + return result + + + +class FewShotDataset(torch.utils.data.Dataset): + """Few-shot dataset.""" + + def __init__(self, args, tokenizer, cache_dir=None, mode="train", use_demo=False): + self.args = args + self.task_name = args.task_name + self.processor = processors_mapping[args.task_name] + self.tokenizer = tokenizer + self.mode = mode + + # If not using demonstrations, use use_demo=True + self.use_demo = use_demo + if self.use_demo: + logger.info("Use demonstrations") + assert mode in ["train", "dev", "test"] + + # Get label list and (for prompt) label word list + self.label_list = self.processor.get_labels() + self.num_labels = len(self.label_list) + if args.prompt: + assert args.mapping is not None + self.label_to_word = eval(args.mapping) + + for key in self.label_to_word: + # For RoBERTa/BART/T5, tokenization also considers space, so we use space+word as label words. + if self.label_to_word[key][0] not in ['<', '[', '.', ',']: + # Make sure space+word is in the vocabulary + assert len(tokenizer.tokenize(' ' + self.label_to_word[key])) == 1 + self.label_to_word[key] = tokenizer._convert_token_to_id(tokenizer.tokenize(' ' + self.label_to_word[key])[0]) + else: + self.label_to_word[key] = tokenizer._convert_token_to_id(self.label_to_word[key]) + logger.info("Label {} to word {} ({})".format(key, tokenizer._convert_id_to_token(self.label_to_word[key]), self.label_to_word[key])) + + if len(self.label_list) > 1: + self.label_word_list = [self.label_to_word[label] for label in self.label_list] + else: + # Regression task + # '0' represents low polarity and '1' represents high polarity. + self.label_word_list = [self.label_to_word[label] for label in ['0', '1']] + else: + self.label_to_word = None + self.label_word_list = None + + # Multiple sampling: when using demonstrations, we sample different combinations of demonstrations during + # inference and aggregate the results by averaging the logits. The number of different samples is num_sample. + if (mode == "train") or not self.use_demo: + # We do not do multiple sampling when not using demonstrations or when it's the training mode + self.num_sample = 1 + else: + self.num_sample = args.num_sample + + # If we use multiple templates, we also need to do multiple sampling during inference. + if args.prompt and args.template_list is not None: + logger.info("There are %d templates. Multiply num_sample by %d" % (len(args.template_list), len(args.template_list))) + self.num_sample *= len(args.template_list) + + logger.info("Total num_sample for mode %s: %d" % (mode, self.num_sample)) + + # Load cache + # Cache name distinguishes mode, task name, tokenizer, and length. So if you change anything beyond these elements, make sure to clear your cache. + cached_features_file = os.path.join( + cache_dir if cache_dir is not None else args.data_dir, + "cached_{}_{}_{}_{}".format( + mode, + tokenizer.__class__.__name__, + str(args.max_seq_length), + args.task_name, + ), + ) + + logger.info(f"Creating/loading examples from dataset file at {args.data_dir}") + + lock_path = cached_features_file + ".lock" + with FileLock(lock_path): + + if os.path.exists(cached_features_file) and not args.overwrite_cache: + start = time.time() + self.support_examples, self.query_examples = torch.load(cached_features_file) + logger.info( + f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start + ) + else: + logger.info(f"Creating features from dataset file at {args.data_dir}") + + # The support examples are sourced from the training set. + self.support_examples = self.processor.get_train_examples(args.data_dir) + + if mode == "dev": + self.query_examples = self.processor.get_dev_examples(args.data_dir) + elif mode == "test": + self.query_examples = self.processor.get_test_examples(args.data_dir) + else: + self.query_examples = self.support_examples + + start = time.time() + torch.save([self.support_examples, self.query_examples], cached_features_file) + # ^ This seems to take a lot of time so I want to investigate why and how we can improve. + logger.info( + "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start + ) + + # For filtering in using demonstrations, load pre-calculated embeddings + if self.use_demo and args.demo_filter: + split_name = '' + if mode == 'train': + split_name = 'train' + elif mode == 'dev': + if args.task_name == 'mnli': + split_name = 'dev_matched' + elif args.task_name == 'mnli-mm': + split_name = 'dev_mismatched' + else: + split_name = 'dev' + elif mode == 'test': + if args.task_name == 'mnli': + split_name = 'test_matched' + elif args.task_name == 'mnli-mm': + split_name = 'test_mismatched' + else: + split_name = 'test' + else: + raise NotImplementedError + + self.support_emb = np.load(os.path.join(args.data_dir, "train_{}.npy".format(args.demo_filter_model))) + self.query_emb = np.load(os.path.join(args.data_dir, "{}_{}.npy".format(split_name, args.demo_filter_model))) + logger.info("Load embeddings (for demonstration filtering) from {}".format(os.path.join(args.data_dir, "{}_{}.npy".format(split_name, args.demo_filter_model)))) + + assert len(self.support_emb) == len(self.support_examples) + assert len(self.query_emb) == len(self.query_examples) + + # Size is expanded by num_sample + self.size = len(self.query_examples) * self.num_sample + + # Prepare examples (especially for using demonstrations) + support_indices = list(range(len(self.support_examples))) + self.example_idx = [] + for sample_idx in range(self.num_sample): + for query_idx in range(len(self.query_examples)): + # If training, exclude the current example. Else keep all. + if self.use_demo and args.demo_filter: + # Demonstration filtering + candidate = [support_idx for support_idx in support_indices + if support_idx != query_idx or mode != "train"] + sim_score = [] + for support_idx in candidate: + sim_score.append((support_idx, util.pytorch_cos_sim(self.support_emb[support_idx], self.query_emb[query_idx]))) + sim_score.sort(key=lambda x: x[1], reverse=True) + if self.num_labels == 1: + # Regression task + limit_each_label = int(len(sim_score) // 2 * args.demo_filter_rate) + count_each_label = {'0': 0, '1': 0} + context_indices = [] + + if args.debug_mode: + print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug + for support_idx, score in sim_score: + if count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] < limit_each_label: + count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] += 1 + context_indices.append(support_idx) + if args.debug_mode: + print(" %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug + else: + limit_each_label = int(len(sim_score) // self.num_labels * args.demo_filter_rate) + count_each_label = {label: 0 for label in self.label_list} + context_indices = [] + + if args.debug_mode: + print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug + for support_idx, score in sim_score: + if count_each_label[self.support_examples[support_idx].label] < limit_each_label: + count_each_label[self.support_examples[support_idx].label] += 1 + context_indices.append(support_idx) + if args.debug_mode: + print(" %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug + else: + # Using demonstrations without filtering + context_indices = [support_idx for support_idx in support_indices + if support_idx != query_idx or mode != "train"] + + # We'll subsample context_indices further later. + self.example_idx.append((query_idx, context_indices, sample_idx)) + + # If it is not training, we pre-process the data; otherwise, we process the data online. + if mode != "train": + self.features = [] + _ = 0 + for query_idx, context_indices, bootstrap_idx in self.example_idx: + # The input (query) example + example = self.query_examples[query_idx] + # The demonstrations + supports = self.select_context([self.support_examples[i] for i in context_indices]) + + if args.template_list is not None: + template = args.template_list[sample_idx % len(args.template_list)] # Use template in order + else: + template = args.template + + self.features.append(self.convert_fn( + example=example, + supports=supports, + use_demo=self.use_demo, + label_list=self.label_list, + prompt=args.prompt, + template=template, + label_word_list=self.label_word_list, + verbose=True if _ == 0 else False, + )) + + _ += 1 + else: + self.features = None + + def select_context(self, context_examples): + """ + Select demonstrations from provided examples. + """ + max_demo_per_label = 1 + counts = {k: 0 for k in self.label_list} + if len(self.label_list) == 1: + # Regression + counts = {'0': 0, '1': 0} + selection = [] + + if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail: + # For GPT-3's in-context learning, we sample gpt3_in_context_num demonstrations randomly. + order = np.random.permutation(len(context_examples)) + for i in range(min(self.args.gpt3_in_context_num, len(order))): + selection.append(context_examples[order[i]]) + else: + # Our sampling strategy + order = np.random.permutation(len(context_examples)) + + for i in order: + label = context_examples[i].label + if len(self.label_list) == 1: + # Regression + label = '0' if float(label) <= median_mapping[self.args.task_name] else '1' + if counts[label] < max_demo_per_label: + selection.append(context_examples[i]) + counts[label] += 1 + if sum(counts.values()) == len(counts) * max_demo_per_label: + break + + assert len(selection) > 0 + + return selection + + def __len__(self): + return self.size + + def __getitem__(self, i): + if self.features is None: + query_idx, context_indices, bootstrap_idx = self.example_idx[i] + # The input (query) example + example = self.query_examples[query_idx] + # The demonstrations + supports = self.select_context([self.support_examples[i] for i in context_indices]) + + if self.args.template_list is not None: + template = self.args.template_list[sample_idx % len(self.args.template_list)] + else: + template = self.args.template + + features = self.convert_fn( + example=example, + supports=supports, + use_demo=self.use_demo, + label_list=self.label_list, + prompt=self.args.prompt, + template=template, + label_word_list=self.label_word_list, + verbose=False, + ) + else: + features = self.features[i] + + return features + + def get_labels(self): + return self.label_list + + + def convert_fn( + self, + example, + supports, + use_demo=False, + label_list=None, + prompt=False, + template=None, + label_word_list=None, + verbose=False + ): + """ + Returns a list of processed "InputFeatures". + """ + max_length = self.args.max_seq_length + + # Prepare labels + label_map = {label: i for i, label in enumerate(label_list)} # Mapping the label names to label ids + if len(label_list) == 1: + # Regression + label_map = {'0': 0, '1': 1} + + # Get example's label id (for training/inference) + if example.label is None: + example_label = None + elif len(label_list) == 1: + # Regerssion + example_label = float(example.label) + else: + example_label = label_map[example.label] + + # Prepare other features + if not use_demo: + # No using demonstrations + inputs = tokenize_multipart_input( + input_text_list=input_example_to_tuple(example), + max_length=max_length, + tokenizer=self.tokenizer, + task_name=self.args.task_name, + prompt=prompt, + template=template, + label_word_list=label_word_list, + first_sent_limit=self.args.first_sent_limit, + other_sent_limit=self.args.other_sent_limit, + ) + features = OurInputFeatures(**inputs, label=example_label) + + else: + # Using demonstrations + + # Max length + if self.args.double_demo: + # When using demonstrations, double the maximum length + # Note that in this case, args.max_seq_length is the maximum length for a single sentence + max_length = max_length * 2 + if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail: + # When using GPT-3's in-context learning, take the maximum tokenization length of the model (512) + max_length = 512 + + # All input sentences, including the query and the demonstrations, are put into augmented_examples, + # and are numbered based on the order (starting from 0). For single sentence tasks, the input (query) + # is the sentence 0; for sentence-pair tasks, the input (query) is the sentence 0 and 1. Note that for GPT-3's + # in-context learning, the input (query) might be at the end instead of the beginning (gpt3_in_context_head) + augmented_example = [] + query_text = input_example_to_tuple(example) # Input sentence list for query + support_by_label = [[] for i in range(len(label_map))] + + if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail: + support_labels = [] + augmented_example = query_text + for support_example in supports: + augmented_example += input_example_to_tuple(support_example) + current_label = support_example.label + if len(label_list) == 1: + current_label = '0' if float(current_label) <= median_mapping[self.args.task_name] else '1' # Regression + support_labels.append(label_map[current_label]) + else: + # Group support examples by label + for label_name, label_id in label_map.items(): + if len(label_list) == 1: + # Regression + for support_example in filter(lambda s: ('0' if float(s.label) <= median_mapping[self.args.task_name] else '1') == label_name, supports): + support_by_label[label_id] += input_example_to_tuple(support_example) + else: + for support_example in filter(lambda s: s.label == label_name, supports): + support_by_label[label_id] += input_example_to_tuple(support_example) + + augmented_example = query_text + for label_id in range(len(label_map)): + augmented_example += support_by_label[label_id] + + # Tokenization (based on the template) + inputs = tokenize_multipart_input( + input_text_list=augmented_example, + max_length=max_length, + tokenizer=self.tokenizer, + task_name=self.args.task_name, + prompt=prompt, + template=template, + label_word_list=label_word_list, + first_sent_limit=self.args.first_sent_limit, + other_sent_limit=self.args.other_sent_limit, + truncate_head=self.args.truncate_head, + gpt3=self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail, + support_labels=None if not (self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail) else support_labels + ) + features = OurInputFeatures(**inputs, label=example_label) + + if verbose: + logger.info("*** Example ***") + logger.info("guid: %s" % (example.guid)) + logger.info("features: %s" % features) + logger.info("text: %s" % self.tokenizer.decode(features.input_ids)) + + return features + + + diff --git a/LM-BFF/src/label_search.py b/LM-BFF/src/label_search.py new file mode 100755 index 0000000..e83abd3 --- /dev/null +++ b/LM-BFF/src/label_search.py @@ -0,0 +1,148 @@ +"""Automatic label search helpers.""" + +import itertools +import torch +import tqdm +import multiprocessing +import numpy as np +import scipy.spatial as spatial +import scipy.special as special +import scipy.stats as stats +import logging + +logger = logging.getLogger(__name__) + + +def select_likely_words(train_logits, train_labels, k_likely=1000, vocab=None, is_regression=False): + """Pre-select likely words based on conditional likelihood.""" + indices = [] + if is_regression: + median = np.median(train_labels) + train_labels = (train_labels > median).astype(np.int) + num_labels = np.max(train_labels) + 1 + for idx in range(num_labels): + label_logits = train_logits[train_labels == idx] + scores = label_logits.mean(axis=0) + kept = [] + for i in np.argsort(-scores): + text = vocab[i] + if not text.startswith("Ġ"): + continue + kept.append(i) + indices.append(kept[:k_likely]) + return indices + + +def select_neighbors(distances, k_neighbors, valid): + """Select k nearest neighbors based on distance (filtered to be within the 'valid' set).""" + indices = np.argsort(distances) + neighbors = [] + for i in indices: + if i not in valid: + continue + neighbors.append(i) + if k_neighbors > 0: + return neighbors[:k_neighbors] + return neighbors + + +def init(train_logits, train_labels): + global logits, labels + logits = train_logits + labels = train_labels + + +def eval_pairing_acc(pairing): + global logits, labels + label_logits = np.take(logits, pairing, axis=-1) + preds = np.argmax(label_logits, axis=-1) + correct = np.sum(preds == labels) + return correct / len(labels) + + +def eval_pairing_corr(pairing): + global logits, labels + if pairing[0] == pairing[1]: + return -1 + label_logits = np.take(logits, pairing, axis=-1) + label_probs = special.softmax(label_logits, axis=-1)[:, 1] + pearson_corr = stats.pearsonr(label_probs, labels)[0] + return pearson_corr + + +def find_labels( + model, + train_logits, + train_labels, + seed_labels=None, + k_likely=1000, + k_neighbors=None, + top_n=-1, + vocab=None, + is_regression=False, +): + # Get top indices based on conditional likelihood using the LM. + likely_indices = select_likely_words( + train_logits=train_logits, + train_labels=train_labels, + k_likely=k_likely, + vocab=vocab, + is_regression=is_regression) + + logger.info("Top labels (conditional) per class:") + for i, inds in enumerate(likely_indices): + logger.info("\t| Label %d: %s", i, ", ".join([vocab[i] for i in inds[:10]])) + + # Convert to sets. + valid_indices = [set(inds) for inds in likely_indices] + + # If specified, further re-rank according to nearest neighbors of seed labels. + # Otherwise, keep ranking as is (based on conditional likelihood only). + if seed_labels: + assert(vocab is not None) + seed_ids = [vocab.index(l) for l in seed_labels] + vocab_vecs = model.lm_head.decoder.weight.detach().cpu().numpy() + seed_vecs = np.take(vocab_vecs, seed_ids, axis=0) + + # [num_labels, vocab_size] + label_distances = spatial.distance.cdist(seed_vecs, vocab_vecs, metric="cosine") + + # Establish label candidates (as k nearest neighbors). + label_candidates = [] + logger.info("Re-ranked by nearest neighbors:") + for i, distances in enumerate(label_distances): + label_candidates.append(select_neighbors(distances, k_neighbors, valid_indices[i])) + logger.info("\t| Label: %s", seed_labels[i]) + logger.info("\t| Neighbors: %s", " ".join([vocab[idx] for idx in label_candidates[i]])) + else: + label_candidates = likely_indices + + # Brute-force search all valid pairings. + pairings = list(itertools.product(*label_candidates)) + + if is_regression: + eval_pairing = eval_pairing_corr + metric = "corr" + else: + eval_pairing = eval_pairing_acc + metric = "acc" + + # Score each pairing. + pairing_scores = [] + with multiprocessing.Pool(initializer=init, initargs=(train_logits, train_labels)) as workers: + with tqdm.tqdm(total=len(pairings)) as pbar: + chunksize = max(10, int(len(pairings) / 1000)) + for score in workers.imap(eval_pairing, pairings, chunksize=chunksize): + pairing_scores.append(score) + pbar.update() + + # Take top-n. + best_idx = np.argsort(-np.array(pairing_scores))[:top_n] + best_scores = [pairing_scores[i] for i in best_idx] + best_pairings = [pairings[i] for i in best_idx] + + logger.info("Automatically searched pairings:") + for i, indices in enumerate(best_pairings): + logger.info("\t| %s (%s = %2.2f)", " ".join([vocab[j] for j in indices]), metric, best_scores[i]) + + return best_pairings diff --git a/LM-BFF/src/models.py b/LM-BFF/src/models.py new file mode 100755 index 0000000..bfdfcc3 --- /dev/null +++ b/LM-BFF/src/models.py @@ -0,0 +1,195 @@ +"""Custom models for few-shot learning specific operations.""" + +import torch +import torch.nn as nn +import transformers +from transformers.modeling_bert import BertPreTrainedModel, BertForSequenceClassification, BertModel, BertOnlyMLMHead +from transformers.modeling_roberta import RobertaForSequenceClassification, RobertaModel, RobertaLMHead, RobertaClassificationHead +from transformers.modeling_outputs import SequenceClassifierOutput + +import logging +logger = logging.getLogger(__name__) + +def resize_token_type_embeddings(model, new_num_types: int, random_segment: bool): + """ + Resize the segment (token type) embeddings for BERT + """ + if hasattr(model, 'bert'): + old_token_type_embeddings = model.bert.embeddings.token_type_embeddings + else: + raise NotImplementedError + new_token_type_embeddings = nn.Embedding(new_num_types, old_token_type_embeddings.weight.size(1)) + if not random_segment: + new_token_type_embeddings.weight.data[:old_token_type_embeddings.weight.size(0)] = old_token_type_embeddings.weight.data + + model.config.type_vocab_size = new_num_types + if hasattr(model, 'bert'): + model.bert.embeddings.token_type_embeddings = new_token_type_embeddings + else: + raise NotImplementedError + + +class BertForPromptFinetuning(BertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.bert = BertModel(config) + self.cls = BertOnlyMLMHead(config) + self.init_weights() + + # These attributes should be assigned once the model is initialized + self.model_args = None + self.data_args = None + self.label_word_list = None + + # For regression + self.lb = None + self.ub = None + + # For label search. + self.return_full_softmax = None + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + mask_pos=None, + labels=None, + ): + batch_size = input_ids.size(0) + + if mask_pos is not None: + mask_pos = mask_pos.squeeze() + + # Encode everything + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids + ) + + # Get <mask> token representation + sequence_output, pooled_output = outputs[:2] + sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] + + # Logits over vocabulary tokens + prediction_mask_scores = self.cls(sequence_mask_output) + + # Exit early and only return mask logits. + if self.return_full_softmax: + if labels is not None: + return torch.zeros(1, out=prediction_mask_scores.new()), prediction_mask_scores + return prediction_mask_scores + + # Return logits for each label + logits = [] + for label_id in range(len(self.label_word_list)): + logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) + logits = torch.cat(logits, -1) + + # Regression task + if self.config.num_labels == 1: + logsoftmax = nn.LogSoftmax(-1) + logits = logsoftmax(logits) # Log prob of right polarity + + loss = None + if labels is not None: + if self.num_labels == 1: + # Regression task + loss_fct = nn.KLDivLoss(log_target=True) + labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) + loss = loss_fct(logits.view(-1, 2), labels) + else: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + + output = (logits,) + if self.num_labels == 1: + # Regression output + output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) + return ((loss,) + output) if loss is not None else output + + + +class RobertaForPromptFinetuning(BertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.roberta = RobertaModel(config) + self.classifier = RobertaClassificationHead(config) + self.lm_head = RobertaLMHead(config) + self.init_weights() + + # These attributes should be assigned once the model is initialized + self.model_args = None + self.data_args = None + self.label_word_list = None + + # For regression + self.lb = None + self.ub = None + + # For auto label search. + self.return_full_softmax = None + + def forward( + self, + input_ids=None, + attention_mask=None, + mask_pos=None, + labels=None, + ): + batch_size = input_ids.size(0) + + if mask_pos is not None: + mask_pos = mask_pos.squeeze() + + # Encode everything + outputs = self.roberta( + input_ids, + attention_mask=attention_mask + ) + + # Get <mask> token representation + sequence_output, pooled_output = outputs[:2] + sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos] + + # Logits over vocabulary tokens + prediction_mask_scores = self.lm_head(sequence_mask_output) + + # Exit early and only return mask logits. + if self.return_full_softmax: + if labels is not None: + return torch.zeros(1, out=prediction_mask_scores.new()), prediction_mask_scores + return prediction_mask_scores + + # Return logits for each label + logits = [] + for label_id in range(len(self.label_word_list)): + logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1)) + logits = torch.cat(logits, -1) + + # Regression task + if self.config.num_labels == 1: + logsoftmax = nn.LogSoftmax(-1) + logits = logsoftmax(logits) # Log prob of right polarity + + loss = None + if labels is not None: + if self.num_labels == 1: + # Regression task + loss_fct = nn.KLDivLoss(log_target=True) + labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1) + loss = loss_fct(logits.view(-1, 2), labels) + else: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + + output = (logits,) + if self.num_labels == 1: + # Regression output + output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,) + return ((loss,) + output) if loss is not None else output diff --git a/LM-BFF/src/processors.py b/LM-BFF/src/processors.py new file mode 100755 index 0000000..b0da9a5 --- /dev/null +++ b/LM-BFF/src/processors.py @@ -0,0 +1,626 @@ +"""Dataset utils for different data settings for GLUE.""" + +import os +import copy +import logging +import torch +import numpy as np +import time +from filelock import FileLock +import json +import itertools +import random +import transformers +from transformers.data.processors.utils import InputFeatures +from transformers import DataProcessor, InputExample +from transformers.data.processors.glue import * +from transformers.data.metrics import glue_compute_metrics +import dataclasses +from dataclasses import dataclass, asdict +from typing import List, Optional, Union +from sentence_transformers import SentenceTransformer, util +from copy import deepcopy +import pandas as pd +import logging + +logger = logging.getLogger(__name__) + +class MrpcProcessor(DataProcessor): + """Processor for the MRPC data set (GLUE version).""" + + def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" + return InputExample( + tensor_dict["idx"].numpy(), + tensor_dict["sentence1"].numpy().decode("utf-8"), + tensor_dict["sentence2"].numpy().decode("utf-8"), + str(tensor_dict["label"].numpy()), + ) + + def get_train_examples(self, data_dir): + """See base class.""" + logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) + return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training, dev and test sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, i) + text_a = line[3] + text_b = line[4] + label = line[0] + examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class MnliProcessor(DataProcessor): + """Processor for the MultiNLI data set (GLUE version).""" + + def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" + return InputExample( + tensor_dict["idx"].numpy(), + tensor_dict["premise"].numpy().decode("utf-8"), + tensor_dict["hypothesis"].numpy().decode("utf-8"), + str(tensor_dict["label"].numpy()), + ) + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched") + + def get_labels(self): + """See base class.""" + return ["contradiction", "entailment", "neutral"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training, dev and test sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[8] + text_b = line[9] + label = line[-1] + examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class MnliMismatchedProcessor(MnliProcessor): + """Processor for the MultiNLI Mismatched data set (GLUE version).""" + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched") + + +class SnliProcessor(DataProcessor): + """Processor for the MultiNLI data set (GLUE version).""" + + def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" + return InputExample( + tensor_dict["idx"].numpy(), + tensor_dict["premise"].numpy().decode("utf-8"), + tensor_dict["hypothesis"].numpy().decode("utf-8"), + str(tensor_dict["label"].numpy()), + ) + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["contradiction", "entailment", "neutral"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training, dev and test sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[7] + text_b = line[8] + label = line[-1] + examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class ColaProcessor(DataProcessor): + """Processor for the CoLA data set (GLUE version).""" + + def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" + return InputExample( + tensor_dict["idx"].numpy(), + tensor_dict["sentence"].numpy().decode("utf-8"), + None, + str(tensor_dict["label"].numpy()), + ) + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training, dev and test sets.""" + test_mode = set_type == "test" + text_index = 3 + examples = [] + for (i, line) in enumerate(lines): + guid = "%s-%s" % (set_type, i) + text_a = line[text_index] + label = line[1] + examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) + return examples + + +class Sst2Processor(DataProcessor): + """Processor for the SST-2 data set (GLUE version).""" + + def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" + return InputExample( + tensor_dict["idx"].numpy(), + tensor_dict["sentence"].numpy().decode("utf-8"), + None, + str(tensor_dict["label"].numpy()), + ) + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training, dev and test sets.""" + examples = [] + text_index = 0 + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, i) + text_a = line[text_index] + label = line[1] + examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) + return examples + + +class StsbProcessor(DataProcessor): + """Processor for the STS-B data set (GLUE version).""" + + def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" + return InputExample( + tensor_dict["idx"].numpy(), + tensor_dict["sentence1"].numpy().decode("utf-8"), + tensor_dict["sentence2"].numpy().decode("utf-8"), + str(tensor_dict["label"].numpy()), + ) + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return [None] + + def _create_examples(self, lines, set_type): + """Creates examples for the training, dev and test sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[7] + text_b = line[8] + label = line[-1] + examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class QqpProcessor(DataProcessor): + """Processor for the QQP data set (GLUE version).""" + + def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" + return InputExample( + tensor_dict["idx"].numpy(), + tensor_dict["question1"].numpy().decode("utf-8"), + tensor_dict["question2"].numpy().decode("utf-8"), + str(tensor_dict["label"].numpy()), + ) + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training, dev and test sets.""" + test_mode = set_type == "test" + q1_index = 3 + q2_index = 4 + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + try: + text_a = line[q1_index] + text_b = line[q2_index] + label = line[5] + except IndexError: + continue + examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class QnliProcessor(DataProcessor): + """Processor for the QNLI data set (GLUE version).""" + + def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" + return InputExample( + tensor_dict["idx"].numpy(), + tensor_dict["question"].numpy().decode("utf-8"), + tensor_dict["sentence"].numpy().decode("utf-8"), + str(tensor_dict["label"].numpy()), + ) + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["entailment", "not_entailment"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training, dev and test sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[1] + text_b = line[2] + label = line[-1] + examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class RteProcessor(DataProcessor): + """Processor for the RTE data set (GLUE version).""" + + def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" + return InputExample( + tensor_dict["idx"].numpy(), + tensor_dict["sentence1"].numpy().decode("utf-8"), + tensor_dict["sentence2"].numpy().decode("utf-8"), + str(tensor_dict["label"].numpy()), + ) + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["entailment", "not_entailment"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training, dev and test sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[1] + text_b = line[2] + label = line[-1] + examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + + +class WnliProcessor(DataProcessor): + """Processor for the WNLI data set (GLUE version).""" + + def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" + return InputExample( + tensor_dict["idx"].numpy(), + tensor_dict["sentence1"].numpy().decode("utf-8"), + tensor_dict["sentence2"].numpy().decode("utf-8"), + str(tensor_dict["label"].numpy()), + ) + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") + + def get_labels(self): + """See base class.""" + return ["0", "1"] + + def _create_examples(self, lines, set_type): + """Creates examples for the training, dev and test sets.""" + examples = [] + for (i, line) in enumerate(lines): + if i == 0: + continue + guid = "%s-%s" % (set_type, line[0]) + text_a = line[1] + text_b = line[2] + label = line[-1] + examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) + return examples + +class TextClassificationProcessor(DataProcessor): + """ + Data processor for text classification datasets (mr, sst-5, subj, trec, cr, mpqa). + """ + + def __init__(self, task_name): + self.task_name = task_name + + def get_example_from_tensor_dict(self, tensor_dict): + """See base class.""" + return InputExample( + tensor_dict["idx"].numpy(), + tensor_dict["sentence"].numpy().decode("utf-8"), + None, + str(tensor_dict["label"].numpy()), + ) + + def get_train_examples(self, data_dir): + """See base class.""" + return self._create_examples(pd.read_csv(os.path.join(data_dir, "train.csv"), header=None).values.tolist(), "train") + + def get_dev_examples(self, data_dir): + """See base class.""" + return self._create_examples(pd.read_csv(os.path.join(data_dir, "dev.csv"), header=None).values.tolist(), "dev") + + def get_test_examples(self, data_dir): + """See base class.""" + return self._create_examples(pd.read_csv(os.path.join(data_dir, "test.csv"), header=None).values.tolist(), "test") + + def get_labels(self): + """See base class.""" + if self.task_name == "mr": + return list(range(2)) + elif self.task_name == "sst-5": + return list(range(5)) + elif self.task_name == "subj": + return list(range(2)) + elif self.task_name == "trec": + return list(range(6)) + elif self.task_name == "cr": + return list(range(2)) + elif self.task_name == "mpqa": + return list(range(2)) + else: + raise Exception("task_name not supported.") + + def _create_examples(self, lines, set_type): + """Creates examples for the training, dev and test sets.""" + examples = [] + for (i, line) in enumerate(lines): + guid = "%s-%s" % (set_type, i) + if self.task_name == "ag_news": + examples.append(InputExample(guid=guid, text_a=line[1] + '. ' + line[2], short_text=line[1] + ".", label=line[0])) + elif self.task_name == "yelp_review_full": + examples.append(InputExample(guid=guid, text_a=line[1], short_text=line[1], label=line[0])) + elif self.task_name == "yahoo_answers": + text = line[1] + if not pd.isna(line[2]): + text += ' ' + line[2] + if not pd.isna(line[3]): + text += ' ' + line[3] + examples.append(InputExample(guid=guid, text_a=text, short_text=line[1], label=line[0])) + elif self.task_name in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']: + examples.append(InputExample(guid=guid, text_a=line[1], label=line[0])) + else: + raise Exception("Task_name not supported.") + + return examples + +def text_classification_metrics(task_name, preds, labels): + return {"acc": (preds == labels).mean()} + +# Add your task to the following mappings + +processors_mapping = { + "cola": ColaProcessor(), + "mnli": MnliProcessor(), + "mnli-mm": MnliMismatchedProcessor(), + "mrpc": MrpcProcessor(), + "sst-2": Sst2Processor(), + "sts-b": StsbProcessor(), + "qqp": QqpProcessor(), + "qnli": QnliProcessor(), + "rte": RteProcessor(), + "wnli": WnliProcessor(), + "snli": SnliProcessor(), + "mr": TextClassificationProcessor("mr"), + "sst-5": TextClassificationProcessor("sst-5"), + "subj": TextClassificationProcessor("subj"), + "trec": TextClassificationProcessor("trec"), + "cr": TextClassificationProcessor("cr"), + "mpqa": TextClassificationProcessor("mpqa") +} + +num_labels_mapping = { + "cola": 2, + "mnli": 3, + "mrpc": 2, + "sst-2": 2, + "sts-b": 1, + "qqp": 2, + "qnli": 2, + "rte": 2, + "wnli": 2, + "snli": 3, + "mr": 2, + "sst-5": 5, + "subj": 2, + "trec": 6, + "cr": 2, + "mpqa": 2 +} + +output_modes_mapping = { + "cola": "classification", + "mnli": "classification", + "mnli-mm": "classification", + "mrpc": "classification", + "sst-2": "classification", + "sts-b": "regression", + "qqp": "classification", + "qnli": "classification", + "rte": "classification", + "wnli": "classification", + "snli": "classification", + "mr": "classification", + "sst-5": "classification", + "subj": "classification", + "trec": "classification", + "cr": "classification", + "mpqa": "classification" +} + +# Return a function that takes (task_name, preds, labels) as inputs +compute_metrics_mapping = { + "cola": glue_compute_metrics, + "mnli": glue_compute_metrics, + "mnli-mm": glue_compute_metrics, + "mrpc": glue_compute_metrics, + "sst-2": glue_compute_metrics, + "sts-b": glue_compute_metrics, + "qqp": glue_compute_metrics, + "qnli": glue_compute_metrics, + "rte": glue_compute_metrics, + "wnli": glue_compute_metrics, + "snli": text_classification_metrics, + "mr": text_classification_metrics, + "sst-5": text_classification_metrics, + "subj": text_classification_metrics, + "trec": text_classification_metrics, + "cr": text_classification_metrics, + "mpqa": text_classification_metrics, +} + +# For regression task only: median +median_mapping = { + "sts-b": 2.5 +} + +bound_mapping = { + "sts-b": (0, 5) +} diff --git a/LM-BFF/src/trainer.py b/LM-BFF/src/trainer.py new file mode 100755 index 0000000..b22a527 --- /dev/null +++ b/LM-BFF/src/trainer.py @@ -0,0 +1,473 @@ +########## The following part is copied from Transformers' trainer (3.4.0) ########## + +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. +""" + +import collections +import inspect +import math +import os +import re +import shutil +import warnings +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from packaging import version +from torch import nn +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import RandomSampler, SequentialSampler + +import transformers +from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from transformers.file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available +from transformers.integrations import ( + default_hp_search_backend, + is_comet_available, + is_optuna_available, + is_ray_available, + is_tensorboard_available, + is_wandb_available, + run_hp_search_optuna, + run_hp_search_ray, +) +from transformers.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING +from transformers.modeling_utils import PreTrainedModel +from transformers.optimization import AdamW, get_linear_schedule_with_warmup +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer_callback import ( + CallbackHandler, + DefaultFlowCallback, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, +) +from transformers.trainer_pt_utils import ( + DistributedTensorGatherer, + SequentialDistributedSampler, + distributed_broadcast_scalars, + distributed_concat, + get_tpu_sampler, + nested_concat, + nested_detach, + nested_numpify, + nested_xla_mesh_reduce, + reissue_pt_warnings, +) +from transformers.trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + BestRun, + EvalPrediction, + HPSearchBackend, + PredictionOutput, + TrainOutput, + default_compute_objective, + default_hp_space, + set_seed, +) +from transformers.training_args import TrainingArguments +from transformers.utils import logging + +from tqdm import tqdm, trange + +_use_native_amp = False +_use_apex = False + +DEFAULT_CALLBACKS = [DefaultFlowCallback] +DEFAULT_PROGRESS_CALLBACK = ProgressCallback + +if is_in_notebook(): + from transformers.utils.notebook import NotebookProgressCallback + + DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback + +# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex +if version.parse(torch.__version__) < version.parse("1.6"): + from transformers.file_utils import is_apex_available + + if is_apex_available(): + from apex import amp + _use_apex = True +else: + _use_native_amp = True + from torch.cuda.amp import autocast + +if version.parse(torch.__version__) < version.parse("1.2"): + _use_ddp_no_sync = False +else: + _use_ddp_no_sync = True + +if is_datasets_available(): + import datasets + +if is_torch_tpu_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + import torch_xla.distributed.parallel_loader as pl + +if is_tensorboard_available(): + from transformers.integrations import TensorBoardCallback + + DEFAULT_CALLBACKS.append(TensorBoardCallback) + + +if is_wandb_available(): + from transformers.integrations import WandbCallback + + DEFAULT_CALLBACKS.append(WandbCallback) + +if is_comet_available(): + from transformers.integrations import CometCallback + + DEFAULT_CALLBACKS.append(CometCallback) + +if is_optuna_available(): + import optuna + +if is_ray_available(): + from ray import tune + +logger = logging.get_logger(__name__) + +########## The above part is copied from Transformers' trainer (3.4.0) ########## + +def default_dev_objective(metrics): + """ + Objective used for picking the best model on development sets + """ + if "eval_mnli/acc" in metrics: + return metrics["eval_mnli/acc"] + elif "eval_mnli-mm/acc" in metrics: + return metrics["eval_mnli-mm/acc"] + elif "eval_f1" in metrics: + return metrics["eval_f1"] + elif "eval_mcc" in metrics: + return metrics["eval_mcc"] + elif "eval_pearson" in metrics: + return metrics["eval_pearson"] + elif "eval_acc" in metrics: + return metrics["eval_acc"] + + raise Exception("No metric founded for {}".format(metrics)) + +class Trainer(transformers.Trainer): + """ + Adding some functions based on Transformers' Trainer class. + """ + + def create_optimizer_and_scheduler(self, num_training_steps: int): + """ + Based on Transformers' default one, we add fixing layer option where the bottom n layers' parameters + are fixed and only the top layers are further fine-tuned. + """ + if self.optimizer is None: + params = {} + for n, p in self.model.named_parameters(): + if self.args.fix_layers > 0: + if 'encoder.layer' in n: + try: + layer_num = int(n[n.find('encoder.layer') + 14:].split('.')[0]) + except: + print(n) + raise Exception("") + if layer_num >= self.args.fix_layers: + print('yes', n) + params[n] = p + else: + print('no ', n) + elif 'embeddings' in n: + print('no ', n) + else: + print('yes', n) + params[n] = p + else: + params[n] = p + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in params.items() if not any(nd in n for nd in no_decay)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in params.items() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + self.optimizer = AdamW( + optimizer_grouped_parameters, + lr=self.args.learning_rate, + betas=(self.args.adam_beta1, self.args.adam_beta2), + eps=self.args.adam_epsilon, + ) + if self.lr_scheduler is None: + self.lr_scheduler = get_linear_schedule_with_warmup( + self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps + ) + + def train(self, model_path=None, dev_objective=None): + """ + Main training entry point. + + The training logic is directly borrowed from transformers.Trainer (version 3.0.2). + Add early stopping. + """ + self.best_dir = None + self.objective = -float("inf") + self.dev_objective = dev_objective if dev_objective is not None else default_dev_objective + + # Data loading. + train_dataloader = self.get_train_dataloader() + num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps + if num_update_steps_per_epoch == 0: + num_update_steps_per_epoch = 1 + if self.args.max_steps > 0: + t_total = self.args.max_steps + num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( + self.args.max_steps % num_update_steps_per_epoch > 0 + ) + else: + t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) + num_train_epochs = self.args.num_train_epochs + + self.create_optimizer_and_scheduler(num_training_steps=t_total) + optimizer = self.optimizer + scheduler = self.lr_scheduler + + # Check if saved optimizer or scheduler states exist + if ( + model_path is not None + and os.path.isfile(os.path.join(model_path, "optimizer.pt")) + and os.path.isfile(os.path.join(model_path, "scheduler.pt")) + ): + # Load in optimizer and scheduler states + optimizer.load_state_dict( + torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device) + ) + scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) + + model = self.model + + if self.args.fp16 and _use_apex: + if not transformers.is_apex_available(): + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level) + + # Multi-gpu training (should be after apex fp16 initialization) + if self.args.n_gpu > 1: + model = torch.nn.DataParallel(model) + + # Distributed training (should be after apex fp16 initialization) + if self.args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.args.local_rank], + output_device=self.args.local_rank, + find_unused_parameters=True, + ) + + # Train + if transformers.is_torch_tpu_available(): + total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() + else: + total_train_batch_size = ( + self.args.train_batch_size + * self.args.gradient_accumulation_steps + * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1) + ) + logger.info("***** Running training *****") + logger.info(" Num examples = %d", self.num_examples(train_dataloader)) + logger.info(" Num Epochs = %d", num_train_epochs) + logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) + logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + + self.global_step = 0 + self.epoch = 0 + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + # Check if continuing training from a checkpoint + if model_path is not None: + # set global_step to global_step of last saved checkpoint from model path + try: + self.global_step = int(model_path.split("-")[-1].split("/")[0]) + epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps) + steps_trained_in_current_epoch = self.global_step % ( + len(train_dataloader) // self.args.gradient_accumulation_steps + ) + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(" Continuing training from epoch %d", epochs_trained) + logger.info(" Continuing training from global step %d", self.global_step) + logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) + except ValueError: + self.global_step = 0 + logger.info(" Starting fine-tuning.") + + tr_loss = torch.tensor(0.0).to(self.args.device) + logging_loss_scalar = 0.0 + model.zero_grad() + train_iterator = trange( + epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master() + ) + for epoch in train_iterator: + if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): + train_dataloader.sampler.set_epoch(epoch) + + if transformers.is_torch_tpu_available(): + parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( + self.args.device + ) + epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master()) + else: + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True) + + # Reset the past mems state at the beginning of each epoch if necessary. + if self.args.past_index >= 0: + self._past = None + + for step, inputs in enumerate(epoch_iterator): + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + continue + + tr_loss += self.training_step(model, inputs) + + if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( + # last step in epoch but step is always smaller than gradient_accumulation_steps + len(epoch_iterator) <= self.args.gradient_accumulation_steps + and (step + 1) == len(epoch_iterator) + ): + if self.args.fp16 and _use_native_amp: + self.scaler.unscale_(optimizer) + norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) + elif self.args.fp16: + norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm) + else: + norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) + + if transformers.is_torch_tpu_available(): + xm.optimizer_step(optimizer) + elif self.args.fp16 and _use_native_amp: + self.scaler.step(optimizer) + self.scaler.update() + else: + optimizer.step() + + scheduler.step() + model.zero_grad() + self.global_step += 1 + self.epoch = epoch + (step + 1) / len(epoch_iterator) + + if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or ( + self.global_step == 1 and self.args.logging_first_step + ): + logs = {} + tr_loss_scalar = tr_loss.item() + logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps + logs["norm"] = norm.item() + # backward compatibility for pytorch schedulers + logs["learning_rate"] = ( + scheduler.get_last_lr()[0] + if version.parse(torch.__version__) >= version.parse("1.4") + else scheduler.get_lr()[0] + ) + logging_loss_scalar = tr_loss_scalar + + self.log(logs) + + # ---------------------------------------------------------------------- + # BEGIN CHANGES. + # ---------------------------------------------------------------------- + + metrics = None + if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0: + output = self.evaluate() + metrics = output.metrics + objective = self.dev_objective(metrics) + if objective > self.objective: + logger.info("Best dev result: {}".format(objective)) + self.objective = objective + self.save_model(self.args.output_dir) + + # ---------------------------------------------------------------------- + # END CHANGES. + # ---------------------------------------------------------------------- + + + if self.args.max_steps > 0 and self.global_step > self.args.max_steps: + epoch_iterator.close() + break + if self.args.max_steps > 0 and self.global_step > self.args.max_steps: + train_iterator.close() + break + if self.args.tpu_metrics_debug or self.args.debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + + if self.args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + return TrainOutput(self.global_step, tr_loss / self.global_step), self.objective + + + """ + Difference compared to original implementation: return output instead of output.metrics (so there is also the logits) + """ + def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are + task-dependent (pass it to the init :obj:`compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (:obj:`Dataset`, `optional`): + Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, + columns not accepted by the ``model.forward()`` method are automatically removed. It must implement + the :obj:`__len__` method. + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. + """ + if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): + raise ValueError("eval_dataset must implement __len__") + + eval_dataloader = self.get_eval_dataloader(eval_dataset) + + output = self.prediction_loop(eval_dataloader, description="Evaluation") + + self.log(output.metrics) + + if self.args.tpu_metrics_debug or self.args.debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + + return output diff --git a/LM-BFF/tools/ensemble.py b/LM-BFF/tools/ensemble.py new file mode 100755 index 0000000..8f914e1 --- /dev/null +++ b/LM-BFF/tools/ensemble.py @@ -0,0 +1,289 @@ +import argparse +import pandas as pd +import json +import numpy as np +import torch +import os +from torch import device +from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset +from transformers import GlueDataTrainingArguments, glue_compute_metrics +from transformers.data.metrics import simple_accuracy +from transformers.data.processors.glue import glue_processors + +def get_glue_label(task, line): + if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: + line = line.strip().split('\t') + if task == 'CoLA': + return line[1] + elif task == 'MNLI': + return line[-1] + elif task == 'MRPC': + return line[0] + elif task == 'QNLI': + return line[-1] + elif task == 'QQP': + return line[-1] + elif task == 'RTE': + return line[-1] + elif task == 'SNLI': + return line[-1] + elif task == 'SST-2': + return line[-1] + elif task == 'STS-B': + return 0 if float(line[-1]) < 2.5 else 1 + elif task == 'WNLI': + return line[-1] + else: + raise NotImplementedError + else: + raise NotImplementedError + +def get_labels(data_dir, k, seed, task, print_name): + if print_name in ['sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec']: + data = pd.read_csv(os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test.csv'), header=None).values.tolist() + labels = np.zeros((len(data)), dtype=np.uint8) + for i, example in enumerate(data): + labels[i] = example[0] + elif print_name in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: + lines = [] + file_name = os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test.tsv') + if task == 'mnli': + file_name = os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test_matched.tsv') + elif task == 'mnli-mm': + file_name = os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test_mismatched.tsv') + + for line in open(file_name): + lines.append(line.strip()) + + if task != 'cola': + lines = lines[1:] + label_list = glue_processors[task]().get_labels() + label_map = {k: i for i, k in enumerate(label_list)} + if task == 'sts-b': + label_map = {0: 0, 1: 1} + label_ids = np.zeros((len(lines))) + for line_id, line in enumerate(lines): + label_ids[line_id] = label_map[get_glue_label(print_name, line)] + return label_ids + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--n_models", type=int, help="Number of models") + parser.add_argument("--k", type=int, default=16, help="Number of training instances per label") + parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)") + + # These options should usually be kept as their default values + parser.add_argument("--data_dir", type=str, default="data/k-shot", help="Data directory") + parser.add_argument("--save_logit_dir", type=str, default="ensemble_predict_results", help="Directory to store the logit file.") + parser.add_argument("--log", type=str, default="log", help="Log path.") + parser.add_argument("--key", type=str, default='', help="Validation metric name") + parser.add_argument("--test_key", type=str, default="", help="Test metric name") + parser.add_argument("--test_key2", type=str, default="", help="Second test metric name") + + args = parser.parse_args() + + condition = eval(args.condition) + + if len(args.key) == 0: + if condition['task_name'] == 'cola': + args.key = 'cola_dev_eval_mcc' + args.test_key = 'cola_test_eval_mcc' + elif condition['task_name'] == 'mrpc/acc': + args.key = 'mrpc_dev_eval_acc' + args.test_key = 'mrpc_test_eval_acc' + args.test_key2 = 'mrpc_test_eval_f1' + condition['task_name'] = 'mrpc' + elif condition['task_name'] == 'mrpc/f1': + args.key = 'mrpc_dev_eval_f1' + args.test_key2 = 'mrpc_test_eval_acc' + args.test_key = 'mrpc_test_eval_f1' + condition['task_name'] = 'mrpc' + elif condition['task_name'] == 'qqp/acc': + args.key = 'qqp_dev_eval_acc' + args.test_key = 'qqp_test_eval_acc' + args.test_key2 = 'qqp_test_eval_f1' + condition['task_name'] = 'qqp' + elif condition['task_name'] == 'qqp/f1': + args.key = 'qqp_dev_eval_f1' + args.test_key2 = 'qqp_test_eval_acc' + args.test_key = 'qqp_test_eval_f1' + condition['task_name'] = 'qqp' + elif condition['task_name'] == 'sts-b/pearson': + args.key = 'sts-b_dev_eval_pearson' + args.test_key = 'sts-b_test_eval_pearson' + args.test_key2 = 'sts-b_test_eval_spearmanr' + condition['task_name'] = 'sts-b' + elif condition['task_name'] == 'sts-b/spearmanr': + args.key = 'sts-b_dev_eval_spearmanr' + args.test_key2 = 'sts-b_test_eval_pearson' + args.test_key = 'sts-b_test_eval_spearmanr' + condition['task_name'] = 'sts-b' + elif condition['task_name'] == 'qnli': + args.key = 'qnli_dev_eval_acc' + args.test_key = 'qnli_test_eval_acc' + elif condition['task_name'] == 'sst-2': + args.key = 'sst-2_dev_eval_acc' + args.test_key = 'sst-2_test_eval_acc' + elif condition['task_name'] == 'snli': + args.key = 'snli_dev_eval_acc' + args.test_key = 'snli_test_eval_acc' + elif condition['task_name'] == 'mnli': + args.key = 'mnli_dev_eval_mnli/acc' + args.test_key = 'mnli_test_eval_mnli/acc' + elif condition['task_name'] == 'mnli-mm': + args.key = 'mnli_dev_eval_mnli/acc' + args.test_key = 'mnli-mm_test_eval_mnli-mm/acc' + elif condition['task_name'] == 'rte': + args.key = 'rte_dev_eval_acc' + args.test_key = 'rte_test_eval_acc' + elif condition['task_name'] == 'ag_news': + args.key = 'ag_news_dev_eval_acc' + args.test_key = 'ag_news_test_eval_acc' + elif condition['task_name'] == 'yahoo_answers': + args.key = 'yahoo_answers_dev_eval_acc' + args.test_key = 'yahoo_answers_test_eval_acc' + elif condition['task_name'] == 'yelp_review_full': + args.key = 'yelp_review_full_dev_eval_acc' + args.test_key = 'yelp_review_full_test_eval_acc' + elif condition['task_name'] == 'mr': + args.key = 'mr_dev_eval_acc' + args.test_key = 'mr_test_eval_acc' + elif condition['task_name'] == 'sst-5': + args.key = 'sst-5_dev_eval_acc' + args.test_key = 'sst-5_test_eval_acc' + elif condition['task_name'] == 'subj': + args.key = 'subj_dev_eval_acc' + args.test_key = 'subj_test_eval_acc' + elif condition['task_name'] == 'trec': + args.key = 'trec_dev_eval_acc' + args.test_key = 'trec_test_eval_acc' + elif condition['task_name'] == 'cr': + args.key = 'cr_dev_eval_acc' + args.test_key = 'cr_test_eval_acc' + elif condition['task_name'] == 'mpqa': + args.key = 'mpqa_dev_eval_acc' + args.test_key = 'mpqa_test_eval_acc' + else: + raise NotImplementedError + + with open(args.log) as f: + result_list = [] + for line in f: + result_list.append(eval(line)) + + seed_result = {} + seed_best = {} + + # Gather all logs satisfying the conditions + for item in result_list: + ok = True + for cond in condition: + if cond == 'task_name' and condition['task_name'] == 'mnli-mm': + if cond not in item or item[cond] != 'mnli': + ok = False + break + else: + if cond not in item or item[cond] != condition[cond]: + ok = False + break + if 'model_id' not in item or 'array_id' not in item: + ok = False + + if ok: + seed = int(item['data_dir'].split('-')[-1]) + model_id = item['model_id'] + array_id = item['array_id'] + + if model_id >= 0 and model_id < args.n_models: + if seed not in seed_result: + seed_result[seed] = {} + seed_best[seed] = {} + if model_id not in seed_result[seed]: + seed_result[seed][model_id] = [] + seed_best[seed][model_id] = {args.key: -1e9} + + seed_result[seed][model_id].append(item) + if item[args.key] > seed_best[seed][model_id][args.key]: + seed_best[seed][model_id] = item + + final_result_dev = np.zeros((len(seed_result), args.n_models)) + final_result_test = np.zeros((len(seed_result), args.n_models)) + final_result_test2 = np.zeros((len(seed_result), args.n_models)) + + logit_file_list = {} + for seed in seed_result: + logit_file_list[seed] = [] + + # Get the results for each model and pick the best dev trial for each model/seed + for model_id in range(args.n_models): + for i, seed in enumerate(seed_result): + final_result_dev[i][model_id] = seed_best[seed][model_id][args.key] + final_result_test[i][model_id] = seed_best[seed][model_id][args.test_key] + if len(args.test_key2) > 0: + final_result_test2[i][model_id] = seed_best[seed][model_id][args.test_key2] + + logit_file_list[seed].append("{}-{}-{}.npy".format(condition['task_name'], model_id, seed_best[seed][model_id]["array_id"])) + + s = "Model %d | val: mean +- std: %.1f +- %.1f | test: mean +- std: %.1f (%.1f) (median %.1f)" % (model_id, final_result_dev[:, model_id].mean() * 100, final_result_dev[:, model_id].std() * 100, final_result_test[:, model_id].mean() * 100, final_result_test[:, model_id].std() * 100, np.median(final_result_test[:, model_id]) * 100) + if len(args.test_key2) > 0: + s += " / %.1f +- %.1f (median %.1f)" % (final_result_test2[:, model_id].mean() * 100, final_result_test2[:, model_id].std() * 100, np.median(final_result_test2[:, model_id]) * 100) + print(s) + + # Map lower-case names to official names (data folder name) + data_dir_mapping = { + 'cola': 'CoLA', + 'mrpc': 'MRPC', + 'qqp': 'QQP', + 'sts-b': 'STS-B', + 'sst-2': 'SST-2', + 'snli': 'SNLI', + 'mnli': 'MNLI', + 'mnli-mm': 'MNLI', + 'rte': 'RTE', + 'ag_news': 'ag_news', + 'yahoo_answers': 'yahoo_answers', + 'yelp_review_full': 'yelp_review_full', + 'sst-5': 'sst-5', + 'mr': 'mr', + 'cr': 'cr', + 'mpqa': 'mpqa', + 'subj': 'subj', + 'trec': 'trec' + } + + tokenizer = AutoTokenizer.from_pretrained('roberta-large') + ensemble_result = np.zeros((len(seed_result))) + ensemble_result2 = np.zeros((len(seed_result))) # for second metric + + # Ensemble for each seed + for seed_id, seed in enumerate(seed_result): + labels = get_labels(args.data_dir, args.k, seed, condition['task_name'], data_dir_mapping[condition['task_name']]) + + # Logits + mean_logits = None + for fname in logit_file_list[seed]: + logits = np.load(os.path.join(args.save_logit_dir, fname)) + if mean_logits is None: + mean_logits = logits + else: + mean_logits += logits + mean_logits /= len(logit_file_list[seed]) + + # Compute metrics + preds = mean_logits.argmax(-1) + if condition['task_name'] in ['sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec']: + metric = {"acc": simple_accuracy(preds, labels)} + else: + metric = glue_compute_metrics(condition['task_name'], preds, labels) + + ensemble_result[seed_id] = metric[args.test_key.split('_')[-1]] + if len(args.test_key2) > 0: + ensemble_result2[seed_id] = metric[args.test_key2.split('_')[-1]] + + s = "mean +- std: %.1f (%.1f) (median %.1f)" % (ensemble_result.mean() * 100, ensemble_result.std() * 100, np.median(ensemble_result) * 100) + if len(args.test_key2) > 0: + s += " / %.1f (%.1f) (median %.1f)" % (ensemble_result2.mean() * 100, ensemble_result2.std() * 100, np.median(ensemble_result2) * 100) + print(s) + +if __name__ == '__main__': + main() diff --git a/LM-BFF/tools/gather_result.py b/LM-BFF/tools/gather_result.py new file mode 100755 index 0000000..83c947d --- /dev/null +++ b/LM-BFF/tools/gather_result.py @@ -0,0 +1,160 @@ +import argparse +import json +import numpy as np +import torch +from torch import device + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)") + + # These options should be kept as their default values + parser.add_argument("--log", type=str, default="log", help="Log path.") + parser.add_argument("--key", type=str, default='', help="Validation metric name") + parser.add_argument("--test_key", type=str, default="", help="Test metric name") + parser.add_argument("--test_key2", type=str, default="", help="Second test metric name") + + args = parser.parse_args() + + condition = eval(args.condition) + + if len(args.key) == 0: + if condition['task_name'] == 'cola': + args.key = 'cola_dev_eval_mcc' + args.test_key = 'cola_test_eval_mcc' + elif condition['task_name'] == 'mrpc/acc': + args.key = 'mrpc_dev_eval_acc' + args.test_key = 'mrpc_test_eval_acc' + args.test_key2 = 'mrpc_test_eval_f1' + condition['task_name'] = 'mrpc' + elif condition['task_name'] == 'mrpc/f1': + args.key = 'mrpc_dev_eval_f1' + args.test_key2 = 'mrpc_test_eval_acc' + args.test_key = 'mrpc_test_eval_f1' + condition['task_name'] = 'mrpc' + elif condition['task_name'] == 'qqp/acc': + args.key = 'qqp_dev_eval_acc' + args.test_key = 'qqp_test_eval_acc' + args.test_key2 = 'qqp_test_eval_f1' + condition['task_name'] = 'qqp' + elif condition['task_name'] == 'qqp/f1': + args.key = 'qqp_dev_eval_f1' + args.test_key2 = 'qqp_test_eval_acc' + args.test_key = 'qqp_test_eval_f1' + condition['task_name'] = 'qqp' + elif condition['task_name'] == 'sts-b/pearson': + args.key = 'sts-b_dev_eval_pearson' + args.test_key = 'sts-b_test_eval_pearson' + args.test_key2 = 'sts-b_test_eval_spearmanr' + condition['task_name'] = 'sts-b' + elif condition['task_name'] == 'sts-b/spearmanr': + args.key = 'sts-b_dev_eval_spearmanr' + args.test_key2 = 'sts-b_test_eval_pearson' + args.test_key = 'sts-b_test_eval_spearmanr' + condition['task_name'] = 'sts-b' + elif condition['task_name'] == 'qnli': + args.key = 'qnli_dev_eval_acc' + args.test_key = 'qnli_test_eval_acc' + elif condition['task_name'] == 'sst-2': + args.key = 'sst-2_dev_eval_acc' + args.test_key = 'sst-2_test_eval_acc' + elif condition['task_name'] == 'snli': + args.key = 'snli_dev_eval_acc' + args.test_key = 'snli_test_eval_acc' + elif condition['task_name'] == 'mnli': + args.key = 'mnli_dev_eval_mnli/acc' + args.test_key = 'mnli_test_eval_mnli/acc' + elif condition['task_name'] == 'mnli-mm': + condition['task_name'] = 'mnli' + args.key = 'mnli_dev_eval_mnli/acc' + args.test_key = 'mnli-mm_test_eval_mnli-mm/acc' + elif condition['task_name'] == 'rte': + args.key = 'rte_dev_eval_acc' + args.test_key = 'rte_test_eval_acc' + elif condition['task_name'] == 'ag_news': + args.key = 'ag_news_dev_eval_acc' + args.test_key = 'ag_news_test_eval_acc' + elif condition['task_name'] == 'yahoo_answers': + args.key = 'yahoo_answers_dev_eval_acc' + args.test_key = 'yahoo_answers_test_eval_acc' + elif condition['task_name'] == 'yelp_review_full': + args.key = 'yelp_review_full_dev_eval_acc' + args.test_key = 'yelp_review_full_test_eval_acc' + elif condition['task_name'] == 'mr': + args.key = 'mr_dev_eval_acc' + args.test_key = 'mr_test_eval_acc' + elif condition['task_name'] == 'sst-5': + args.key = 'sst-5_dev_eval_acc' + args.test_key = 'sst-5_test_eval_acc' + elif condition['task_name'] == 'subj': + args.key = 'subj_dev_eval_acc' + args.test_key = 'subj_test_eval_acc' + elif condition['task_name'] == 'trec': + args.key = 'trec_dev_eval_acc' + args.test_key = 'trec_test_eval_acc' + elif condition['task_name'] == 'cr': + args.key = 'cr_dev_eval_acc' + args.test_key = 'cr_test_eval_acc' + elif condition['task_name'] == 'mpqa': + args.key = 'mpqa_dev_eval_acc' + args.test_key = 'mpqa_test_eval_acc' + else: + raise NotImplementedError + + with open(args.log) as f: + result_list = [] + for line in f: + result_list.append(eval(line)) + + seed_result = {} + seed_best = {} + + for item in result_list: + ok = True + for cond in condition: + if isinstance(condition[cond], list): + if cond not in item or (item[cond] not in condition[cond]): + ok = False + break + else: + if cond not in item or (item[cond] != condition[cond]): + ok = False + break + + if ok: + seed = item['data_dir'].split('-')[-1] + '-' + str(item['seed']) + if seed not in seed_result: + seed_result[seed] = [item] + seed_best[seed] = item + else: + seed_result[seed].append(item) + if item[args.key] > seed_best[seed][args.key]: + seed_best[seed] = item + + final_result_dev = np.zeros((len(seed_best))) + final_result_test = np.zeros((len(seed_best))) + final_result_test2 = np.zeros((len(seed_best))) + for i, seed in enumerate(seed_best): + final_result_dev[i] = seed_best[seed][args.key] + final_result_test[i] = seed_best[seed][args.test_key] + if len(args.test_key2) > 0: + final_result_test2[i] = seed_best[seed][args.test_key2] + print("%s: best dev (%.5f) test (%.5f) %s | total trials: %d" % ( + seed, + seed_best[seed][args.key], + seed_best[seed][args.test_key], + "test2 (%.5f)" % (seed_best[seed][args.test_key2]) if len(args.test_key2) > 0 else "", + len(seed_result[seed]) + )) + s = '' + for k in ['per_device_train_batch_size', 'gradient_accumulation_steps', 'learning_rate', 'eval_steps', 'max_steps']: + s += '| {}: {} '.format(k, seed_best[seed][k]) + print(' ' + s) + + s = "mean +- std: %.1f (%.1f) (median %.1f)" % (final_result_test.mean() * 100, final_result_test.std() * 100, np.median(final_result_test) * 100) + if len(args.test_key2) > 0: + s += "second metric: %.1f (%.1f) (median %.1f)" % (final_result_test2.mean() * 100, final_result_test2.std() * 100, np.median(final_result_test2) * 100) + print(s) + +if __name__ == '__main__': + main() diff --git a/LM-BFF/tools/generate_k_shot_data.py b/LM-BFF/tools/generate_k_shot_data.py new file mode 100755 index 0000000..b7f4222 --- /dev/null +++ b/LM-BFF/tools/generate_k_shot_data.py @@ -0,0 +1,181 @@ +"""This script samples K examples randomly without replacement from the original data.""" + +import argparse +import os +import numpy as np +import pandas as pd +from pandas import DataFrame + +def get_label(task, line): + if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: + # GLUE style + line = line.strip().split('\t') + if task == 'CoLA': + return line[1] + elif task == 'MNLI': + return line[-1] + elif task == 'MRPC': + return line[0] + elif task == 'QNLI': + return line[-1] + elif task == 'QQP': + return line[-1] + elif task == 'RTE': + return line[-1] + elif task == 'SNLI': + return line[-1] + elif task == 'SST-2': + return line[-1] + elif task == 'STS-B': + return 0 if float(line[-1]) < 2.5 else 1 + elif task == 'WNLI': + return line[-1] + else: + raise NotImplementedError + else: + return line[0] + +def load_datasets(data_dir, tasks): + datasets = {} + for task in tasks: + if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: + # GLUE style (tsv) + dataset = {} + dirname = os.path.join(data_dir, task) + if task == "MNLI": + splits = ["train", "dev_matched", "dev_mismatched"] + else: + splits = ["train", "dev"] + for split in splits: + filename = os.path.join(dirname, f"{split}.tsv") + with open(filename, "r") as f: + lines = f.readlines() + dataset[split] = lines + datasets[task] = dataset + else: + # Other datasets (csv) + dataset = {} + dirname = os.path.join(data_dir, task) + splits = ["train", "test"] + for split in splits: + filename = os.path.join(dirname, f"{split}.csv") + dataset[split] = pd.read_csv(filename, header=None) + datasets[task] = dataset + return datasets + +def split_header(task, lines): + """ + Returns if the task file has a header or not. Only for GLUE tasks. + """ + if task in ["CoLA"]: + return [], lines + elif task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI"]: + return lines[0:1], lines[1:] + else: + raise ValueError("Unknown GLUE task.") + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--k", type=int, default=16, + help="Training examples for each class.") + parser.add_argument("--task", type=str, nargs="+", + default=['SST-2', 'sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'], + help="Task names") + parser.add_argument("--seed", type=int, nargs="+", + default=[100, 13, 21, 42, 87], + help="Random seeds") + + parser.add_argument("--data_dir", type=str, default="data/original", help="Path to original data") + parser.add_argument("--output_dir", type=str, default="data", help="Output path") + parser.add_argument("--mode", type=str, default='k-shot', choices=['k-shot', 'k-shot-10x'], help="k-shot or k-shot-10x (10x dev set)") + + args = parser.parse_args() + args.output_dir = os.path.join(args.output_dir, args.mode) + + k = args.k + print("K =", k) + datasets = load_datasets(args.data_dir, args.task) + + for seed in args.seed: + print("Seed = %d" % (seed)) + for task, dataset in datasets.items(): + # Set random seed + np.random.seed(seed) + + # Shuffle the training set + print("| Task = %s" % (task)) + if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: + # GLUE style + train_header, train_lines = split_header(task, dataset["train"]) + np.random.shuffle(train_lines) + else: + # Other datasets + train_lines = dataset['train'].values.tolist() + np.random.shuffle(train_lines) + + # Set up dir + task_dir = os.path.join(args.output_dir, task) + setting_dir = os.path.join(task_dir, f"{k}-{seed}") + os.makedirs(setting_dir, exist_ok=True) + + # Write test splits + if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: + # GLUE style + # Use the original development set as the test set (the original test sets are not publicly available) + for split, lines in dataset.items(): + if split.startswith("train"): + continue + split = split.replace('dev', 'test') + with open(os.path.join(setting_dir, f"{split}.tsv"), "w") as f: + for line in lines: + f.write(line) + else: + # Other datasets + # Use the original test sets + dataset['test'].to_csv(os.path.join(setting_dir, 'test.csv'), header=False, index=False) + + # Get label list for balanced sampling + label_list = {} + for line in train_lines: + label = get_label(task, line) + if label not in label_list: + label_list[label] = [line] + else: + label_list[label].append(line) + + if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: + with open(os.path.join(setting_dir, "train.tsv"), "w") as f: + for line in train_header: + f.write(line) + for label in label_list: + for line in label_list[label][:k]: + f.write(line) + name = "dev.tsv" + if task == 'MNLI': + name = "dev_matched.tsv" + with open(os.path.join(setting_dir, name), "w") as f: + for line in train_header: + f.write(line) + for label in label_list: + dev_rate = 11 if '10x' in args.mode else 2 + for line in label_list[label][k:k*dev_rate]: + f.write(line) + else: + new_train = [] + for label in label_list: + for line in label_list[label][:k]: + new_train.append(line) + new_train = DataFrame(new_train) + new_train.to_csv(os.path.join(setting_dir, 'train.csv'), header=False, index=False) + + new_dev = [] + for label in label_list: + dev_rate = 11 if '10x' in args.mode else 2 + for line in label_list[label][k:k*dev_rate]: + new_dev.append(line) + new_dev = DataFrame(new_dev) + new_dev.to_csv(os.path.join(setting_dir, 'dev.csv'), header=False, index=False) + + +if __name__ == "__main__": + main() diff --git a/LM-BFF/tools/generate_labels.py b/LM-BFF/tools/generate_labels.py new file mode 100755 index 0000000..3acc47f --- /dev/null +++ b/LM-BFF/tools/generate_labels.py @@ -0,0 +1,284 @@ +"""Finetuning the library models for sequence classification on GLUE.""" + +import os, sys, inspect +currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) +parentdir = os.path.dirname(currentdir) +sys.path.insert(0, parentdir) + +import logging +import json + +from dataclasses import dataclass, field +from typing import Optional + +from transformers import AutoConfig, AutoTokenizer +from transformers import GlueDataTrainingArguments as DataTrainingArguments +from transformers import HfArgumentParser, TrainingArguments, set_seed + +from src.label_search import find_labels +from src.dataset import FewShotDataset +from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings +from src.trainer import Trainer +from src.processors import output_modes_mapping, num_labels_mapping + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + + +@dataclass +class DynamicDataTrainingArguments(DataTrainingArguments): + """ + Arguments for dynamic training. + """ + # For prompting + template: str = field( + default=None, + metadata={"help": "Template"} + ) + + mapping: str = field( + default=None, + metadata={"help": "Label word mapping"} + ) + + debug_mode: bool = field( + default=False, + metadata={"help": "Debug mode"} + ) + + first_sent_limit: int = field( + default=None, + metadata={"help": "Limit the length of the first sentence (i.e., sent_0)"} + ) + + other_sent_limit: int = field( + default=None, + metadata={"help": "Limit the length of sentences other than the first sentence"} + ) + + use_full_length: bool = field( + default=None, + metadata={"help": "Use the full length (512)"} + ) + + truncate_head: bool = field( + default=False, + metadata={"help": "When exceeding the maximum length, truncate the head instead of the tail."} + ) + + use_space_word: bool = field( + default=True, + metadata={"help": "Use space words (e.g., Gpositive) instead of original words."} + ) + + use_seed_labels: bool = field( + default=False, + metadata={"help": "Regularize using seed labels"}, + ) + + k_likely: int = field( + default=100, + metadata={"help": "Take the top-k most (conditionally) likely labels per class."} + ) + + k_neighbors: int = field( + default=50, + metadata={"help": "Re-rank by nearest neighbor, and take the top k."} + ) + + n_pairs: int = field( + default=32, + metadata={"help": "Number of label pairings to use."} + ) + + output_file: str = field( + default="out", + metadata={"help": "Output file"} + ) + + append_output_file: bool = field( + default=False, + ) + + write_template: bool = field( + default=False, + ) + + +def main(): + parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments, TrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Fix prompt to be true. + data_args.prompt = True + data_args.num_sample = 1 + data_args.template_list = None + data_args.gpt3_in_context_head = False + data_args.gpt3_in_context_tail = False + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, + ) + + # Check save path + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError(f"Output directory ({training_args.output_dir}) already exists.") + + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + bool(training_args.local_rank != -1), + training_args.fp16, + ) + logger.info("Training/evaluation parameters %s", training_args) + + # Set seed + set_seed(training_args.seed) + + try: + num_labels = num_labels_mapping[data_args.task_name] + output_mode = output_modes_mapping[data_args.task_name] + logger.info("Task name: {}, number of labels: {}, output mode: {}".format(data_args.task_name, num_labels, output_mode)) + except KeyError: + raise ValueError("Task not found: %s" % (data_args.task_name)) + + # Create config + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + ) + + if config.model_type == 'roberta': + model_fn = RobertaForPromptFinetuning + elif config.model_type == 'bert': + model_fn = BertForPromptFinetuning + else: + raise NotImplementedError + special_tokens = [] + + # Create tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + additional_special_tokens=special_tokens, + cache_dir=model_args.cache_dir, + ) + + set_seed(training_args.seed) + + model = model_fn.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + ) + + # For BERT, increase the size of the segment (token type) embeddings + if config.model_type == 'bert': + model.resize_token_embeddings(len(tokenizer)) + resize_token_type_embeddings(model, new_num_types=10, random_segment=model_args.random_segment) + + # Pass dataset and argument information to the model + model.model_args = model_args + model.data_args = data_args + model.tokenizer = tokenizer + model.return_full_softmax = True + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=None, + eval_dataset=None, + ) + + # First we compute zero-shot logits on all of the examples. + dataset = FewShotDataset(data_args, tokenizer=tokenizer, mode="train", use_demo=False) + + # Predict logits. + dataloader = trainer.get_eval_dataloader(dataset) + output = trainer.prediction_loop(dataloader, description="Evaluation") + logits = output.predictions[0] if isinstance(output.predictions, (list, tuple)) else output.predictions + labels = output.label_ids + + # Assign words to labels. + if data_args.use_seed_labels: + if data_args.use_space_word: + seed_labels = {k: "Ġ" + v for k, v in eval(data_args.mapping).items()} + else: + seed_labels = eval(data_args.word_mapping) + seed_labels = [seed_labels[label] for label in dataset.get_labels()] + else: + seed_labels = None + + vocab = list(tokenizer.get_vocab()) + + # Find best labels. + label_pairings = find_labels( + model=trainer.model, + train_logits=logits, + train_labels=labels, + seed_labels=seed_labels, + k_likely=data_args.k_likely, + k_neighbors=data_args.k_neighbors, + top_n=data_args.n_pairs, + vocab=vocab, + is_regression=config.num_labels == 1) + + labels = dataset.get_labels() + if config.num_labels == 1: + labels = ['0', '1'] + + os.makedirs(os.path.dirname(data_args.output_file), exist_ok=True) + if data_args.append_output_file: + mode = "a" + else: + mode = "w" + + # Write to output. + with open(data_args.output_file, mode) as f: + for pairing in label_pairings: + words = [vocab[i][len("Ġ"):] for i in pairing] + mapping = {labels[i]: words[i] for i in range(len(labels))} + if data_args.write_template: + f.write(data_args.template + "\t") + f.write(json.dumps(mapping) + "\n") + + +if __name__ == "__main__": + main() diff --git a/LM-BFF/tools/generate_template.py b/LM-BFF/tools/generate_template.py new file mode 100755 index 0000000..1c5f615 --- /dev/null +++ b/LM-BFF/tools/generate_template.py @@ -0,0 +1,375 @@ +import transformers +from transformers import T5ForConditionalGeneration, T5Tokenizer +import argparse +import torch +import os +from tqdm import tqdm +import json +import argparse +import pandas as pd + +def get_text(template, input_text_tuple, label, tokenizer, mapping): + def enc(text): + return tokenizer.encode(text, add_special_tokens=False) + special_token_mapping = {'cls': tokenizer.cls_token_id, 'mask': tokenizer.mask_token_id, 'sep': tokenizer.sep_token_id, 'sep+': tokenizer.sep_token_id} + for i in range(10): + special_token_mapping["<extra_id_%d>" % (i)] = tokenizer._convert_token_to_id("<extra_id_%d>" % (i)) + template_list = template.split('*') + input_ids = [] + for part in template_list: + new_tokens = [] + if part in special_token_mapping: + if part == 'cls' and 'T5' in type(tokenizer).__name__: + # T5 does not have cls token + continue + new_tokens.append(special_token_mapping[part]) + elif part[:5] == 'label': + new_tokens += enc(' ' + mapping[label]) + elif part[:5] == 'sent_': + sent_id = int(part.split('_')[1]) + new_tokens += enc(input_text_tuple[sent_id]) + elif part[:6] == '+sent_': + sent_id = int(part.split('_')[1]) + new_tokens += enc(' ' + input_text_tuple[sent_id]) # add space + elif part[:6] == 'sent-_': + # Delete the last token + sent_id = int(part.split('_')[1]) + new_tokens += enc(input_text_tuple[sent_id][:-1]) + elif part[:7] == '+sentl_': + # Lower case the first token + sent_id = int(part.split('_')[1]) + text = input_text_tuple[sent_id] + text = text[:1].lower() + text[1:] + new_tokens += enc(' ' + text) + elif part[:7] == '+sentu_': + # Upper case the first token + sent_id = int(part.split('_')[1]) + text = input_text_tuple[sent_id] + text = text[:1].upper() + text[1:] + new_tokens += enc(' ' + text) + elif part[:6] == 'sentl_': + # Lower case the first token + sent_id = int(part.split('_')[1]) + text = input_text_tuple[sent_id] + text = text[:1].lower() + text[1:] + new_tokens += enc(text) + elif part[:6] == 'sentu_': + # Lower case the first token + sent_id = int(part.split('_')[1]) + text = input_text_tuple[sent_id] + text = text[:1].upper() + text[1:] + new_tokens += enc(text) + elif part[:7] == 'sentl-_': + # Lower case the first token + sent_id = int(part.split('_')[1]) + text = input_text_tuple[sent_id] + text = text[:1].lower() + text[1:] + new_tokens += enc(text[:-1]) + else: + part = part.replace('_', ' ') # there cannot be space in command, so use '_' to replace space + # handle special case when t5 tokenizer might add an extra space + if len(part) == 1: + new_tokens.append(tokenizer._convert_token_to_id(part)) + else: + new_tokens += enc(part) + + input_ids += new_tokens + return input_ids + +def generate(dataset, template, model, tokenizer, target_number, mapping, beam, label=None, length_limit=None, truncate=None): + """ + Generate templates based on given inputs + + label: Only use instances with this label (deprecated) + length_limit: At least generate content as long as length_limit (deprecated) + """ + input_texts = [] + input_tensors = [] + max_length = 0 + + # Process the inputs + for item in dataset: + if label is None or item['label'] == label: + input_text = get_text(template, item['text'], item['label'], tokenizer, mapping) + if truncate is not None: + if truncate == 'head': + input_text = input_text[-256:] + elif truncate == 'tail': + input_text = input_text[:256] + else: + raise NotImplementedError + input_ids = torch.tensor(input_text).long() + max_length = max(max_length, input_ids.size(-1)) + input_tensors.append(input_ids) + + # Concatenate inputs as a batch + input_ids = torch.zeros((len(input_tensors), max_length)).long() + attention_mask = torch.zeros((len(input_tensors), max_length)).long() + for i in range(len(input_tensors)): + input_ids[i, :input_tensors[i].size(-1)] = input_tensors[i] + attention_mask[i, :input_tensors[i].size(-1)] = 1 + + # Print some examples + print('####### example #######') + print(tokenizer.decode(input_ids[0])) + print(tokenizer.decode(input_ids[1])) + print(tokenizer.decode(input_ids[2])) + print('####### example #######\n') + + #input_ids = input_ids.cuda() + #attention_mask = attention_mask.cuda() + assert len(input_tensors) > 0 + + # Maximum generate content length + max_length = 20 + + start_mask = tokenizer._convert_token_to_id('<extra_id_0>') + ori_decoder_input_ids = torch.zeros((input_ids.size(0), max_length)).long() + ori_decoder_input_ids[..., 0] = model.config.decoder_start_token_id + + # decoder_input_ids: decoder inputs for next regressive generation + # ll: log likelihood + # output_id: which part of generated contents we are at + # output: generated content so far + # last_length (deprecated): how long we have generated for this part + current_output = [{'decoder_input_ids': ori_decoder_input_ids, 'll': 0, 'output_id': 1, 'output': [], 'last_length': -1}] + for i in tqdm(range(max_length - 2)): + new_current_output = [] + for item in current_output: + if item['output_id'] > target_number: + # Enough contents + new_current_output.append(item) + continue + decoder_input_ids = item['decoder_input_ids'] + + # Forward + batch_size = 10 + turn = input_ids.size(0) // batch_size + if input_ids.size(0) % batch_size != 0: + turn += 1 + aggr_output = [] + for t in range(turn): + start = t * batch_size + end = min((t + 1) * batch_size, input_ids.size(0)) + + with torch.no_grad(): + #aggr_output.append(model(input_ids[start:end], attention_mask=attention_mask[start:end], decoder_input_ids=decoder_input_ids.cuda()[start:end])[0]) + aggr_output.append(model(input_ids[start:end], attention_mask=attention_mask[start:end], decoder_input_ids=decoder_input_ids[start:end])[0]) + aggr_output = torch.cat(aggr_output, 0) + + # Gather results across all input sentences, and sort generated tokens by log likelihood + aggr_output = aggr_output.mean(0) + log_denominator = torch.logsumexp(aggr_output[i], -1).item() + ids = list(range(model.config.vocab_size)) + ids.sort(key=lambda x: aggr_output[i][x].item(), reverse=True) + ids = ids[:beam+3] + + for word_id in ids: + output_id = item['output_id'] + + if word_id == start_mask - output_id or word_id == tokenizer._convert_token_to_id('</s>'): + # Finish one part + if length_limit is not None and item['last_length'] < length_limit[output_id - 1]: + check = False + else: + check = True + output_id += 1 + last_length = 0 + else: + last_length = item['last_length'] + 1 + check = True + + output_text = item['output'] + [word_id] + ll = item['ll'] + aggr_output[i][word_id] - log_denominator + new_decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.size()) + new_decoder_input_ids[:] = decoder_input_ids + new_decoder_input_ids[..., i + 1] = word_id + + # Forbid single space token, "....", and ".........." + if word_id in [3, 19794, 22354]: + check = False + + # Forbid continuous "." + if len(output_text) > 1 and output_text[-2] == 5 and output_text[-1] == 5: + check = False + + if check: + # Add new results to beam search pool + new_item = {'decoder_input_ids': new_decoder_input_ids, 'll': ll, 'output_id': output_id, 'output': output_text, 'last_length': last_length} + new_current_output.append(new_item) + + if len(new_current_output) == 0: + break + + new_current_output.sort(key=lambda x: x['ll'], reverse=True) + new_current_output = new_current_output[:beam] + current_output = new_current_output + + result = [] + print("####### generated results #######") + for item in current_output: + generate_text = '' + for token in item['output']: + generate_text += tokenizer._convert_id_to_token(token) + print('--------------') + print('score:', item['ll'].item()) + print('generated ids', item['output']) + print('generated text', generate_text) + result.append(generate_text) + print("####### generated results #######\n") + + return result + +def load_dataset(task, data_dir): + if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]: + lines = open(os.path.join(data_dir, 'train.tsv')).readlines() + if task != 'CoLA': + lines = lines[1:] + + dataset = [] + for line in lines: + line = line.strip().split('\t') + if task == 'CoLA': + dataset.append({'label': line[1], 'text': [line[-1]]}) + elif task == 'MNLI': + dataset.append({'label': line[-1], 'text': [line[8], line[9]]}) + elif task == 'MRPC': + dataset.append({'label': line[0], 'text': [line[-2], line[-1]]}) + elif task == 'QNLI': + dataset.append({'label': line[-1], 'text': [line[1], line[2]]}) + elif task == 'QQP': + dataset.append({'label': line[-1], 'text': [line[3], line[4]]}) + elif task == 'RTE': + dataset.append({'label': line[-1], 'text': [line[1], line[2]]}) + elif task == 'SNLI': + dataset.append({'label': line[-1], 'text': [line[7], line[8]]}) + elif task == 'SST-2': + dataset.append({'label': line[-1], 'text': [line[0]]}) + elif task == 'STS-B': + dataset.append({'label': '0' if float(line[-1]) < 2.5 else '1', 'text': [line[-3], line[-2]]}) + elif task == 'WNLI': + dataset.append({'label': line[-1], 'text': [line[1], line[2]]}) + else: + raise NotImplementedError + else: + lines = pd.read_csv(os.path.join(data_dir, 'train.csv')).values.tolist() + dataset = [] + for line in lines: + dataset.append({'label': line[0], 'text': [line[1]]}) + + return dataset + +def search_template(model, tokenizer, task_name, k, seed, beam, output_dir, data_dir): + print('#', task_name, k, seed, beam) + dataset_path = os.path.join(data_dir, task_name, "{}-{}".format(k, seed)) + dataset = load_dataset(task_name, dataset_path) + print('|', 'dataset examples') + print('|', dataset[0]) + print('|', dataset[-1]) + print() + + # Manual label word mappings + map_of_mapping = { + 'SST-2': {'0':'terrible','1':'great'}, + 'sst-5': {0:'terrible',1:'bad',2:'okay',3:'good',4:'great'}, + 'mr': {0:'terrible',1:'great'}, + 'cr': {0:'terrible',1:'great'}, + 'subj': {0:'subjective',1:'objective'}, + 'trec': {0:'Description',1:'Entity',2:'Expression',3:'Human',4:'Location',5:'Number'}, + 'mpqa': {0:'terrible',1:'great'}, + 'CoLA': {'0':'incorrect','1':'correct'}, + 'MRPC': {'0':'No','1':'Yes'}, + 'QQP': {'0':'No','1':'Yes'}, + 'STS-B': {'0':'No','1':'Yes'}, + 'MNLI': {'contradiction':'No','entailment':'Yes','neutral':'Maybe'}, + 'SNLI': {'contradiction':'No','entailment':'Yes','neutral':'Maybe'}, + 'QNLI': {'not_entailment':'No','entailment':'Yes'}, + 'RTE': {'not_entailment':'No','entailment':'Yes'} + } + + mapping = map_of_mapping[task_name] + print('|', 'mapping') + print('|', mapping) + + os.makedirs(output_dir, exist_ok=True) + os.makedirs(os.path.join(output_dir, task_name), exist_ok=True) + f = open(os.path.join(output_dir, task_name, "{}-{}.txt".format(k, seed)), 'w') + + if task_name in ['SST-2', 'sst-5', 'mr', 'cr', 'subj', 'trec', 'CoLA', 'mpqa']: + # Single sentence tasks + # We take two kinds of templates: put [MASK] at the beginning or the end + template = "*cls**sentu_0**<extra_id_0>**label**<extra_id_1>**sep+*" + generate_text = generate(dataset, template, model, tokenizer, target_number=2, mapping=mapping, beam=beam, label=None, truncate='head')[:beam//2] + + print("####### generated templates #######") + for text in generate_text: + # Transform T5 outputs to our template format + text = text.replace('<extra_id_0>', '*cls**sent_0*') + text = text.replace('<extra_id_1>', '*mask*') + text = text.replace('<extra_id_2>', '*sep+*') + text = text.replace('</s>', '*sep+*') + text = text.replace('▁', '_') + print(text) + f.write(text + '\n') + print("####### generated templates #######\n") + + template = "*cls*.*<extra_id_0>**label**<extra_id_1>**+sentu_0**sep+*" + generate_text = generate(dataset, template, model, tokenizer, target_number=2, mapping=mapping, beam=beam, label=None, truncate='tail')[:beam//2] + print("####### generated templates #######") + for text in generate_text: + # Transform T5 outputs to our template format + text = text.replace('<extra_id_0>', '*cls*') + text = text.replace('<extra_id_1>', '*mask*') + text = text.replace('<extra_id_2>', '*+sent_0**sep+*') + text = text.replace('</s>', '*+sent_0**sep+*') + text = text.replace('▁', '_') + print(text) + f.write(text + '\n') + print("####### generated templates #######\n") + + elif task_name in ['MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE']: + # Sentence pair tasks + # We always put [MASK] between the two sentences + template = "*cls**sent-_0**<extra_id_0>**label**<extra_id_1>**+sentl_1**sep+*" + generate_text = generate(dataset, template, model, tokenizer, target_number=2, mapping=mapping, beam=beam, label=None) + print("####### generated templates #######") + for text in generate_text: + # Transform T5 outputs to our template format + text = text.replace('<extra_id_0>', '*cls**sent-_0*') + text = text.replace('<extra_id_1>', '*mask*') + text = text.replace('<extra_id_2>', '*+sentl_1**sep+*') + text = text.replace('</s>', '*+sentl_1**sep+*') + text = text.replace('▁', '_') + print(text) + f.write(text + '\n') + print("####### generated templates #######\n") + else: + raise NotImplementedError + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--t5_model', type=str, default='t5-3b', help='T5 pre-trained model') + parser.add_argument('--seed', type=int, nargs='+', default=[42, 13, 21, 100, 87], help="Data split seeds") + parser.add_argument('--task_name', type=str, nargs='+', default=['SST-2', 'sst-5', 'mr', 'cr', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'], help="Task names") + parser.add_argument('--output_dir', type=str, default='Output directory') + + parser.add_argument('--data_dir', type=str, default="data/k-shot", help="Data directory") + parser.add_argument('--beam', type=int, default=100, help="Beam search width") + parser.add_argument('--k', type=int, default=16, help="Number of training instances per label") + + args = parser.parse_args() + + model = T5ForConditionalGeneration.from_pretrained(args.t5_model) + tokenizer = T5Tokenizer.from_pretrained(args.t5_model) + tokenizer.sep_token = '</s>' + + #model = model.cuda() + model.eval() + + for task_name in args.task_name: + for seed in args.seed: + search_template(model=model, tokenizer=tokenizer, task_name=task_name, k=args.k, seed=seed, beam=args.beam, output_dir=args.output_dir, data_dir=args.data_dir) + +if __name__ == '__main__': + main() diff --git a/LM-BFF/tools/get_sbert_embedding.py b/LM-BFF/tools/get_sbert_embedding.py new file mode 100755 index 0000000..806243c --- /dev/null +++ b/LM-BFF/tools/get_sbert_embedding.py @@ -0,0 +1,105 @@ +from sentence_transformers import SentenceTransformer, util +import argparse +import os +import numpy as np +from tqdm import tqdm +import pandas as pd + +def get_sentence(task, line): + if task in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']: + # Text classification tasks + if line[1] is None or pd.isna(line[1]): + return '' + else: + return line[1] + else: + # GLUE tasks + line = line.strip().split('\t') + if task == 'CoLA': + return line[-1] + elif task == 'MNLI': + return line[8] + ' ' + line[9] + elif task == 'MRPC': + return line[-2] + ' ' + line[-1] + elif task == 'QNLI': + return line[1] + ' ' + line[2] + elif task == 'QQP': + return line[3] + ' ' + line[4] + elif task == 'RTE': + return line[1] + ' ' + line[2] + elif task == 'SNLI': + return line[7] + ' ' + line[8] + elif task == 'SST-2': + return line[0] + elif task == 'STS-B': + return line[-3] + ' ' + line[-2] + elif task == 'WNLI': + return line[1] + ' ' + line[2] + else: + raise NotImplementedError + +def split_header(task, lines): + """Returns if the task file has a header or not.""" + if task in ["CoLA"]: + return [], lines + elif task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI"]: + return lines[0:1], lines[1:] + else: + raise ValueError("Unknown GLUE task.") + +def load_datasets(data_dir, task, do_test=False): + dataset = {} + if task == "MNLI": + splits = ["train", "dev_matched"] + if do_test: + splits += ['test_matched', 'test_mismatched'] + else: + splits = ["train", "dev"] + if do_test: + splits.append('test') + for split in splits: + if task in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']: + filename = os.path.join(data_dir, f"{split}.csv") + dataset[split] = pd.read_csv(filename, header=None).values.tolist() + else: + filename = os.path.join(data_dir, f"{split}.tsv") + with open(filename, "r") as f: + lines = f.readlines() + header, content = split_header(task, lines) + dataset[split] = content + return dataset + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--do_test", action='store_true', help="Generate embeddings for test splits (test set is usually large, so we don't want to repeatedly generate embeddings for them)") + parser.add_argument("--sbert_model", type=str, default='roberta-large', help="Sentence BERT model name") + + parser.add_argument("--k", type=int, help="Number of training instances per label", default=16) + parser.add_argument("--data_dir", type=str, default="data/k-shot", help="Path to few-shot data") + parser.add_argument("--seed", type=int, nargs="+", default=[42, 13, 21, 87, 100], help="Seeds for data splits") + parser.add_argument("--task", type=str, nargs="+", default=["SST-2", "sst-5", "mr", "cr", "mpqa", "subj", "trec", "CoLA", "MRPC", "QQP", "STS-B", "MNLI", "SNLI", "QNLI", "RTE"], help="Tasks") + + args = parser.parse_args() + + model = SentenceTransformer('{}-nli-stsb-mean-tokens'.format(args.sbert_model)) + model = model.cuda() + + for task in args.task: + for seed in args.seed: + folder = os.path.join(args.data_dir, task, '{}-{}'.format(args.k, seed)) + dataset = load_datasets(folder, task, do_test=args.do_test) + for split in dataset: + print('{}-{}-{}-{}'.format(task, args.k, seed, split)) + lines = dataset[split] + embeddings = [] + for line_id, line in tqdm(enumerate(lines)): + sent = get_sentence(task, line) + if line_id == 0: + print('|', sent) + emb = model.encode(sent) + embeddings.append(emb) + embeddings = np.stack(embeddings) + np.save(os.path.join(folder, "{}_sbert-{}.npy".format(split, args.sbert_model)), embeddings) + +if __name__ == '__main__': + main() diff --git a/LM-BFF/tools/sort_mapping.py b/LM-BFF/tools/sort_mapping.py new file mode 100755 index 0000000..a3ea209 --- /dev/null +++ b/LM-BFF/tools/sort_mapping.py @@ -0,0 +1,173 @@ +import argparse +import json +import numpy as np +import torch +from torch import device +import os + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)") + parser.add_argument('--mapping_dir', type=str, help='Mapping directory') + + # These options should be kept as their default values + parser.add_argument("--k", type=int, default=16) + parser.add_argument("--log", type=str, default="log", help="Log path.") + parser.add_argument("--key", type=str, default='', help="Validation metric name") + parser.add_argument("--test_key", type=str, default="", help="Test metric name") + parser.add_argument("--test_key2", type=str, default="", help="Second test metric name") + + args = parser.parse_args() + + condition = eval(args.condition) + + if len(args.key) == 0: + if condition['task_name'] == 'cola': + args.key = 'cola_dev_eval_mcc' + args.test_key = 'cola_test_eval_mcc' + print_name = 'CoLA' + elif condition['task_name'] == 'mrpc/acc': + args.key = 'mrpc_dev_eval_acc' + args.test_key = 'mrpc_test_eval_acc' + args.test_key2 = 'mrpc_test_eval_f1' + condition['task_name'] = 'mrpc' + print_name = 'MRPC' + elif condition['task_name'] == 'mrpc/f1': + args.key = 'mrpc_dev_eval_f1' + args.test_key2 = 'mrpc_test_eval_acc' + args.test_key = 'mrpc_test_eval_f1' + condition['task_name'] = 'mrpc' + print_name = 'MRPC' + elif condition['task_name'] == 'qqp/acc': + args.key = 'qqp_dev_eval_acc' + args.test_key = 'qqp_test_eval_acc' + args.test_key2 = 'qqp_test_eval_f1' + condition['task_name'] = 'qqp' + print_name = 'QQP' + elif condition['task_name'] == 'qqp/f1': + args.key = 'qqp_dev_eval_f1' + args.test_key2 = 'qqp_test_eval_acc' + args.test_key = 'qqp_test_eval_f1' + condition['task_name'] = 'qqp' + print_name = 'QQP' + elif condition['task_name'] == 'sts-b/pearson': + args.key = 'sts-b_dev_eval_pearson' + args.test_key = 'sts-b_test_eval_pearson' + args.test_key2 = 'sts-b_test_eval_spearmanr' + condition['task_name'] = 'sts-b' + print_name = 'STS-B' + elif condition['task_name'] == 'sts-b/spearmanr': + args.key = 'sts-b_dev_eval_spearmanr' + args.test_key2 = 'sts-b_test_eval_pearson' + args.test_key = 'sts-b_test_eval_spearmanr' + condition['task_name'] = 'sts-b' + print_name = 'STS-B' + elif condition['task_name'] == 'qnli': + args.key = 'qnli_dev_eval_acc' + args.test_key = 'qnli_test_eval_acc' + print_name = 'QNLI' + elif condition['task_name'] == 'sst-2': + args.key = 'sst-2_dev_eval_acc' + args.test_key = 'sst-2_test_eval_acc' + print_name = 'SST-2' + elif condition['task_name'] == 'snli': + args.key = 'snli_dev_eval_acc' + args.test_key = 'snli_test_eval_acc' + print_name = 'SNLI' + elif condition['task_name'] == 'mnli': + args.key = 'mnli_dev_eval_mnli/acc' + args.test_key = 'mnli_test_eval_mnli/acc' + print_name = 'MNLI' + elif condition['task_name'] == 'mnli-mm': + condition['task_name'] = 'mnli' + args.key = 'mnli_dev_eval_mnli/acc' + args.test_key = 'mnli-mm_test_eval_mnli-mm/acc' + print_name = 'MNLI' + elif condition['task_name'] == 'rte': + args.key = 'rte_dev_eval_acc' + args.test_key = 'rte_test_eval_acc' + print_name = 'RTE' + elif condition['task_name'] == 'ag_news': + args.key = 'ag_news_dev_eval_acc' + args.test_key = 'ag_news_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'yahoo_answers': + args.key = 'yahoo_answers_dev_eval_acc' + args.test_key = 'yahoo_answers_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'yelp_review_full': + args.key = 'yelp_review_full_dev_eval_acc' + args.test_key = 'yelp_review_full_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'mr': + args.key = 'mr_dev_eval_acc' + args.test_key = 'mr_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'sst-5': + args.key = 'sst-5_dev_eval_acc' + args.test_key = 'sst-5_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'subj': + args.key = 'subj_dev_eval_acc' + args.test_key = 'subj_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'trec': + args.key = 'trec_dev_eval_acc' + args.test_key = 'trec_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'cr': + args.key = 'cr_dev_eval_acc' + args.test_key = 'cr_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'mpqa': + args.key = 'mpqa_dev_eval_acc' + args.test_key = 'mpqa_test_eval_acc' + print_name = condition['task_name'] + else: + raise NotImplementedError + + with open(args.log) as f: + result_list = [] + for line in f: + result_list.append(eval(line)) + + seed_result = {} + seed_result_mapping_id = {} # avoid duplication + + for item in result_list: + ok = True + for cond in condition: + if cond not in item or item[cond] != condition[cond]: + ok = False + break + + if ok: + seed = item['seed'] + if seed not in seed_result: + seed_result[seed] = [item] + seed_result_mapping_id[seed] = {item['mapping_id']: 1} + else: + if item['mapping_id'] not in seed_result_mapping_id[seed]: + seed_result[seed].append(item) + seed_result_mapping_id[seed][item['mapping_id']] = 1 + + for seed in seed_result: + print("Seed %d has %d results" % (seed, len(seed_result[seed]))) + + # Load all mappings + with open(os.path.join(args.mapping_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f: + mappings = [] + for line in f: + mappings.append(line.strip()) + + # Write sorted mappings + fsort = open(os.path.join(args.mapping_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w') + fscore = open(os.path.join(args.mapping_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w') + + seed_result[seed].sort(key=lambda x: x[args.key], reverse=True) + for item in seed_result[seed]: + fsort.write(mappings[item['mapping_id']] + '\n') + fscore.write("%.5f %s\n" % (item[args.key], mappings[item['mapping_id']])) + +if __name__ == '__main__': + main() diff --git a/LM-BFF/tools/sort_prompt.py b/LM-BFF/tools/sort_prompt.py new file mode 100755 index 0000000..ef6420d --- /dev/null +++ b/LM-BFF/tools/sort_prompt.py @@ -0,0 +1,173 @@ +import argparse +import json +import numpy as np +import torch +from torch import device +import os + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)") + parser.add_argument('--prompt_dir', type=str, help='Prompt directory') + + # These options should be kept as their default values + parser.add_argument("--k", type=int, default=16) + parser.add_argument("--log", type=str, default="log", help="Log path.") + parser.add_argument("--key", type=str, default='', help="Validation metric name") + parser.add_argument("--test_key", type=str, default="", help="Test metric name") + parser.add_argument("--test_key2", type=str, default="", help="Second test metric name") + + args = parser.parse_args() + + condition = eval(args.condition) + + if len(args.key) == 0: + if condition['task_name'] == 'cola': + args.key = 'cola_dev_eval_mcc' + args.test_key = 'cola_test_eval_mcc' + print_name = 'CoLA' + elif condition['task_name'] == 'mrpc/acc': + args.key = 'mrpc_dev_eval_acc' + args.test_key = 'mrpc_test_eval_acc' + args.test_key2 = 'mrpc_test_eval_f1' + condition['task_name'] = 'mrpc' + print_name = 'MRPC' + elif condition['task_name'] == 'mrpc/f1': + args.key = 'mrpc_dev_eval_f1' + args.test_key2 = 'mrpc_test_eval_acc' + args.test_key = 'mrpc_test_eval_f1' + condition['task_name'] = 'mrpc' + print_name = 'MRPC' + elif condition['task_name'] == 'qqp/acc': + args.key = 'qqp_dev_eval_acc' + args.test_key = 'qqp_test_eval_acc' + args.test_key2 = 'qqp_test_eval_f1' + condition['task_name'] = 'qqp' + print_name = 'QQP' + elif condition['task_name'] == 'qqp/f1': + args.key = 'qqp_dev_eval_f1' + args.test_key2 = 'qqp_test_eval_acc' + args.test_key = 'qqp_test_eval_f1' + condition['task_name'] = 'qqp' + print_name = 'QQP' + elif condition['task_name'] == 'sts-b/pearson': + args.key = 'sts-b_dev_eval_pearson' + args.test_key = 'sts-b_test_eval_pearson' + args.test_key2 = 'sts-b_test_eval_spearmanr' + condition['task_name'] = 'sts-b' + print_name = 'STS-B' + elif condition['task_name'] == 'sts-b/spearmanr': + args.key = 'sts-b_dev_eval_spearmanr' + args.test_key2 = 'sts-b_test_eval_pearson' + args.test_key = 'sts-b_test_eval_spearmanr' + condition['task_name'] = 'sts-b' + print_name = 'STS-B' + elif condition['task_name'] == 'qnli': + args.key = 'qnli_dev_eval_acc' + args.test_key = 'qnli_test_eval_acc' + print_name = 'QNLI' + elif condition['task_name'] == 'sst-2': + args.key = 'sst-2_dev_eval_acc' + args.test_key = 'sst-2_test_eval_acc' + print_name = 'SST-2' + elif condition['task_name'] == 'snli': + args.key = 'snli_dev_eval_acc' + args.test_key = 'snli_test_eval_acc' + print_name = 'SNLI' + elif condition['task_name'] == 'mnli': + args.key = 'mnli_dev_eval_mnli/acc' + args.test_key = 'mnli_test_eval_mnli/acc' + print_name = 'MNLI' + elif condition['task_name'] == 'mnli-mm': + condition['task_name'] = 'mnli' + args.key = 'mnli_dev_eval_mnli/acc' + args.test_key = 'mnli-mm_test_eval_mnli-mm/acc' + print_name = 'MNLI' + elif condition['task_name'] == 'rte': + args.key = 'rte_dev_eval_acc' + args.test_key = 'rte_test_eval_acc' + print_name = 'RTE' + elif condition['task_name'] == 'ag_news': + args.key = 'ag_news_dev_eval_acc' + args.test_key = 'ag_news_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'yahoo_answers': + args.key = 'yahoo_answers_dev_eval_acc' + args.test_key = 'yahoo_answers_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'yelp_review_full': + args.key = 'yelp_review_full_dev_eval_acc' + args.test_key = 'yelp_review_full_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'mr': + args.key = 'mr_dev_eval_acc' + args.test_key = 'mr_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'sst-5': + args.key = 'sst-5_dev_eval_acc' + args.test_key = 'sst-5_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'subj': + args.key = 'subj_dev_eval_acc' + args.test_key = 'subj_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'trec': + args.key = 'trec_dev_eval_acc' + args.test_key = 'trec_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'cr': + args.key = 'cr_dev_eval_acc' + args.test_key = 'cr_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'mpqa': + args.key = 'mpqa_dev_eval_acc' + args.test_key = 'mpqa_test_eval_acc' + print_name = condition['task_name'] + else: + raise NotImplementedError + + with open(args.log) as f: + result_list = [] + for line in f: + result_list.append(eval(line)) + + seed_result = {} + seed_result_prompt_id = {} # avoid duplication + + for item in result_list: + ok = True + for cond in condition: + if cond not in item or item[cond] != condition[cond]: + ok = False + break + + if ok: + seed = item['seed'] + if seed not in seed_result: + seed_result[seed] = [item] + seed_result_prompt_id[seed] = {item['prompt_id']: 1} + else: + if item['prompt_id'] not in seed_result_prompt_id[seed]: + seed_result[seed].append(item) + seed_result_prompt_id[seed][item['prompt_id']] = 1 + + for seed in seed_result: + print("Seed %d has %d results" % (seed, len(seed_result[seed]))) + + # Load all prompts + with open(os.path.join(args.prompt_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f: + prompts = [] + for line in f: + prompts.append(line.strip()) + + # Write sorted prompts + fsort = open(os.path.join(args.prompt_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w') + fscore = open(os.path.join(args.prompt_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w') + + seed_result[seed].sort(key=lambda x: x[args.key], reverse=True) + for item in seed_result[seed]: + fsort.write(prompts[item['prompt_id']] + '\n') + fscore.write("%.5f %s\n" % (item[args.key], prompts[item['prompt_id']])) + +if __name__ == '__main__': + main() diff --git a/LM-BFF/tools/sort_template.py b/LM-BFF/tools/sort_template.py new file mode 100755 index 0000000..64239b7 --- /dev/null +++ b/LM-BFF/tools/sort_template.py @@ -0,0 +1,173 @@ +import argparse +import json +import numpy as np +import torch +from torch import device +import os + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)") + parser.add_argument('--template_dir', type=str, help='Template directory') + + # These options should be kept as their default values + parser.add_argument("--k", type=int, default=16) + parser.add_argument("--log", type=str, default="log", help="Log path.") + parser.add_argument("--key", type=str, default='', help="Validation metric name") + parser.add_argument("--test_key", type=str, default="", help="Test metric name") + parser.add_argument("--test_key2", type=str, default="", help="Second test metric name") + + args = parser.parse_args() + + condition = eval(args.condition) + + if len(args.key) == 0: + if condition['task_name'] == 'cola': + args.key = 'cola_dev_eval_mcc' + args.test_key = 'cola_test_eval_mcc' + print_name = 'CoLA' + elif condition['task_name'] == 'mrpc/acc': + args.key = 'mrpc_dev_eval_acc' + args.test_key = 'mrpc_test_eval_acc' + args.test_key2 = 'mrpc_test_eval_f1' + condition['task_name'] = 'mrpc' + print_name = 'MRPC' + elif condition['task_name'] == 'mrpc/f1': + args.key = 'mrpc_dev_eval_f1' + args.test_key2 = 'mrpc_test_eval_acc' + args.test_key = 'mrpc_test_eval_f1' + condition['task_name'] = 'mrpc' + print_name = 'MRPC' + elif condition['task_name'] == 'qqp/acc': + args.key = 'qqp_dev_eval_acc' + args.test_key = 'qqp_test_eval_acc' + args.test_key2 = 'qqp_test_eval_f1' + condition['task_name'] = 'qqp' + print_name = 'QQP' + elif condition['task_name'] == 'qqp/f1': + args.key = 'qqp_dev_eval_f1' + args.test_key2 = 'qqp_test_eval_acc' + args.test_key = 'qqp_test_eval_f1' + condition['task_name'] = 'qqp' + print_name = 'QQP' + elif condition['task_name'] == 'sts-b/pearson': + args.key = 'sts-b_dev_eval_pearson' + args.test_key = 'sts-b_test_eval_pearson' + args.test_key2 = 'sts-b_test_eval_spearmanr' + condition['task_name'] = 'sts-b' + print_name = 'STS-B' + elif condition['task_name'] == 'sts-b/spearmanr': + args.key = 'sts-b_dev_eval_spearmanr' + args.test_key2 = 'sts-b_test_eval_pearson' + args.test_key = 'sts-b_test_eval_spearmanr' + condition['task_name'] = 'sts-b' + print_name = 'STS-B' + elif condition['task_name'] == 'qnli': + args.key = 'qnli_dev_eval_acc' + args.test_key = 'qnli_test_eval_acc' + print_name = 'QNLI' + elif condition['task_name'] == 'sst-2': + args.key = 'sst-2_dev_eval_acc' + args.test_key = 'sst-2_test_eval_acc' + print_name = 'SST-2' + elif condition['task_name'] == 'snli': + args.key = 'snli_dev_eval_acc' + args.test_key = 'snli_test_eval_acc' + print_name = 'SNLI' + elif condition['task_name'] == 'mnli': + args.key = 'mnli_dev_eval_mnli/acc' + args.test_key = 'mnli_test_eval_mnli/acc' + print_name = 'MNLI' + elif condition['task_name'] == 'mnli-mm': + condition['task_name'] = 'mnli' + args.key = 'mnli_dev_eval_mnli/acc' + args.test_key = 'mnli-mm_test_eval_mnli-mm/acc' + print_name = 'MNLI' + elif condition['task_name'] == 'rte': + args.key = 'rte_dev_eval_acc' + args.test_key = 'rte_test_eval_acc' + print_name = 'RTE' + elif condition['task_name'] == 'ag_news': + args.key = 'ag_news_dev_eval_acc' + args.test_key = 'ag_news_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'yahoo_answers': + args.key = 'yahoo_answers_dev_eval_acc' + args.test_key = 'yahoo_answers_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'yelp_review_full': + args.key = 'yelp_review_full_dev_eval_acc' + args.test_key = 'yelp_review_full_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'mr': + args.key = 'mr_dev_eval_acc' + args.test_key = 'mr_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'sst-5': + args.key = 'sst-5_dev_eval_acc' + args.test_key = 'sst-5_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'subj': + args.key = 'subj_dev_eval_acc' + args.test_key = 'subj_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'trec': + args.key = 'trec_dev_eval_acc' + args.test_key = 'trec_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'cr': + args.key = 'cr_dev_eval_acc' + args.test_key = 'cr_test_eval_acc' + print_name = condition['task_name'] + elif condition['task_name'] == 'mpqa': + args.key = 'mpqa_dev_eval_acc' + args.test_key = 'mpqa_test_eval_acc' + print_name = condition['task_name'] + else: + raise NotImplementedError + + with open(args.log) as f: + result_list = [] + for line in f: + result_list.append(eval(line)) + + seed_result = {} + seed_result_template_id = {} # avoid duplication + + for item in result_list: + ok = True + for cond in condition: + if cond not in item or item[cond] != condition[cond]: + ok = False + break + + if ok: + seed = item['seed'] + if seed not in seed_result: + seed_result[seed] = [item] + seed_result_template_id[seed] = {item['template_id']: 1} + else: + if item['template_id'] not in seed_result_template_id[seed]: + seed_result[seed].append(item) + seed_result_template_id[seed][item['template_id']] = 1 + + for seed in seed_result: + print("Seed %d has %d results" % (seed, len(seed_result[seed]))) + + # Load all templates + with open(os.path.join(args.template_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f: + templates = [] + for line in f: + templates.append(line.strip()) + + # Write sorted templates + fsort = open(os.path.join(args.template_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w') + fscore = open(os.path.join(args.template_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w') + + seed_result[seed].sort(key=lambda x: x[args.key], reverse=True) + for item in seed_result[seed]: + fsort.write(templates[item['template_id']] + '\n') + fscore.write("%.5f %s\n" % (item[args.key], templates[item['template_id']])) + +if __name__ == '__main__': + main() diff --git "a/LM-BFF/\346\234\252\345\221\275\345\220\215.ipynb" "b/LM-BFF/\346\234\252\345\221\275\345\220\215.ipynb" new file mode 100755 index 0000000..39dc502 --- /dev/null +++ "b/LM-BFF/\346\234\252\345\221\275\345\220\215.ipynb" @@ -0,0 +1,432 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f857d1db-7d6d-4c76-af4f-d508a4027192", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "head: error reading 'data/x-stance/': Is a directory\n" + ] + } + ], + "source": [ + "!head -2 data/x-stance/questions.en.json\n" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "b6421d22-ca47-4ecf-b667-8f19f4cb035a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "questions.de.jsonl questions.en.jsonl\tquestions.fr.jsonl questions.it.jsonl\n" + ] + } + ], + "source": [ + "!ls data/x-stance" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1d288f21-0b89-4974-bded-d5ca9ff24f82", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import pandas as pd\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "67b7a07f-26d9-4e6a-8473-2614f6b34887", + "metadata": {}, + "outputs": [], + "source": [ + "json_list = []\n", + "for line in open('data/x-stance/questions.it.jsonl'):\n", + " json_list.append(json.loads(line))\n", + "data = pd.DataFrame.from_dict(json_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "d505a983-7d08-4fa0-8281-99027e8edd4c", + "metadata": {}, + "outputs": [], + "source": [ + "k = 100\n", + "seed_list = [100,13,21,42,87]" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "824ff804-347d-4913-a3f0-20bd38cb4159", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "100-100 100-13 100-21 100-42 100-87 16-100 16-13\t16-21 16-42 16-87\n" + ] + } + ], + "source": [ + "!ls /home/mist/projects/LM-BFF/data/k-shot/x-stance" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "5ddb0ab8-0dfb-44bc-99c3-69829c86d8a7", + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = '/home/mist/projects/LM-BFF/data/k-shot/'" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "8b8b604e-5767-4bdb-b620-2d4c464c595c", + "metadata": {}, + "outputs": [], + "source": [ + "task = 'x-stance'" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "59dd94e5-b067-4644-832f-bc0e45d3486f", + "metadata": {}, + "outputs": [], + "source": [ + "for seed in seed_list:\n", + " task_dir = os.path.join(output_dir, task)\n", + " setting_dir = os.path.join(task_dir, f\"{k}-{seed}\")\n", + " os.makedirs(setting_dir, exist_ok=True)\n", + " data.sample(random_state=seed,n=k)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "531acdd0-cbf6-430a-91e2-66923236908c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>id</th>\n", + " <th>text</th>\n", + " <th>topic</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>2</td>\n", + " <td>Finden Sie es grundsätzlich richtig, dass der ...</td>\n", + " <td>Welfare</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>4</td>\n", + " <td>Soll zusätzlich zur bestehenden Mutterschaftsv...</td>\n", + " <td>Welfare</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>6</td>\n", + " <td>Die Invalidenversicherung spricht bei nicht ob...</td>\n", + " <td>Welfare</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>7</td>\n", + " <td>Würden Sie eine nationale Spitalplanung befürw...</td>\n", + " <td>Healthcare</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>9</td>\n", + " <td>Finden Sie es richtig, dass einzelne ärztliche...</td>\n", + " <td>Healthcare</td>\n", + " </tr>\n", + " <tr>\n", + " <th>...</th>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>189</th>\n", + " <td>3464</td>\n", + " <td>Würden Sie eine Ausdehnung der rechtlichen Mög...</td>\n", + " <td>Security</td>\n", + " </tr>\n", + " <tr>\n", + " <th>190</th>\n", + " <td>3468</td>\n", + " <td>Soll die Schweiz Verhandlungen über den Beitri...</td>\n", + " <td>Foreign Policy</td>\n", + " </tr>\n", + " <tr>\n", + " <th>191</th>\n", + " <td>3469</td>\n", + " <td>Soll der Bundesrat ein Freihandelsabkommen mit...</td>\n", + " <td>Foreign Policy</td>\n", + " </tr>\n", + " <tr>\n", + " <th>192</th>\n", + " <td>3470</td>\n", + " <td>Eine Initiative fordert, dass die Haftungsrege...</td>\n", + " <td>Foreign Policy</td>\n", + " </tr>\n", + " <tr>\n", + " <th>193</th>\n", + " <td>3471</td>\n", + " <td>Befürworten Sie die Kandidatur der Schweiz für...</td>\n", + " <td>Foreign Policy</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>194 rows × 3 columns</p>\n", + "</div>" + ], + "text/plain": [ + " id text topic\n", + "0 2 Finden Sie es grundsätzlich richtig, dass der ... Welfare\n", + "1 4 Soll zusätzlich zur bestehenden Mutterschaftsv... Welfare\n", + "2 6 Die Invalidenversicherung spricht bei nicht ob... Welfare\n", + "3 7 Würden Sie eine nationale Spitalplanung befürw... Healthcare\n", + "4 9 Finden Sie es richtig, dass einzelne ärztliche... Healthcare\n", + ".. ... ... ...\n", + "189 3464 Würden Sie eine Ausdehnung der rechtlichen Mög... Security\n", + "190 3468 Soll die Schweiz Verhandlungen über den Beitri... Foreign Policy\n", + "191 3469 Soll der Bundesrat ein Freihandelsabkommen mit... Foreign Policy\n", + "192 3470 Eine Initiative fordert, dass die Haftungsrege... Foreign Policy\n", + "193 3471 Befürworten Sie die Kandidatur der Schweiz für... Foreign Policy\n", + "\n", + "[194 rows x 3 columns]" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "bcf450cf-cb23-462f-b070-3529c1dfa86d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Infrastructure & Environment 31\n", + "Economy 23\n", + "Security 20\n", + "Immigration 19\n", + "Society 17\n", + "Education 16\n", + "Foreign Policy 16\n", + "Finances 15\n", + "Welfare 15\n", + "Healthcare 11\n", + "Political System 9\n", + "Digitisation 2\n", + "Name: topic, dtype: int64" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data['topic'].value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b07328e8-5fcf-4fad-9468-49ff91652ef9", + "metadata": {}, + "outputs": [], + "source": [ + "for seed in seed_list:\n", + " data.sample(random_state=seed,n=k)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "8d7aa681-bd14-4c48-881a-9130a9b88edc", + "metadata": {}, + "outputs": [], + "source": [ + "label_encoding = {\n", + "'Infrastructure & Environment': 0,\n", + "'Economy': 1 , \n", + "'Security': 2 , \n", + "'Immigration': 3 ,\n", + "'Society': 4 ,\n", + "'Education': 5 ,\n", + "'Foreign Policy': 6 ,\n", + "'Finances': 7 ,\n", + "'Welfare':8 ,\n", + "'Healthcare':9 ,\n", + "'Political System': 10 , \n", + "'Digitisation':11 }" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "3c2fad2e-77cf-4eff-9dda-5773f59f4402", + "metadata": {}, + "outputs": [], + "source": [ + "data['label'] = data['topic'].apply(lambda x:label_encoding[x])" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "6e4c0a41-31a8-4620-bfed-c776a21e0c1c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(194, 4)" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4fc29374-8495-43b7-829f-90c11bf8a974", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'pd' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-4-a66617debac6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mline\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'data/x-stance/questions.en.jsonl'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mjson_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataFrame\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjson_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m'+1'\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'stars'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m>=\u001b[0m\u001b[0;36m3\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m'-1'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrac\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'pd' is not defined" + ] + } + ], + "source": [ + "\n", + "\n", + "data['label']=data.apply(lambda x:'+1' if x['stars']>=3 else '-1',axis=1)\n", + "data = data.sample(frac=1)\n", + "data['text']=data['text'].apply(lambda x:' '.join(x.replace('\\n','.').split(' ')[:500]))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "551574a5-f461-4a6f-8c0a-97cf92747b74", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'data' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-9-c5d84736ba45>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'data' is not defined" + ] + } + ], + "source": [ + "data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3187648-9c75-423d-bc21-ba655e30df6f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} -- libgit2 0.26.0