From cf1a4f6527d9f35a3bb63fa33bcd6eee07dfccd0 Mon Sep 17 00:00:00 2001
From: xyw <yunwanx@foxmail.com>
Date: Thu, 28 Oct 2021 21:31:53 +0800
Subject: [PATCH] v1

---
 LM-BFF/.gitignore                    | 152 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/LICENSE                       |  21 +++++++++++++++++++++
 LM-BFF/README.md                     | 405 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/figs/lmbff.png                | Bin 0 -> 378387 bytes
 LM-BFF/nohup.out                     |   1 +
 LM-BFF/requirements.txt              |  37 +++++++++++++++++++++++++++++++++++++
 LM-BFF/run.py                        | 628 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/src/dataset.py                | 658 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/src/label_search.py           | 148 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/src/models.py                 | 195 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/src/processors.py             | 626 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/src/trainer.py                | 473 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/tools/ensemble.py             | 289 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/tools/gather_result.py        | 160 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/tools/generate_k_shot_data.py | 181 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/tools/generate_labels.py      | 284 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/tools/generate_template.py    | 375 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/tools/get_sbert_embedding.py  | 105 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/tools/sort_mapping.py         | 173 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/tools/sort_prompt.py          | 173 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/tools/sort_template.py        | 173 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 LM-BFF/未命名.ipynb               | 432 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 22 files changed, 5689 insertions(+)
 create mode 100755 LM-BFF/.gitignore
 create mode 100755 LM-BFF/LICENSE
 create mode 100755 LM-BFF/README.md
 create mode 100755 LM-BFF/figs/lmbff.png
 create mode 100755 LM-BFF/nohup.out
 create mode 100755 LM-BFF/requirements.txt
 create mode 100755 LM-BFF/run.py
 create mode 100755 LM-BFF/src/dataset.py
 create mode 100755 LM-BFF/src/label_search.py
 create mode 100755 LM-BFF/src/models.py
 create mode 100755 LM-BFF/src/processors.py
 create mode 100755 LM-BFF/src/trainer.py
 create mode 100755 LM-BFF/tools/ensemble.py
 create mode 100755 LM-BFF/tools/gather_result.py
 create mode 100755 LM-BFF/tools/generate_k_shot_data.py
 create mode 100755 LM-BFF/tools/generate_labels.py
 create mode 100755 LM-BFF/tools/generate_template.py
 create mode 100755 LM-BFF/tools/get_sbert_embedding.py
 create mode 100755 LM-BFF/tools/sort_mapping.py
 create mode 100755 LM-BFF/tools/sort_prompt.py
 create mode 100755 LM-BFF/tools/sort_template.py
 create mode 100755 LM-BFF/未命名.ipynb

diff --git a/LM-BFF/.gitignore b/LM-BFF/.gitignore
new file mode 100755
index 0000000..4cf2bba
--- /dev/null
+++ b/LM-BFF/.gitignore
@@ -0,0 +1,152 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+#   For a library or package, you might want to ignore these files since the code is
+#   intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+*.sh
+.venv
+.vscode
+data
+!data/k-shot/checksum
+log*
+runs
+result
+wandb
+ensemble_predict_results
+auto*
+my*
+slurm
diff --git a/LM-BFF/LICENSE b/LM-BFF/LICENSE
new file mode 100755
index 0000000..61d8849
--- /dev/null
+++ b/LM-BFF/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Princeton Natural Language Processing
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/LM-BFF/README.md b/LM-BFF/README.md
new file mode 100755
index 0000000..0853418
--- /dev/null
+++ b/LM-BFF/README.md
@@ -0,0 +1,405 @@
+# LM-BFF (**B**etter **F**ew-shot **F**ine-tuning of **L**anguage **M**odels)
+
+This is the implementation of the paper [Making Pre-trained Language Models Better Few-shot Learners](https://arxiv.org/pdf/2012.15723.pdf). LM-BFF is short for **b**etter **f**ew-shot **f**ine-tuning of **l**anguage **m**odels.
+
+## Quick links
+
+* [Overview](#overview)
+* [Requirements](#requirements)
+* [Prepare the data](#prepare-the-data)
+* [Run the model](#run-lm-bff)
+  * [Quick start](#quick-start)
+  * [Experiments with multiple runs](#experiments-with-multiple-runs)
+  * [Using demonstrations with filtering](#using-demonstrations-with-filtering)
+  * [Automatically searched prompt](#automatically-searched-prompt)
+  * [Ensemble](#ensemble-model)
+  * [Zero-shot experiments](#zero-shot-experiments)
+  * [How to design your own templates](#how-to-design-your-own-templates)
+* [Citation](#citation)
+
+
+## Overview
+
+![](./figs/lmbff.png)
+
+In this work we present LM-BFF, a suite of simple and complementary techniques for fine-tuning pre-trained language models on a small number of training examples. Our approach includes:
+
+1. Prompt-based fine-tuning together with a novel pipeline for automating prompt generation.
+2. A refined strategy for incorporating demonstrations into context.
+
+You can find more details of this work in our [paper](https://arxiv.org/pdf/2012.15723.pdf).
+
+## Requirements
+
+To run our code, please install all the dependency packages by using the following command:
+
+```
+pip install -r requirements.txt
+```
+
+**NOTE**: Different versions of packages (like `pytorch`, `transformers`, etc.) may lead to different results from the paper. However, the trend should still hold no matter what versions of packages you use.
+
+## Prepare the data
+
+We pack the original datasets (SST-2, SST-5, MR, CR, MPQA, Subj, TREC, CoLA, MNLI, SNLI, QNLI, RTE, MRPC, QQP, STS-B) [here](https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar). Please download it and extract the files to `./data/original`, or run the following commands:
+
+```bash
+cd data
+bash download_dataset.sh
+```
+
+Then use the following command (in the root directory) to generate the few-shot data we need:
+
+```bash
+python tools/generate_k_shot_data.py
+```
+
+See `tools/generate_k_shot_data.py` for more options. For results in the paper, we use the default options: we take `K=16` and take 5 different seeds of 13, 21, 42, 87, 100. The few-shot data will be generated to `data/k-shot`. In the directory of each dataset, there will be folders named as `$K-$SEED` indicating different dataset samples. You can use the following command to check whether the generated data are exactly the same as ours:
+
+```bash
+cd data/k-shot
+md5sum -c checksum
+```
+
+**NOTE**: During training, the model will generate/load cache files in the data folder. If your data have changed, make sure to clean all the cache files (starting with "cache").
+
+## Run LM-BFF
+
+### Quick start
+Our code is built on [transformers](https://github.com/huggingface/transformers) and we use its `3.4.0` version. Other versions of `transformers` might cause unexpected errors.
+
+Before running any experiments, create the result folder by `mkdir result` to save checkpoints. Then you can run our code with the following example:
+
+```bash
+python run.py \
+    --task_name SST-2 \
+    --data_dir data/k-shot/SST-2/16-42 \
+    --overwrite_output_dir \
+    --do_train \
+    --do_eval \
+    --do_predict \
+    --evaluate_during_training \
+    --model_name_or_path roberta-large \
+    --few_shot_type prompt-demo \
+    --num_k 16 \
+    --max_steps 1000 \
+    --eval_steps 100 \
+    --per_device_train_batch_size 2 \
+    --learning_rate 1e-5 \
+    --num_train_epochs 0 \
+    --output_dir result/tmp \
+    --seed 42 \
+    --template "*cls**sent_0*_It_was*mask*.*sep+*" \
+    --mapping "{'0':'terrible','1':'great'}" \
+    --num_sample 16 \
+```
+
+Most arguments are inherited from `transformers` and are easy to understand. We further explain some of the LM-BFF's arguments:
+
+* `few_shot_type`: There are three modes
+  * `finetune`: Standard fine-tuning
+  * `prompt`: Prompt-based fine-tuning.
+  * `prompt-demo`: Prompt-based fine-tuning with demonstrations.
+* `num_k`: Number of training instances for each class. We take `num_k`=16 in our paper. This argument is mainly used for indexing logs afterwards (because the training example numbers are actually decided by the data split you use).
+* `template`: Template for prompt-based fine-tuning. We will introduce the template format later.
+* `mapping`: Label word mapping for prompt-based fine-tuning. It is a string of dictionary indicating the mapping from label names to label words. **NOTE**: For RoBERTa, the model will automatically add space before the word. See the paper appendix for details.
+* `num_sample`: When using demonstrations during inference, the number of samples for each input query. Say `num_sample`=16, then we sample 16 different sets of demonstrations for one input, do the forward seperately, and average the logits for all 16 samples as the final prediction.
+
+Also, this codebase supports BERT-series and RoBERTa-series pre-trained models in Huggingface's `transformers`. You can check [Huggingface's website](https://huggingface.co/models) for available models and pass models with a "bert" or "roberta" in their names to `--model_name_or_path`. Some examples would be `bert-base-uncased`, `bert-large-uncased`, `roberta-base`, `roberta-large`, etc.
+
+To easily run our experiments, you can also use `run_experiment.sh` (this command runs prompt-based fine-tuning with demonstrations, no filtering, manual prompt):
+
+```bash
+TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh
+```
+
+We have already defined the templates and label word mappings in it, so you only need manipulate several hyper-parameters and `TAG` (you can use whatever tag you want and it just makes finding results easier). See `run_experiment.sh` for more options of these environment variables. Besides, you can add extra arguments by
+
+```bash
+TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--output_dir result/exp --max_seq_length 512"
+```
+
+### Experiments with multiple runs
+
+To carry out experiments with multiple data splits, as the evaluation protocol detailed in \$3.3 of [our paper](https://arxiv.org/pdf/2012.15723.pdf) (grid-search for each seed and aggregate the results over 5 different seeds), you can use the following scripts:
+
+```bash
+for seed in 13 21 42 87 100
+do
+    for bs in 2 4 8
+    do
+        for lr in 1e-5 2e-5 5e-5
+        do
+            TAG=exp \
+            TYPE=prompt-demo \
+            TASK=SST-2 \
+            BS=$bs \
+            LR=$lr \
+            SEED=$seed \
+            MODEL=roberta-large \
+            bash run_experiment.sh
+        done
+    done
+done
+```
+
+All the results will be stored in `./log`. To gather all the results, run the following command:
+
+```bash
+python tools/gather_result.py --condition "{'tag': 'exp', 'task_name': 'sst-2', 'few_shot_type': 'prompt-demo'}"
+```
+
+Then the program will find all the trials that satisfy the condition in `./log`, and print the mean/std of the final results. Note that the task names are all lower-cased and if the task has more than one metric, you need to specify the major metric (used for taking the best validation trial) in the name (e.g., `mnli`, `mnli-mm`, `mrpc/acc`, `mrpc/f1`, `qqp/acc`, `qqp/f1`, `sts-b/pearson`, `sts-b/spearman`).
+
+### Using demonstrations with filtering
+
+To use the filtering mechanism when using demonstrations, we need to first generate [Sentence-BERT](https://github.com/UKPLab/sentence-transformers) embeddings. To generate embeddings for datasets in our paper, you can directly run
+
+```
+bash tools/get_sbert_embedding.sh roberta-large
+```
+
+`roberta-large` can also be replaced by `bert-base`, `bert-large`, `roberta-base` and `distilbert-base` (see [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) for details). See `tools/get_sbert_embedding.sh` and `tools/get_sbert_embedding.py` if you want to add more datasets.
+
+After generating the embeddings (embeddings are saved as numpy files in the data folders), we can run the following commands to do prompt-based fine-tuning with demonstrations with filtering:
+
+```bash
+TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--demo_filter --demo_filter_model sbert-roberta-large"
+```
+
+### Automatically searched prompt
+
+We provide our automatic search results in `auto_template` and `auto_label_mapping`. There are three types of files:
+
+* `SST-2/16-42.txt`: Initial search results for SST-2 dataset, K=16 and SEED=42.
+* `SST-2/16-42.sort.txt`: Do prompt-based fine-tuning on initial results and sort them based on dev set performance.
+* `SST-2/16-42.score.txt`: Same as above, but with dev set scores.
+
+To use the best automatic template (`auto-T` in the paper), use the following command:
+
+```bash
+TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--template_path auto_template/SST-2/16-42.sort.txt --template_id 0"
+```
+
+You can also use the _i_-th automatic result by specifying different `template_id`.
+
+Similarly, to use automatic label (`auto-L` in the paper), use the following command:
+
+```bash
+TAG=exp TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--mapping_path auto_label_mapping/SST-2/16-42.sort.txt --mapping_id 0"
+```
+
+**NOTE**: Make sure to use the corresponding automatic search results with different data split seeds.
+
+**Our final results (LM-BFF) take prompt-based fine-tuning with demonstrations, filtering and automatic template, for example**:
+
+```bash
+for seed in 13 21 42 87 100
+do
+    for bs in 2 4 8
+    do
+        for lr in 1e-5 2e-5 5e-5
+        do
+            TAG=LM-BFF \
+            TYPE=prompt-demo \
+            TASK=SST-2 \
+            BS=$bs \
+            LR=$lr \
+            SEED=$seed \
+            MODEL=roberta-large \
+            bash run_experiment.sh "--template_path auto_template/SST-2/16-$seed.sort.txt --template_id 0 --demo_filter --demo_filter_model sbert-roberta-large"
+        done
+    done
+done
+
+python tools/gather_result.py --condition "{'tag': 'LM-BFF', 'task_name': 'sst-2', 'few_shot_type': 'prompt-demo'}"
+```
+
+#### Search for automatic templates
+
+If you want to try automatically generating templates by yourself, here are the instructions. Note that it is an extremely long process :)
+
+To get automatic templates, we first generate template candidates by using T5:
+
+```bash
+python tools/generate_template.py \
+    --output_dir my_auto_template \
+    --task_name SST-2 \
+    --seed 13 21 42 87 100 \
+    --t5_model t5-3b \
+    --beam 100
+```
+
+Where `--t5_model` specifies the pre-trained T5 checkpoint to use and `--beam` specifies the beam search width. Note that `t5-3b` model will take approximately 15GB GPU memory, and if your GPU does not support it, you can try smaller T5 models (e.g., `t5-base`).
+
+Then we do prompt-based fine-tuning of all the templates
+
+```bash
+for template_id in {0..99}
+do
+    for seed in 13 21 42 87 100
+    do
+        # To save time, we fix these hyper-parameters
+        bs=8
+        lr=1e-5
+
+        # Since we only use dev performance here, use --no_predict to skip testing
+        TAG=exp-template \
+        TYPE=prompt \
+        TASK=SST-2 \
+        BS=$bs \
+        LR=$lr \
+        SEED=$seed \
+        MODEL=roberta-large \
+        bash run_experiment.sh "--template_path my_auto_template/SST-2/16-$seed.txt --template_id $template_id --no_predict"
+    done
+done
+```
+
+... and sort them based on dev set performance:
+
+```bash
+python tools/sort_template.py --condition "{'tag': 'exp-template', 'task_name': 'sst-2'}" --template_dir my_auto_template
+```
+
+The sorted results will be saved in `my_auto_template`, with the same format as described in [Automatically searched prompt](#automatically-searched-prompt).
+
+#### Search for automatic label word mappings
+
+Similar to the process of automatic template search, we first generate candidate label word mappings by running:
+
+```bash
+bash tools/run_generate_labels.sh
+```
+
+You can modify the options in `tools/run_generate_labels.sh` to run this for different datasets or save mappings to different directories. After running the generation, the candidate label mappings will be saved in `my_auto_label_mapping/manual_template`.
+
+Then we do prompt-based fine-tuning of all the mappings by:
+
+```bash
+for mapping_id in {0..99}
+do
+    for seed in 13 21 42 87 100
+    do
+        # To save time, we fix these hyper-parameters
+        bs=8
+        lr=1e-5
+
+        # Since we only use dev performance here, use --no_predict to skip testing
+        TAG=exp-mapping \
+        TYPE=prompt \
+        TASK=SST-2 \
+        BS=$bs \
+        LR=$lr \
+        SEED=$seed \
+        MODEL=roberta-large \
+        bash run_experiment.sh "--mapping_path my_auto_label_mapping/manual_template/SST-2/16-$seed.txt --mapping_id $mapping_id --no_predict"
+    done
+done
+```
+
+... and sort them based on dev set performance:
+
+```bash
+python tools/sort_mapping.py --condition "{'tag': 'exp-mapping', 'task_name': 'sst-2'}" --mapping_dir my_auto_label_mapping/manual_template
+```
+
+The sorted results will be saved in `my_auto_label_mapping/manual_template`, with the same format as described in [Automatically searched prompt](#automatically-searched-prompt).
+
+**Auto T + L**: We can also do a joint search of templates and label word mappings following these steps:
+
+1. First, do the automatic template search following [Search for automatic templates](#search-for-automatic-templates).
+2. The following steps are similar to automatic label mapping except a few arguments. When running `tools/run_generate_labels.sh`, change `LOAD_TEMPLATES` to `true` in it and the template + mapping candidates will be written in `my_auto_label_mapping/auto_template`
+3. For the following fine-tuning, change `--mapping_path` and `--mapping_id` to `--prompt_path` and `--prompt_id`.
+4. In the end, for re-ranking all the prompts, change `tools/sort_mapping.py` to `tools/sort_prompt.py` to get the final lists.
+
+### Ensemble model
+
+First we need to train models with different templates:
+
+```bash
+mkdir ensemble_predict_results
+for template_id in {0..19} # Use top 20 templates
+do
+    array_id=0
+    for seed in 13 21 42 87 100
+    do
+        for bs in 2 4 8
+        do
+            for lr in 1e-5 2e-5 5e-5
+            do
+                TAG=exp-ensemble \
+                TYPE=prompt-demo \
+                TASK=SST-2 \
+                BS=$bs \
+                LR=$lr \
+                SEED=$seed \
+                MODEL=roberta-large \
+                bash run_experiment.sh "--template_path auto_template/SST-2/16-$seed.sort.txt --template_id $template_id --model_id $template_id --array_id $array_id --save_logit --save_logit_dir ensemble_predict_results"
+
+                array_id=$(expr $array_id + 1)
+            done
+        done
+    done
+done
+```
+
+Looks a little complicated? It's actually pretty easy to understand: `--model_id` and `--array_id` is used to distinguish different runs, and `--save_logit` tells the program to save the prediction results for ensemble.
+
+After finishing the experiments, use the following command to get the ensemble results:
+
+```bash
+python tools/ensemble.py --condition "{'tag': 'exp-ensemble', 'task_name': 'sst-2', 'few_shot_type': 'prompt-demo'}" --n_models 20
+```
+
+where `--n_models` specify how many models you want to use for ensemble (should be kept the same as the number of templates you use in experiments).
+
+### Zero-shot experiments
+
+It's easy to run zero-shot experiments: just add the `--no_train` argument:
+
+```bash
+TAG=zero-shot TYPE=prompt TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--no_train"
+```
+
+To do "GPT-3 style" in-context learning:
+
+```bash
+TAG=gpt3-in-context TYPE=prompt-demo TASK=SST-2 BS=2 LR=1e-5 SEED=42 MODEL=roberta-large bash run_experiment.sh "--no_train --num_sample 1 --gpt3_in_context_head --gpt3_in_context_num 32 --truncate_head --use_full_length"
+```
+
+### How to design your own templates
+
+Here are two template examples:
+
+For SST-2: `*cls**sent_0*_It_was*mask*.*sep+*` => `[CLS] {S0} It was [MASK]. [SEP]`
+
+For MNLI: `*cls**sent-_0*?*mask*,*+sentl_1**sep+*` => `[CLS] {S0}? [MASK], {S1} [SEP]`
+
+The template is composed of special tokens and variables (surrounded by `*`) and text (e.g., `It_was`, where space is replaced by `_`). Special tokens and variables contain:
+
+* `*cls*`, `*sep*`, `*sep+*` and `*mask*`: Special tokens of CLS, SEP and MASK (different for different pre-trained models and tokenizers). `*sep+*` means the contents before and after this token have different segment embeddings (only for BERT).
+* `*sent_i*`: The i-th sentence.
+* `*sent-_i*`: The i-th sentence, discarding the last character.
+* `*sentl_i*`: The i-th sentence, lower-casing the first letter.
+* `*sentl-_i*`: The i-th sentence, discarding the last character and lower-casing the first letter.
+* `*+sent_i*`: The i-th sentence, adding an extra space at the beginning.
+* `*+sentl_i*`: The i-th sentence, adding an extra space at the beginning and lower-casing the first letter.
+
+
+## Bugs or questions?
+
+If you have any questions related to the code or the paper, feel free to email Tianyu (`tianyug@cs.princeton.edu`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker!
+
+## Citation
+
+Please cite our paper if you use LM-BFF in your work:
+
+```bibtex
+@inproceedings{gao2021making,
+   title={Making Pre-trained Language Models Better Few-shot Learners},
+   author={Gao, Tianyu and Fisch, Adam and Chen, Danqi},
+   booktitle={Association for Computational Linguistics (ACL)},
+   year={2021}
+}
+```
diff --git a/LM-BFF/figs/lmbff.png b/LM-BFF/figs/lmbff.png
new file mode 100755
index 0000000..04dac6e
Binary files /dev/null and b/LM-BFF/figs/lmbff.png differ
diff --git a/LM-BFF/nohup.out b/LM-BFF/nohup.out
new file mode 100755
index 0000000..0e744ed
--- /dev/null
+++ b/LM-BFF/nohup.out
@@ -0,0 +1 @@
+
  0%|          | 0/18 [00:00<?, ?it/s]
  6%|▌         | 1/18 [02:41<45:41, 161.24s/it]
\ No newline at end of file
diff --git a/LM-BFF/requirements.txt b/LM-BFF/requirements.txt
new file mode 100755
index 0000000..e91061e
--- /dev/null
+++ b/LM-BFF/requirements.txt
@@ -0,0 +1,37 @@
+certifi==2020.12.5
+chardet==4.0.0
+click==7.1.2
+dataclasses
+filelock==3.0.12
+flake8==3.8.4
+future==0.18.2
+idna==2.10
+importlib-metadata==3.3.0
+joblib==1.0.0
+mccabe==0.6.1
+nltk==3.5
+numpy==1.19.4
+packaging==20.8
+pandas==1.1.5
+protobuf==3.14.0
+pycodestyle==2.6.0
+pyflakes==2.2.0
+pyparsing==2.4.7
+python-dateutil==2.8.1
+pytz==2020.5
+regex==2020.11.13
+requests==2.25.1
+sacremoses==0.0.43
+scikit-learn==0.24.0
+scipy==1.5.4
+sentence-transformers==0.4.0
+sentencepiece==0.1.94
+six==1.15.0
+threadpoolctl==2.1.0
+tokenizers==0.9.2
+torch==1.6.0
+tqdm==4.48.2
+transformers==3.4.0
+typing-extensions==3.7.4.3
+urllib3>=1.26.4
+zipp==3.4.0
diff --git a/LM-BFF/run.py b/LM-BFF/run.py
new file mode 100755
index 0000000..baa73cf
--- /dev/null
+++ b/LM-BFF/run.py
@@ -0,0 +1,628 @@
+"""Finetuning the library models for sequence classification on GLUE."""
+
+import dataclasses
+import logging
+import os
+import sys
+from dataclasses import dataclass, field
+from typing import Callable, Dict, Optional
+import torch
+
+import numpy as np
+
+import transformers
+from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
+from transformers import GlueDataTrainingArguments as DataTrainingArguments
+from transformers import HfArgumentParser, TrainingArguments, set_seed
+
+from src.dataset import FewShotDataset
+from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings
+from src.trainer import Trainer
+from src.processors import processors_mapping, num_labels_mapping, output_modes_mapping, compute_metrics_mapping, bound_mapping
+
+from filelock import FileLock
+from datetime import datetime
+
+from copy import deepcopy
+from tqdm import tqdm
+import json
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ModelArguments:
+    """
+    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
+    """
+    model_name_or_path: str = field(
+        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
+    )
+    config_name: Optional[str] = field(
+        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
+    )
+    tokenizer_name: Optional[str] = field(
+        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
+    )
+    cache_dir: Optional[str] = field(
+        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
+    )
+    # Few-shot type
+    #   - finetune: standard fine-tuning
+    #   - prompt: prompt-based fine-tuning
+    #   - prompt-demo: prompt-based fine-tuning with demonstrations
+    few_shot_type: str = field(
+        default='prompt-demo',
+        metadata={"help": "Few-shot learning model type. Choice: finetune, prompt, prompt-demo"}
+    )
+
+    # Only for BERT-type model
+    random_segment: bool = field(
+        default=False,
+        metadata={"help": "Whether to reinitialize the token type embeddings (only for BERT)."}
+    )
+
+@dataclass
+class DynamicDataTrainingArguments(DataTrainingArguments):
+    """
+    Arguments for dynamic training.
+    """
+    num_k: Optional[int] = field(
+        default=16,
+        metadata={"help": "Number of training instances per class"}
+    )
+
+    num_sample: Optional[int] = field(
+        default=16,
+        metadata={"help": "Number of samples (for inference) in fine-tuning with demonstrations"}
+    )
+
+    num_demo: Optional[int] = field(
+        default=1,
+        metadata={"help": "Number of demonstrations from each class"}
+    )
+
+    auto_demo: bool = field(
+        default=True,
+        metadata={"help": "Automatically generate template for using demonstrations"}
+    )
+
+    # For prompting
+    template: str = field(
+        default=None,
+        metadata={"help": "Template"}
+    )
+
+    mapping: str = field(
+        default=None,
+        metadata={"help": "Label word mapping"}
+    )
+
+    template_path: str = field(
+        default=None,
+        metadata={"help": "Path to a txt file that stores all the templates, one per line. Do not set this when prompt_path is used"}
+    )
+
+    mapping_path: str = field(
+        default=None,
+        metadata={"help": "Path to a txt file that stores all the label word mappings, one per line. Do not set this when prompt_path is used"}
+    )
+
+    prompt_path: str = field(
+        default=None,
+        metadata={"help": "Path to a txt file that stores all the prompts (templates and mappings), one per line"}
+    )
+ 
+    template_id: int = field(
+        default=None,
+        metadata={"help": "Template id if using template_path"}
+    )
+
+    mapping_id: int = field(
+        default=None,
+        metadata={"help": "Mapping id if using template_path"}
+    )
+
+    prompt_id: int = field(
+        default=None,
+        metadata={"help": "Prompt id if using prompt_path"}
+    )
+
+    top_n_template: int = field(
+        default=None,
+        metadata={"help": "Use top-n template in the template path"}
+    )
+
+    # For logging
+    tag: str = field(
+        default='',
+        metadata={"help": "Set the tag and find the result easier in the log."}
+    )
+
+    # For filtering when using demonstrations
+    demo_filter: bool = field(
+        default=False,
+        metadata={"help": "Only use similar instances in demonstrations"}
+    )
+
+    demo_filter_rate: float = field(
+        default=0.5,
+        metadata={"help": "Only use top-x\% similar instances in demonstrations"}
+    )
+
+    demo_filter_model: str = field(
+        default=None,
+        metadata={"help": "Model name for demonstration filter embeddings. Will load embeddings based on the model name."}
+    )
+
+    debug_mode: bool = field(
+        default=False,
+        metadata={"help": "Debug mode"}
+    )
+
+    # For max length
+    double_demo: bool = field(
+        default=False,
+        metadata={"help": "Use double length for using demonstrations"}
+    )
+
+    first_sent_limit: int = field(
+        default=None,
+        metadata={"help": "Limit the length of the first sentence (i.e., sent_0)"}
+    )
+
+    other_sent_limit: int = field(
+        default=None,
+        metadata={"help": "Limit the length of sentences other than the first sentence"}
+    )
+
+    use_full_length: bool = field(
+        default=None,
+        metadata={"help": "Use the full length (512)"}
+    )
+
+    # GPT-3's in-context learning
+    gpt3_in_context_head: bool = field(
+        default=False,
+        metadata={"help": "GPT-3's in-context learning (context at the beginning)"}
+    )
+
+    gpt3_in_context_tail: bool = field(
+        default=False,
+        metadata={"help": "GPT-3's in-context learning (context at the end)"}
+    )
+
+    gpt3_in_context_num: int = field(
+        default=32,
+        metadata={"help": "Number of context examples"}
+    )
+
+    truncate_head: bool = field(
+        default=False,
+        metadata={"help": "When exceeding the maximum length, truncate the head instead of the tail."}
+    )
+
+    # Do not set up the following fields. They are set up automatically.
+    prompt: bool = field(
+        default=False,
+        metadata={"help": "Whether to use prompt-based fine-tuning"}
+    )
+    template_list: list = field(
+        default=None,
+        metadata={"help": "(DO NOT List of templates (only initialized after the program starts."}
+    )
+
+
+@dataclass
+class DynamicTrainingArguments(TrainingArguments):
+    # For ensemble
+    array_id: int = field(
+        default=-1,
+        metadata={"help": "Array ID (contains seed and hyper-paramter search) to idenfity the model"}
+    )
+
+    model_id: int = field(
+        default=-1,
+        metadata={"help": "Model ID (contains template information) to identify the model"}
+    )
+
+    save_logit: bool = field(
+        default=False,
+        metadata={"help": "Save test file logit with name $TASK-$MODEL_ID-$ARRAY_ID.npy"}
+    )
+
+    save_logit_dir: str = field(
+        default=None,
+        metadata={"help": "Where to save the prediction result"}
+    )
+
+    # Regularization
+    fix_layers: int = field(
+        default=0,
+        metadata={"help": "Fix bottom-n layers when optimizing"}
+    )
+
+    # Training
+    save_at_last: bool = field(
+        default=False,
+        metadata={"help": "Instead of saving the best (dev performance) checkpoint, save the last checkpoint"}
+    )
+
+    # Turn off train/test
+    no_train: bool = field(
+        default=False,
+        metadata={"help": "No training"}
+    )
+    no_predict: bool = field(
+        default=False,
+        metadata={"help": "No test"}
+    )
+
+
+def main():
+    parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments, DynamicTrainingArguments))
+
+    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+        # If we pass only one argument to the script and it's the path to a json file,
+        # let's parse it to get our arguments.
+        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
+    else:
+        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+    if 'prompt' in model_args.few_shot_type:
+        data_args.prompt = True
+
+    if training_args.no_train:
+        training_args.do_train = False
+    if training_args.no_predict:
+        training_args.do_predict = False
+
+    # Setup logging
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
+    )
+
+    # Load prompt/template/mapping file
+    if data_args.prompt:
+        if data_args.prompt_path is not None:
+            assert data_args.prompt_id is not None
+            prompt_list = []
+            with open(data_args.prompt_path) as f:
+                for line in f:
+                    line = line.strip()
+                    template, mapping = line.split('\t')
+                    prompt_list.append((template, mapping))
+
+            data_args.template, data_args.mapping = prompt_list[data_args.prompt_id] 
+            logger.info("Specify load the %d-th prompt: %s | %s" % (data_args.prompt_id, data_args.template, data_args.mapping))
+        else:
+            if data_args.template_path is not None:
+                with open(data_args.template_path) as f:
+                    data_args.template_list = []
+                    for line in f:
+                        line = line.strip()
+                        if len(line) > 0:
+                            data_args.template_list.append(line)
+
+                # Load top-n templates
+                if data_args.top_n_template is not None:
+                    data_args.template_list = data_args.template_list[:data_args.top_n_template]
+                logger.info("Load top-%d templates from %s" % (len(data_args.template_list), data_args.template_path))
+
+                # ... or load i-th template
+                if data_args.template_id is not None:
+                    data_args.template = data_args.template_list[data_args.template_id]
+                    data_args.template_list = None
+                    logger.info("Specify load the %d-th template: %s" % (data_args.template_id, data_args.template))
+
+            if data_args.mapping_path is not None:
+                assert data_args.mapping_id is not None # Only can use one label word mapping
+                with open(data_args.mapping_path) as f:
+                    mapping_list = []
+                    for line in f:
+                        line = line.strip()
+                        mapping_list.append(line)
+
+                data_args.mapping = mapping_list[data_args.mapping_id]
+                logger.info("Specify using the %d-th mapping: %s" % (data_args.mapping_id, data_args.mapping))
+
+    # Check save path
+    if (
+        os.path.exists(training_args.output_dir)
+        and os.listdir(training_args.output_dir)
+        and training_args.do_train
+        and not training_args.overwrite_output_dir
+    ):
+        raise ValueError(f"Output directory ({training_args.output_dir}) already exists.")
+
+    logger.warning(
+        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
+        training_args.local_rank,
+        training_args.device,
+        training_args.n_gpu,
+        bool(training_args.local_rank != -1),
+        training_args.fp16,
+    )
+    logger.info("Training/evaluation parameters %s", training_args)
+
+    # Set seed
+    set_seed(training_args.seed)
+
+    try:
+        num_labels = num_labels_mapping[data_args.task_name]
+        output_mode = output_modes_mapping[data_args.task_name]
+        logger.info("Task name: {}, number of labels: {}, output mode: {}".format(data_args.task_name, num_labels, output_mode))
+    except KeyError:
+        raise ValueError("Task not found: %s" % (data_args.task_name))
+
+    # Automatically generate template for using demonstrations
+    if data_args.auto_demo and model_args.few_shot_type == 'prompt-demo':
+        # GPT-3's in-context learning
+        if data_args.gpt3_in_context_head or data_args.gpt3_in_context_tail: 
+            logger.info("Automatically convert the template to GPT-3's in-context learning.")
+            assert data_args.template_list is None
+
+            old_template = data_args.template
+            new_template = old_template + ''
+            old_template = old_template.replace('*cls*', '')
+            # Single sentence or sentence pair?
+            sent_num = 1
+            if "_1" in old_template:
+                sent_num = 2
+            for instance_id in range(data_args.gpt3_in_context_num):
+                sub_template = old_template + ''
+                # Replace sent_id
+                for sent_id in range(sent_num):
+                    sub_template = sub_template.replace("_{}*".format(sent_id), "_{}*".format(sent_num + sent_num * instance_id + sent_id))
+                # Replace mask
+                sub_template = sub_template.replace("*mask*", "*labelx_{}*".format(instance_id))
+                if data_args.gpt3_in_context_tail:
+                    new_template = new_template + sub_template # Put context at the end
+                else:
+                    new_template = sub_template + new_template # Put context at the beginning
+            logger.info("| {} => {}".format(data_args.template, new_template))
+            data_args.template = new_template
+        else:
+            logger.info("Automatically convert the template to using demonstrations.")
+            if data_args.template_list is not None:
+                for i in range(len(data_args.template_list)):
+                    old_template = data_args.template_list[i]
+                    new_template = old_template + ''
+                    old_template = old_template.replace('*cls*', '')
+                    # Single sentence or sentence pair?
+                    sent_num = 1
+                    if "_1" in old_template:
+                        sent_num = 2
+                    for label_id in range(num_labels):
+                        sub_template = old_template + ''
+                        # Replace sent id
+                        for sent_id in range(sent_num):
+                            sub_template = sub_template.replace("_{}*".format(sent_id), "_{}*".format(sent_num + sent_num * label_id + sent_id))
+                        # Replace mask
+                        sub_template = sub_template.replace("*mask*", "*label_{}*".format(label_id))
+                        new_template = new_template + sub_template
+                    logger.info("| {} => {}".format(data_args.template_list[i], new_template))
+                    data_args.template_list[i] = new_template
+            else:
+                old_template = data_args.template
+                new_template = old_template + ''
+                old_template = old_template.replace('*cls*', '')
+                # Single sentence or sentence pair?
+                sent_num = 1
+                if "_1" in old_template:
+                    sent_num = 2
+                for label_id in range(num_labels):
+                    sub_template = old_template + ''
+                    # Replace sent id
+                    for sent_id in range(sent_num):
+                        sub_template = sub_template.replace("_{}".format(sent_id), "_{}".format(sent_num + sent_num * label_id + sent_id))
+                    # Replace mask
+                    sub_template = sub_template.replace("*mask*", "*label_{}*".format(label_id))
+                    new_template = new_template + sub_template
+                logger.info("| {} => {}".format(data_args.template, new_template))
+                data_args.template = new_template
+
+    # Create config
+    config = AutoConfig.from_pretrained(
+        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
+        num_labels=num_labels,
+        finetuning_task=data_args.task_name,
+        cache_dir=model_args.cache_dir,
+    )
+
+    if 'prompt' in model_args.few_shot_type:
+        if config.model_type == 'roberta':
+            model_fn = RobertaForPromptFinetuning
+        elif config.model_type == 'bert':
+            model_fn = BertForPromptFinetuning
+        else:
+            raise NotImplementedError
+    elif model_args.few_shot_type == 'finetune':
+        model_fn = AutoModelForSequenceClassification
+    else:
+        raise NotImplementedError
+    special_tokens = []
+
+    # Create tokenizer
+    tokenizer = AutoTokenizer.from_pretrained(
+        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
+        additional_special_tokens=special_tokens,
+        cache_dir=model_args.cache_dir,
+    )
+
+    # Get our special datasets.
+    train_dataset = (
+        FewShotDataset(data_args, tokenizer=tokenizer, mode="train", use_demo=("demo" in model_args.few_shot_type))
+    )
+    eval_dataset = (
+        FewShotDataset(data_args, tokenizer=tokenizer, mode="dev", use_demo=("demo" in model_args.few_shot_type))
+        if training_args.do_eval
+        else None
+    )
+    test_dataset = (
+        FewShotDataset(data_args, tokenizer=tokenizer, mode="test", use_demo=("demo" in model_args.few_shot_type))
+        if training_args.do_predict
+        else None
+    )
+
+    set_seed(training_args.seed)
+
+    model = model_fn.from_pretrained(
+        model_args.model_name_or_path,
+        from_tf=bool(".ckpt" in model_args.model_name_or_path),
+        config=config,
+        cache_dir=model_args.cache_dir,
+    )
+
+    # For BERT, increase the size of the segment (token type) embeddings
+    if config.model_type == 'bert':
+        model.resize_token_embeddings(len(tokenizer))
+        resize_token_type_embeddings(model, new_num_types=10, random_segment=model_args.random_segment)
+
+    # Pass dataset and argument information to the model
+    if data_args.prompt:
+        model.label_word_list = torch.tensor(train_dataset.label_word_list).long().cuda()
+    if output_modes_mapping[data_args.task_name] == 'regression':
+        # lower / upper bounds
+        model.lb, model.ub = bound_mapping[data_args.task_name]
+    model.model_args = model_args
+    model.data_args = data_args
+    model.tokenizer = tokenizer
+
+    # Build metric
+    def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
+        def compute_metrics_fn(p: EvalPrediction):
+            # Note: the eval dataloader is sequential, so the examples are in order.
+            # We average the logits over each sample for using demonstrations.
+            predictions = p.predictions
+            num_logits = predictions.shape[-1]
+            logits = predictions.reshape([eval_dataset.num_sample, -1, num_logits])
+            logits = logits.mean(axis=0)
+            
+            if num_logits == 1:
+                preds = np.squeeze(logits)
+            else:
+                preds = np.argmax(logits, axis=1)
+
+            # Just for sanity, assert label ids are the same.
+            label_ids = p.label_ids.reshape([eval_dataset.num_sample, -1])
+            label_ids_avg = label_ids.mean(axis=0)
+            label_ids_avg = label_ids_avg.astype(p.label_ids.dtype)
+            assert (label_ids_avg - label_ids[0]).mean() < 1e-2
+            label_ids = label_ids[0]
+
+            return compute_metrics_mapping[task_name](task_name, preds, label_ids)
+
+        return compute_metrics_fn
+    
+    # Initialize our Trainer
+    trainer = Trainer(
+        model=model,
+        args=training_args,
+        train_dataset=train_dataset,
+        eval_dataset=eval_dataset,
+        compute_metrics=build_compute_metrics_fn(data_args.task_name)
+    )
+
+    # Training
+    if training_args.do_train:
+        trainer.train(model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None)
+        # Use the early stop, so do not save the model in the end (unless specify save_at_last)
+        if training_args.save_at_last:
+            trainer.save_model(training_args.output_dir)
+ 
+        if trainer.is_world_master():
+            tokenizer.save_pretrained(training_args.output_dir)
+            torch.save(model_args, os.path.join(training_args.output_dir, "model_args.bin"))
+            torch.save(data_args, os.path.join(training_args.output_dir, "data_args.bin"))
+        
+        # Reload the best checkpoint (for eval)
+        model = model_fn.from_pretrained(training_args.output_dir)
+        model = model.to(training_args.device)
+        trainer.model = model
+        if data_args.prompt:
+            model.label_word_list = torch.tensor(train_dataset.label_word_list).long().cuda()
+        if output_modes_mapping[data_args.task_name] == 'regression':
+            # lower / upper bounds
+            model.lb, model.ub = bound_mapping[data_args.task_name]
+        model.model_args = model_args
+        model.data_args = data_args
+        model.tokenizer = tokenizer
+
+    # Evaluation
+    final_result = {
+        'time': str(datetime.today()),
+    }
+
+    eval_results = {}
+    if training_args.do_eval:
+        logger.info("*** Validate ***")
+
+        eval_datasets = [eval_dataset]
+
+        for eval_dataset in eval_datasets:
+            trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name)
+            output = trainer.evaluate(eval_dataset=eval_dataset)
+            eval_result = output.metrics 
+
+            output_eval_file = os.path.join(
+                training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
+            )
+            if trainer.is_world_master():
+                with open(output_eval_file, "w") as writer:
+                    logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
+                    for key, value in eval_result.items():
+                        logger.info("  %s = %s", key, value)
+                        writer.write("%s = %s\n" % (key, value))
+                        final_result[eval_dataset.args.task_name + '_dev_' + key] = value
+            eval_results.update(eval_result)
+
+    test_results = {}
+    if training_args.do_predict:
+        logging.info("*** Test ***")
+        test_datasets = [test_dataset]
+        if data_args.task_name == "mnli":
+            mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
+            test_datasets.append(
+                FewShotDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", use_demo=('demo' in model_args.few_shot_type))
+            )
+
+        for test_dataset in test_datasets:
+            trainer.compute_metrics = build_compute_metrics_fn(test_dataset.args.task_name)
+            output = trainer.evaluate(eval_dataset=test_dataset)
+            test_result = output.metrics
+
+            output_test_file = os.path.join(
+                training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt"
+            )
+            if trainer.is_world_master():
+                with open(output_test_file, "w") as writer:
+                    logger.info("***** Test results {} *****".format(test_dataset.args.task_name))
+                    for key, value in test_result.items():
+                        logger.info("  %s = %s", key, value)
+                        writer.write("%s = %s\n" % (key, value))
+                        final_result[test_dataset.args.task_name + '_test_' + key] = value
+
+                if training_args.save_logit:
+                    predictions = output.predictions
+                    num_logits = predictions.shape[-1]
+                    logits = predictions.reshape([test_dataset.num_sample, -1, num_logits]).mean(axis=0)
+                    np.save(os.path.join(training_args.save_logit_dir, "{}-{}-{}.npy".format(test_dataset.task_name, training_args.model_id, training_args.array_id)), logits)
+
+            test_results.update(test_result)
+
+    with FileLock('log.lock'):
+        with open('log', 'a') as f:
+            final_result.update(vars(model_args))
+            final_result.update(vars(training_args))
+            final_result.update(vars(data_args))
+            if 'evaluation_strategy' in final_result:
+                final_result.pop('evaluation_strategy')
+            f.write(str(final_result) + '\n')
+    
+    return eval_results
+
+if __name__ == "__main__":
+    main()
diff --git a/LM-BFF/src/dataset.py b/LM-BFF/src/dataset.py
new file mode 100755
index 0000000..4617608
--- /dev/null
+++ b/LM-BFF/src/dataset.py
@@ -0,0 +1,658 @@
+"""Dataset utils for different data settings for GLUE."""
+
+import os
+import copy
+import logging
+import torch
+import numpy as np
+import time
+from filelock import FileLock
+import json
+import itertools
+import random
+import transformers
+from src.processors import processors_mapping, num_labels_mapping, output_modes_mapping, compute_metrics_mapping, median_mapping
+from transformers.data.processors.utils import InputFeatures
+from transformers import DataProcessor, InputExample
+import dataclasses
+from dataclasses import dataclass
+from typing import List, Optional, Union
+from sentence_transformers import SentenceTransformer, util
+from copy import deepcopy
+import pandas as pd
+
+logger = logging.getLogger(__name__)
+
+@dataclass(frozen=True)
+class OurInputFeatures(InputFeatures):
+    """
+    Inherit from Transformers' InputFeatuers.
+    """
+
+    input_ids: List[int]
+    attention_mask: Optional[List[int]] = None
+    token_type_ids: Optional[List[int]] = None
+    label: Optional[Union[int, float]] = None
+    mask_pos: Optional[List[int]] = None # Position of the mask token
+    label_word_list: Optional[List[int]] = None # Label word mapping (dynamic)
+
+    def to_json_string(self):
+        """Serializes this instance to a JSON string."""
+        return json.dumps(dataclasses.asdict(self)) + "\n"
+
+def input_example_to_string(example, sep_token): 
+    if example.text_b is None:
+        return example.text_a
+    else:
+        # Warning: very simple hack here
+        return example.text_a + ' ' + sep_token + ' ' + example.text_b
+
+def input_example_to_tuple(example): 
+    if example.text_b is None:
+        if pd.isna(example.text_a) or example.text_a is None:
+            return ['']
+            logger.warn("Empty input")
+        else:
+            return [example.text_a]
+    else:
+        return [example.text_a, example.text_b]
+
+def tokenize_multipart_input(
+    input_text_list, 
+    max_length, 
+    tokenizer, 
+    task_name=None, 
+    prompt=False, 
+    template=None,
+    label_word_list=None, 
+    first_sent_limit=None,
+    other_sent_limit=None,
+    gpt3=False,
+    truncate_head=False,
+    support_labels=None,
+):
+    def enc(text):
+        return tokenizer.encode(text, add_special_tokens=False)
+
+    input_ids = []
+    attention_mask = []
+    token_type_ids = [] # Only for BERT
+    mask_pos = None # Position of the mask token
+
+    if prompt:
+        """
+        Concatenate all sentences and prompts based on the provided template.
+        Template example: '*cls*It was*mask*.*sent_0**<sep>*label_0:*sent_1**<sep>**label_1*:*sent_2**<sep>*'
+        *xx* represent variables:
+            *cls*: cls_token
+            *mask*: mask_token
+            *sep*: sep_token
+            *sep+*: sep_token, also means +1 for segment id
+            *sent_i*: sentence i (input_text_list[i])
+            *sent-_i*: same as above, but delete the last token
+            *sentl_i*: same as above, but use lower case for the first word
+            *sentl-_i*: same as above, but use lower case for the first word and delete the last token
+            *+sent_i*: same as above, but add a space before the sentence
+            *+sentl_i*: same as above, but add a space before the sentence and use lower case for the first word
+            *label_i*: label_word_list[i]
+            *label_x*: label depends on the example id (support_labels needed). this is only used in GPT-3's in-context learning
+
+        Use "_" to replace space.
+        PAY ATTENTION TO SPACE!! DO NOT leave space before variables, for this will lead to extra space token.
+        """
+        assert template is not None
+
+        special_token_mapping = {
+            'cls': tokenizer.cls_token_id, 'mask': tokenizer.mask_token_id, 'sep': tokenizer.sep_token_id, 'sep+': tokenizer.sep_token_id, 
+        }
+        template_list = template.split('*') # Get variable list in the template
+        segment_id = 0 # Current segment id. Segment id +1 if encountering sep+.
+
+        for part_id, part in enumerate(template_list):
+            new_tokens = []
+            segment_plus_1_flag = False
+            if part in special_token_mapping:
+                if part == 'cls' and 'T5' in type(tokenizer).__name__:
+                    # T5 does not have cls token
+                    continue
+                new_tokens.append(special_token_mapping[part])
+                if part == 'sep+':
+                    segment_plus_1_flag = True
+            elif part[:6] == 'label_':
+                # Note that label_word_list already has extra space, so do not add more space ahead of it.
+                label_id = int(part.split('_')[1])
+                label_word = label_word_list[label_id]
+                new_tokens.append(label_word)
+            elif part[:7] == 'labelx_':
+                instance_id = int(part.split('_')[1])
+                label_id = support_labels[instance_id]
+                label_word = label_word_list[label_id]
+                new_tokens.append(label_word)
+            elif part[:5] == 'sent_':
+                sent_id = int(part.split('_')[1])
+                new_tokens += enc(input_text_list[sent_id]) 
+            elif part[:6] == '+sent_':
+                # Add space
+                sent_id = int(part.split('_')[1])
+                new_tokens += enc(' ' + input_text_list[sent_id])
+            elif part[:6] == 'sent-_':
+                # Delete the last token
+                sent_id = int(part.split('_')[1])
+                new_tokens += enc(input_text_list[sent_id][:-1])
+            elif part[:6] == 'sentl_':
+                # Lower case the first token
+                sent_id = int(part.split('_')[1])
+                text = input_text_list[sent_id]
+                text = text[:1].lower() + text[1:]
+                new_tokens += enc(text)
+            elif part[:7] == '+sentl_':
+                # Lower case the first token and add space 
+                sent_id = int(part.split('_')[1])
+                text = input_text_list[sent_id]
+                text = text[:1].lower() + text[1:]
+                new_tokens += enc(' ' + text)
+            elif part[:7] == 'sentl-_':
+                # Lower case the first token and discard the last token
+                sent_id = int(part.split('_')[1])
+                text = input_text_list[sent_id]
+                text = text[:1].lower() + text[1:]
+                new_tokens += enc(text[:-1])
+            elif part[:6] == 'sentu_':
+                # Upper case the first token
+                sent_id = int(part.split('_')[1])
+                text = input_text_list[sent_id]
+                text = text[:1].upper() + text[1:]
+                new_tokens += enc(text)
+            elif part[:7] == '+sentu_':
+                # Upper case the first token and add space
+                sent_id = int(part.split('_')[1])
+                text = input_text_list[sent_id]
+                text = text[:1].upper() + text[1:]
+                new_tokens += enc(' ' + text)
+            else:
+                # Just natural language prompt
+                part = part.replace('_', ' ') 
+                # handle special case when T5 tokenizer might add an extra space
+                if len(part) == 1:
+                    new_tokens.append(tokenizer._convert_token_to_id(part))
+                else:
+                    new_tokens += enc(part)
+
+            if part[:4] == 'sent' or part[1:5] == 'sent':
+                # If this part is the sentence, limit the sentence length
+                sent_id = int(part.split('_')[1])
+                if sent_id == 0:
+                    if first_sent_limit is not None:
+                        new_tokens = new_tokens[:first_sent_limit]
+                else:
+                    if other_sent_limit is not None:
+                        new_tokens = new_tokens[:other_sent_limit]
+
+            input_ids += new_tokens
+            attention_mask += [1 for i in range(len(new_tokens))]
+            token_type_ids += [segment_id for i in range(len(new_tokens))]
+
+            if segment_plus_1_flag:
+                segment_id += 1
+    else:
+        input_ids = [tokenizer.cls_token_id]
+        attention_mask = [1]
+        token_type_ids = [0]
+
+        for sent_id, input_text in enumerate(input_text_list):
+            if input_text is None:
+                # Do not have text_b
+                continue
+            if pd.isna(input_text) or input_text is None:
+                # Empty input
+                input_text = ''
+            input_tokens = enc(input_text) + [tokenizer.sep_token_id]
+            input_ids += input_tokens
+            attention_mask += [1 for i in range(len(input_tokens))]
+            token_type_ids += [sent_id for i in range(len(input_tokens))]
+
+        if 'T5' in type(tokenizer).__name__: # T5 does not have CLS token
+            input_ids = input_ids[1:]
+            attention_mask = attention_mask[1:]
+            token_type_ids = token_type_ids[1:]
+
+    # Padding
+    if first_sent_limit is not None and len(input_ids) > max_length:
+        # If using sentence limit, the total length still exceeds the maximum limit, report a warning
+        logger.warn("Input exceeds max_length limit: {}".format(tokenizer.decode(input_ids)))
+
+    while len(input_ids) < max_length:
+        input_ids.append(tokenizer.pad_token_id)
+        attention_mask.append(0)
+        token_type_ids.append(0)
+
+    # Truncate
+    if len(input_ids) > max_length:
+        if truncate_head:
+            input_ids = input_ids[-max_length:]
+            attention_mask = attention_mask[-max_length:]
+            token_type_ids = token_type_ids[-max_length:]
+        else:
+            # Default is to truncate the tail
+            input_ids = input_ids[:max_length]
+            attention_mask = attention_mask[:max_length]
+            token_type_ids = token_type_ids[:max_length]
+
+    # Find mask token
+    if prompt:
+        mask_pos = [input_ids.index(tokenizer.mask_token_id)]
+        # Make sure that the masked position is inside the max_length
+        assert mask_pos[0] < max_length
+
+    result = {'input_ids': input_ids, 'attention_mask': attention_mask}
+    if 'BERT' in type(tokenizer).__name__:
+        # Only provide token type ids for BERT
+        result['token_type_ids'] = token_type_ids
+
+    if prompt:
+        result['mask_pos'] = mask_pos
+
+    return result
+
+
+
+class FewShotDataset(torch.utils.data.Dataset):
+    """Few-shot dataset."""
+
+    def __init__(self, args, tokenizer, cache_dir=None, mode="train", use_demo=False):
+        self.args = args
+        self.task_name = args.task_name
+        self.processor = processors_mapping[args.task_name]
+        self.tokenizer = tokenizer
+        self.mode = mode
+
+        # If not using demonstrations, use use_demo=True
+        self.use_demo = use_demo
+        if self.use_demo:
+            logger.info("Use demonstrations")
+        assert mode in ["train", "dev", "test"]
+
+        # Get label list and (for prompt) label word list
+        self.label_list = self.processor.get_labels()
+        self.num_labels = len(self.label_list)
+        if args.prompt:
+            assert args.mapping is not None
+            self.label_to_word = eval(args.mapping)
+
+            for key in self.label_to_word:
+                # For RoBERTa/BART/T5, tokenization also considers space, so we use space+word as label words.
+                if self.label_to_word[key][0] not in ['<', '[', '.', ',']:
+                    # Make sure space+word is in the vocabulary
+                    assert len(tokenizer.tokenize(' ' + self.label_to_word[key])) == 1
+                    self.label_to_word[key] = tokenizer._convert_token_to_id(tokenizer.tokenize(' ' + self.label_to_word[key])[0])
+                else:
+                    self.label_to_word[key] = tokenizer._convert_token_to_id(self.label_to_word[key])
+                logger.info("Label {} to word {} ({})".format(key, tokenizer._convert_id_to_token(self.label_to_word[key]), self.label_to_word[key]))
+            
+            if len(self.label_list) > 1:
+                self.label_word_list = [self.label_to_word[label] for label in self.label_list]
+            else:
+                # Regression task
+                # '0' represents low polarity and '1' represents high polarity.
+                self.label_word_list = [self.label_to_word[label] for label in ['0', '1']]
+        else:
+            self.label_to_word = None
+            self.label_word_list = None
+
+        # Multiple sampling: when using demonstrations, we sample different combinations of demonstrations during 
+        # inference and aggregate the results by averaging the logits. The number of different samples is num_sample.
+        if (mode == "train") or not self.use_demo:
+            # We do not do multiple sampling when not using demonstrations or when it's the training mode 
+            self.num_sample = 1
+        else:
+            self.num_sample = args.num_sample
+
+        # If we use multiple templates, we also need to do multiple sampling during inference.
+        if args.prompt and args.template_list is not None:
+            logger.info("There are %d templates. Multiply num_sample by %d" % (len(args.template_list), len(args.template_list)))
+            self.num_sample *= len(args.template_list)
+                
+        logger.info("Total num_sample for mode %s: %d" % (mode, self.num_sample))
+
+        # Load cache
+        # Cache name distinguishes mode, task name, tokenizer, and length. So if you change anything beyond these elements, make sure to clear your cache.
+        cached_features_file = os.path.join(
+            cache_dir if cache_dir is not None else args.data_dir,
+            "cached_{}_{}_{}_{}".format(
+                mode,
+                tokenizer.__class__.__name__,
+                str(args.max_seq_length),
+                args.task_name,
+            ),
+        )
+
+        logger.info(f"Creating/loading examples from dataset file at {args.data_dir}")
+
+        lock_path = cached_features_file + ".lock"
+        with FileLock(lock_path):
+
+            if os.path.exists(cached_features_file) and not args.overwrite_cache:
+                start = time.time()
+                self.support_examples, self.query_examples = torch.load(cached_features_file)
+                logger.info(
+                    f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
+                )
+            else:
+                logger.info(f"Creating features from dataset file at {args.data_dir}")
+
+                # The support examples are sourced from the training set.
+                self.support_examples = self.processor.get_train_examples(args.data_dir)
+
+                if mode == "dev":
+                    self.query_examples = self.processor.get_dev_examples(args.data_dir)
+                elif mode == "test":
+                    self.query_examples = self.processor.get_test_examples(args.data_dir)
+                else:
+                    self.query_examples = self.support_examples
+
+                start = time.time()
+                torch.save([self.support_examples, self.query_examples], cached_features_file)
+                # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
+                logger.info(
+                    "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
+                )
+
+        # For filtering in using demonstrations, load pre-calculated embeddings
+        if self.use_demo and args.demo_filter:
+            split_name = ''
+            if mode == 'train':
+                split_name = 'train'
+            elif mode == 'dev':
+                if args.task_name == 'mnli':
+                    split_name = 'dev_matched'
+                elif args.task_name == 'mnli-mm':
+                    split_name = 'dev_mismatched'
+                else:
+                    split_name = 'dev'
+            elif mode == 'test':
+                if args.task_name == 'mnli':
+                    split_name = 'test_matched'
+                elif args.task_name == 'mnli-mm':
+                    split_name = 'test_mismatched'
+                else:
+                    split_name = 'test'
+            else:
+                raise NotImplementedError
+
+            self.support_emb = np.load(os.path.join(args.data_dir, "train_{}.npy".format(args.demo_filter_model)))
+            self.query_emb = np.load(os.path.join(args.data_dir, "{}_{}.npy".format(split_name, args.demo_filter_model)))
+            logger.info("Load embeddings (for demonstration filtering) from {}".format(os.path.join(args.data_dir, "{}_{}.npy".format(split_name, args.demo_filter_model))))
+
+            assert len(self.support_emb) == len(self.support_examples)
+            assert len(self.query_emb) == len(self.query_examples)
+ 
+        # Size is expanded by num_sample
+        self.size = len(self.query_examples) * self.num_sample
+        
+        # Prepare examples (especially for using demonstrations)
+        support_indices = list(range(len(self.support_examples)))
+        self.example_idx = []
+        for sample_idx in range(self.num_sample):
+            for query_idx in range(len(self.query_examples)):
+                # If training, exclude the current example. Else keep all.
+                if self.use_demo and args.demo_filter:
+                    # Demonstration filtering
+                    candidate = [support_idx for support_idx in support_indices
+                                   if support_idx != query_idx or mode != "train"]
+                    sim_score = []
+                    for support_idx in candidate:
+                        sim_score.append((support_idx, util.pytorch_cos_sim(self.support_emb[support_idx], self.query_emb[query_idx])))
+                    sim_score.sort(key=lambda x: x[1], reverse=True)
+                    if self.num_labels == 1:
+                        # Regression task
+                        limit_each_label = int(len(sim_score) // 2 * args.demo_filter_rate)
+                        count_each_label = {'0': 0, '1': 0}
+                        context_indices = []
+
+                        if args.debug_mode:
+                            print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug
+                        for support_idx, score in sim_score:
+                            if count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] < limit_each_label:
+                                count_each_label['0' if float(self.support_examples[support_idx].label) <= median_mapping[args.task_name] else '1'] += 1
+                                context_indices.append(support_idx)
+                                if args.debug_mode:
+                                    print("    %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug
+                    else:
+                        limit_each_label = int(len(sim_score) // self.num_labels * args.demo_filter_rate)
+                        count_each_label = {label: 0 for label in self.label_list}
+                        context_indices = []
+
+                        if args.debug_mode:
+                            print("Query %s: %s" % (self.query_examples[query_idx].label, self.query_examples[query_idx].text_a)) # debug
+                        for support_idx, score in sim_score:
+                            if count_each_label[self.support_examples[support_idx].label] < limit_each_label:
+                                count_each_label[self.support_examples[support_idx].label] += 1
+                                context_indices.append(support_idx)
+                                if args.debug_mode:
+                                    print("    %.4f %s | %s" % (score, self.support_examples[support_idx].label, self.support_examples[support_idx].text_a)) # debug
+                else:
+                    # Using demonstrations without filtering
+                    context_indices = [support_idx for support_idx in support_indices
+                               if support_idx != query_idx or mode != "train"]
+
+                # We'll subsample context_indices further later.
+                self.example_idx.append((query_idx, context_indices, sample_idx))
+
+        # If it is not training, we pre-process the data; otherwise, we process the data online.
+        if mode != "train":
+            self.features = []
+            _ = 0
+            for query_idx, context_indices, bootstrap_idx in self.example_idx:
+                # The input (query) example
+                example = self.query_examples[query_idx]
+                # The demonstrations
+                supports = self.select_context([self.support_examples[i] for i in context_indices])
+
+                if args.template_list is not None:
+                    template = args.template_list[sample_idx % len(args.template_list)] # Use template in order
+                else:
+                    template = args.template
+
+                self.features.append(self.convert_fn(
+                    example=example,
+                    supports=supports,
+                    use_demo=self.use_demo,
+                    label_list=self.label_list,
+                    prompt=args.prompt,
+                    template=template,
+                    label_word_list=self.label_word_list,
+                    verbose=True if _ == 0 else False,
+                ))
+
+                _ += 1
+        else:
+            self.features = None
+
+    def select_context(self, context_examples):
+        """
+        Select demonstrations from provided examples.
+        """
+        max_demo_per_label = 1
+        counts = {k: 0 for k in self.label_list}
+        if len(self.label_list) == 1:
+            # Regression
+            counts = {'0': 0, '1': 0}
+        selection = []
+
+        if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail:
+            # For GPT-3's in-context learning, we sample gpt3_in_context_num demonstrations randomly. 
+            order = np.random.permutation(len(context_examples))
+            for i in range(min(self.args.gpt3_in_context_num, len(order))):
+                selection.append(context_examples[order[i]])
+        else:
+            # Our sampling strategy
+            order = np.random.permutation(len(context_examples))
+
+            for i in order:
+                label = context_examples[i].label
+                if len(self.label_list) == 1:
+                    # Regression
+                    label = '0' if float(label) <= median_mapping[self.args.task_name] else '1'
+                if counts[label] < max_demo_per_label:
+                    selection.append(context_examples[i])
+                    counts[label] += 1
+                if sum(counts.values()) == len(counts) * max_demo_per_label:
+                    break
+        
+            assert len(selection) > 0
+        
+        return selection
+
+    def __len__(self):
+        return self.size
+
+    def __getitem__(self, i):
+        if self.features is None:
+            query_idx, context_indices, bootstrap_idx = self.example_idx[i]
+            # The input (query) example
+            example = self.query_examples[query_idx]
+            # The demonstrations
+            supports = self.select_context([self.support_examples[i] for i in context_indices])
+
+            if self.args.template_list is not None:
+                template = self.args.template_list[sample_idx % len(self.args.template_list)]
+            else:
+                template = self.args.template
+
+            features = self.convert_fn(
+                example=example,
+                supports=supports,
+                use_demo=self.use_demo,
+                label_list=self.label_list,
+                prompt=self.args.prompt,
+                template=template,
+                label_word_list=self.label_word_list,
+                verbose=False,
+            )
+        else:
+            features = self.features[i]
+            
+        return features
+
+    def get_labels(self):
+        return self.label_list
+
+
+    def convert_fn(
+        self,
+        example,
+        supports,
+        use_demo=False,
+        label_list=None,
+        prompt=False,
+        template=None,
+        label_word_list=None,
+        verbose=False
+    ):
+        """
+        Returns a list of processed "InputFeatures".
+        """
+        max_length = self.args.max_seq_length    
+
+        # Prepare labels
+        label_map = {label: i for i, label in enumerate(label_list)} # Mapping the label names to label ids
+        if len(label_list) == 1:
+            # Regression
+            label_map = {'0': 0, '1': 1}
+
+        # Get example's label id (for training/inference)
+        if example.label is None:
+            example_label = None
+        elif len(label_list) == 1:
+            # Regerssion
+            example_label = float(example.label)
+        else:
+            example_label = label_map[example.label]
+
+        # Prepare other features
+        if not use_demo:
+            # No using demonstrations
+            inputs = tokenize_multipart_input(
+                input_text_list=input_example_to_tuple(example),
+                max_length=max_length,
+                tokenizer=self.tokenizer,
+                task_name=self.args.task_name,
+                prompt=prompt,
+                template=template,
+                label_word_list=label_word_list,
+                first_sent_limit=self.args.first_sent_limit,
+                other_sent_limit=self.args.other_sent_limit,
+            )
+            features = OurInputFeatures(**inputs, label=example_label)
+
+        else:
+            # Using demonstrations
+
+            # Max length
+            if self.args.double_demo:
+                # When using demonstrations, double the maximum length
+                # Note that in this case, args.max_seq_length is the maximum length for a single sentence
+                max_length = max_length * 2
+            if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail:
+                # When using GPT-3's in-context learning, take the maximum tokenization length of the model (512)
+                max_length = 512
+
+            # All input sentences, including the query and the demonstrations, are put into augmented_examples, 
+            # and are numbered based on the order (starting from 0). For single sentence tasks, the input (query)
+            # is the sentence 0; for sentence-pair tasks, the input (query) is the sentence 0 and 1. Note that for GPT-3's 
+            # in-context learning, the input (query) might be at the end instead of the beginning (gpt3_in_context_head)
+            augmented_example = []
+            query_text = input_example_to_tuple(example) # Input sentence list for query
+            support_by_label = [[] for i in range(len(label_map))]
+
+            if self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail:
+                support_labels = []
+                augmented_example = query_text
+                for support_example in supports:
+                    augmented_example += input_example_to_tuple(support_example)
+                    current_label = support_example.label
+                    if len(label_list) == 1:
+                        current_label = '0' if float(current_label) <= median_mapping[self.args.task_name] else '1' # Regression
+                    support_labels.append(label_map[current_label])
+            else:
+                # Group support examples by label
+                for label_name, label_id in label_map.items():
+                    if len(label_list) == 1:
+                        # Regression
+                        for support_example in filter(lambda s: ('0' if float(s.label) <= median_mapping[self.args.task_name] else '1') == label_name, supports):
+                            support_by_label[label_id] += input_example_to_tuple(support_example)
+                    else:
+                        for support_example in filter(lambda s: s.label == label_name, supports):
+                            support_by_label[label_id] += input_example_to_tuple(support_example)
+
+                augmented_example = query_text
+                for label_id in range(len(label_map)):
+                    augmented_example += support_by_label[label_id]
+
+            # Tokenization (based on the template)
+            inputs = tokenize_multipart_input(
+                input_text_list=augmented_example,
+                max_length=max_length,
+                tokenizer=self.tokenizer,
+                task_name=self.args.task_name,
+                prompt=prompt,
+                template=template,
+                label_word_list=label_word_list,
+                first_sent_limit=self.args.first_sent_limit,
+                other_sent_limit=self.args.other_sent_limit,
+                truncate_head=self.args.truncate_head,
+                gpt3=self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail,
+                support_labels=None if not (self.args.gpt3_in_context_head or self.args.gpt3_in_context_tail) else support_labels
+            )
+            features = OurInputFeatures(**inputs, label=example_label)
+
+        if verbose:
+            logger.info("*** Example ***")
+            logger.info("guid: %s" % (example.guid))
+            logger.info("features: %s" % features)
+            logger.info("text: %s" % self.tokenizer.decode(features.input_ids))
+
+        return features
+
+
+
diff --git a/LM-BFF/src/label_search.py b/LM-BFF/src/label_search.py
new file mode 100755
index 0000000..e83abd3
--- /dev/null
+++ b/LM-BFF/src/label_search.py
@@ -0,0 +1,148 @@
+"""Automatic label search helpers."""
+
+import itertools
+import torch
+import tqdm
+import multiprocessing
+import numpy as np
+import scipy.spatial as spatial
+import scipy.special as special
+import scipy.stats as stats
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def select_likely_words(train_logits, train_labels, k_likely=1000, vocab=None, is_regression=False):
+    """Pre-select likely words based on conditional likelihood."""
+    indices = []
+    if is_regression:
+        median = np.median(train_labels)
+        train_labels = (train_labels > median).astype(np.int)
+    num_labels = np.max(train_labels) + 1
+    for idx in range(num_labels):
+        label_logits = train_logits[train_labels == idx]
+        scores = label_logits.mean(axis=0)
+        kept = []
+        for i in np.argsort(-scores):
+            text = vocab[i]
+            if not text.startswith("Ġ"):
+                continue
+            kept.append(i)
+        indices.append(kept[:k_likely])
+    return indices
+
+
+def select_neighbors(distances, k_neighbors, valid):
+    """Select k nearest neighbors based on distance (filtered to be within the 'valid' set)."""
+    indices = np.argsort(distances)
+    neighbors = []
+    for i in indices:
+        if i not in valid:
+            continue
+        neighbors.append(i)
+    if k_neighbors > 0:
+        return neighbors[:k_neighbors]
+    return neighbors
+
+
+def init(train_logits, train_labels):
+    global logits, labels
+    logits = train_logits
+    labels = train_labels
+
+
+def eval_pairing_acc(pairing):
+    global logits, labels
+    label_logits = np.take(logits, pairing, axis=-1)
+    preds = np.argmax(label_logits, axis=-1)
+    correct = np.sum(preds == labels)
+    return correct / len(labels)
+
+
+def eval_pairing_corr(pairing):
+    global logits, labels
+    if pairing[0] == pairing[1]:
+        return -1
+    label_logits = np.take(logits, pairing, axis=-1)
+    label_probs = special.softmax(label_logits, axis=-1)[:, 1]
+    pearson_corr = stats.pearsonr(label_probs, labels)[0]
+    return pearson_corr
+
+
+def find_labels(
+    model,
+    train_logits,
+    train_labels,
+    seed_labels=None,
+    k_likely=1000,
+    k_neighbors=None,
+    top_n=-1,
+    vocab=None,
+    is_regression=False,
+):
+    # Get top indices based on conditional likelihood using the LM.
+    likely_indices = select_likely_words(
+        train_logits=train_logits,
+        train_labels=train_labels,
+        k_likely=k_likely,
+        vocab=vocab,
+        is_regression=is_regression)
+
+    logger.info("Top labels (conditional) per class:")
+    for i, inds in enumerate(likely_indices):
+        logger.info("\t| Label %d: %s", i, ", ".join([vocab[i] for i in inds[:10]]))
+
+    # Convert to sets.
+    valid_indices = [set(inds) for inds in likely_indices]
+
+    # If specified, further re-rank according to nearest neighbors of seed labels.
+    # Otherwise, keep ranking as is (based on conditional likelihood only).
+    if seed_labels:
+        assert(vocab is not None)
+        seed_ids = [vocab.index(l) for l in seed_labels]
+        vocab_vecs = model.lm_head.decoder.weight.detach().cpu().numpy()
+        seed_vecs = np.take(vocab_vecs, seed_ids, axis=0)
+
+        # [num_labels, vocab_size]
+        label_distances = spatial.distance.cdist(seed_vecs, vocab_vecs, metric="cosine")
+
+        # Establish label candidates (as k nearest neighbors).
+        label_candidates = []
+        logger.info("Re-ranked by nearest neighbors:")
+        for i, distances in enumerate(label_distances):
+            label_candidates.append(select_neighbors(distances, k_neighbors, valid_indices[i]))
+            logger.info("\t| Label: %s", seed_labels[i])
+            logger.info("\t| Neighbors: %s", " ".join([vocab[idx] for idx in label_candidates[i]]))
+    else:
+        label_candidates = likely_indices
+
+    # Brute-force search all valid pairings.
+    pairings = list(itertools.product(*label_candidates))
+
+    if is_regression:
+        eval_pairing = eval_pairing_corr
+        metric = "corr"
+    else:
+        eval_pairing = eval_pairing_acc
+        metric = "acc"
+
+    # Score each pairing.
+    pairing_scores = []
+    with multiprocessing.Pool(initializer=init, initargs=(train_logits, train_labels)) as workers:
+        with tqdm.tqdm(total=len(pairings)) as pbar:
+            chunksize = max(10, int(len(pairings) / 1000))
+            for score in workers.imap(eval_pairing, pairings, chunksize=chunksize):
+                pairing_scores.append(score)
+                pbar.update()
+
+    # Take top-n.
+    best_idx = np.argsort(-np.array(pairing_scores))[:top_n]
+    best_scores = [pairing_scores[i] for i in best_idx]
+    best_pairings = [pairings[i] for i in best_idx]
+
+    logger.info("Automatically searched pairings:")
+    for i, indices in enumerate(best_pairings):
+        logger.info("\t| %s (%s = %2.2f)", " ".join([vocab[j] for j in indices]), metric, best_scores[i])
+
+    return best_pairings
diff --git a/LM-BFF/src/models.py b/LM-BFF/src/models.py
new file mode 100755
index 0000000..bfdfcc3
--- /dev/null
+++ b/LM-BFF/src/models.py
@@ -0,0 +1,195 @@
+"""Custom models for few-shot learning specific operations."""
+
+import torch
+import torch.nn as nn
+import transformers
+from transformers.modeling_bert import BertPreTrainedModel, BertForSequenceClassification, BertModel, BertOnlyMLMHead
+from transformers.modeling_roberta import RobertaForSequenceClassification, RobertaModel, RobertaLMHead, RobertaClassificationHead
+from transformers.modeling_outputs import SequenceClassifierOutput
+
+import logging
+logger = logging.getLogger(__name__)
+
+def resize_token_type_embeddings(model, new_num_types: int, random_segment: bool):
+    """
+    Resize the segment (token type) embeddings for BERT
+    """
+    if hasattr(model, 'bert'):
+        old_token_type_embeddings = model.bert.embeddings.token_type_embeddings
+    else:
+        raise NotImplementedError
+    new_token_type_embeddings = nn.Embedding(new_num_types, old_token_type_embeddings.weight.size(1))
+    if not random_segment:
+        new_token_type_embeddings.weight.data[:old_token_type_embeddings.weight.size(0)] = old_token_type_embeddings.weight.data
+
+    model.config.type_vocab_size = new_num_types
+    if hasattr(model, 'bert'):
+        model.bert.embeddings.token_type_embeddings = new_token_type_embeddings
+    else:
+        raise NotImplementedError
+
+
+class BertForPromptFinetuning(BertPreTrainedModel):
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        self.bert = BertModel(config)
+        self.cls = BertOnlyMLMHead(config)
+        self.init_weights()
+
+        # These attributes should be assigned once the model is initialized
+        self.model_args = None
+        self.data_args = None
+        self.label_word_list = None
+
+        # For regression
+        self.lb = None
+        self.ub = None
+
+        # For label search.
+        self.return_full_softmax = None
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        token_type_ids=None,
+        mask_pos=None,
+        labels=None,
+    ):
+        batch_size = input_ids.size(0)
+
+        if mask_pos is not None:
+            mask_pos = mask_pos.squeeze()
+
+        # Encode everything
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids
+        )
+
+        # Get <mask> token representation
+        sequence_output, pooled_output = outputs[:2]
+        sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
+
+        # Logits over vocabulary tokens
+        prediction_mask_scores = self.cls(sequence_mask_output)
+
+        # Exit early and only return mask logits.
+        if self.return_full_softmax:
+            if labels is not None:
+                return torch.zeros(1, out=prediction_mask_scores.new()), prediction_mask_scores
+            return prediction_mask_scores
+
+        # Return logits for each label
+        logits = []
+        for label_id in range(len(self.label_word_list)):
+            logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
+        logits = torch.cat(logits, -1)
+
+        # Regression task
+        if self.config.num_labels == 1:
+            logsoftmax = nn.LogSoftmax(-1)
+            logits = logsoftmax(logits) # Log prob of right polarity
+
+        loss = None
+        if labels is not None:
+            if self.num_labels == 1:
+                # Regression task
+                loss_fct = nn.KLDivLoss(log_target=True)
+                labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
+                loss = loss_fct(logits.view(-1, 2), labels)
+            else:
+                loss_fct = nn.CrossEntropyLoss()
+                loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
+
+        output = (logits,)
+        if self.num_labels == 1:
+            # Regression output
+            output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
+        return ((loss,) + output) if loss is not None else output
+
+
+
+class RobertaForPromptFinetuning(BertPreTrainedModel):
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        self.roberta = RobertaModel(config)
+        self.classifier = RobertaClassificationHead(config)
+        self.lm_head = RobertaLMHead(config)
+        self.init_weights()
+
+        # These attributes should be assigned once the model is initialized
+        self.model_args = None
+        self.data_args = None
+        self.label_word_list = None
+
+        # For regression
+        self.lb = None
+        self.ub = None
+
+        # For auto label search.
+        self.return_full_softmax = None
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        mask_pos=None,
+        labels=None,
+    ):
+        batch_size = input_ids.size(0)
+
+        if mask_pos is not None:
+            mask_pos = mask_pos.squeeze()
+
+        # Encode everything
+        outputs = self.roberta(
+            input_ids,
+            attention_mask=attention_mask
+        )
+
+        # Get <mask> token representation
+        sequence_output, pooled_output = outputs[:2]
+        sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]
+
+        # Logits over vocabulary tokens
+        prediction_mask_scores = self.lm_head(sequence_mask_output)
+
+        # Exit early and only return mask logits.
+        if self.return_full_softmax:
+            if labels is not None:
+                return torch.zeros(1, out=prediction_mask_scores.new()), prediction_mask_scores
+            return prediction_mask_scores
+
+        # Return logits for each label
+        logits = []
+        for label_id in range(len(self.label_word_list)):
+            logits.append(prediction_mask_scores[:, self.label_word_list[label_id]].unsqueeze(-1))
+        logits = torch.cat(logits, -1)
+
+        # Regression task
+        if self.config.num_labels == 1:
+            logsoftmax = nn.LogSoftmax(-1)
+            logits = logsoftmax(logits) # Log prob of right polarity
+
+        loss = None
+        if labels is not None:
+            if self.num_labels == 1:
+                # Regression task
+                loss_fct = nn.KLDivLoss(log_target=True)
+                labels = torch.stack([1 - (labels.view(-1) - self.lb) / (self.ub - self.lb), (labels.view(-1) - self.lb) / (self.ub - self.lb)], -1)
+                loss = loss_fct(logits.view(-1, 2), labels)
+            else:
+                loss_fct = nn.CrossEntropyLoss()
+                loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
+
+        output = (logits,)
+        if self.num_labels == 1:
+            # Regression output
+            output = (torch.exp(logits[..., 1].unsqueeze(-1)) * (self.ub - self.lb) + self.lb,)
+        return ((loss,) + output) if loss is not None else output
diff --git a/LM-BFF/src/processors.py b/LM-BFF/src/processors.py
new file mode 100755
index 0000000..b0da9a5
--- /dev/null
+++ b/LM-BFF/src/processors.py
@@ -0,0 +1,626 @@
+"""Dataset utils for different data settings for GLUE."""
+
+import os
+import copy
+import logging
+import torch
+import numpy as np
+import time
+from filelock import FileLock
+import json
+import itertools
+import random
+import transformers
+from transformers.data.processors.utils import InputFeatures
+from transformers import DataProcessor, InputExample
+from transformers.data.processors.glue import *
+from transformers.data.metrics import glue_compute_metrics
+import dataclasses
+from dataclasses import dataclass, asdict
+from typing import List, Optional, Union
+from sentence_transformers import SentenceTransformer, util
+from copy import deepcopy
+import pandas as pd
+import logging
+
+logger = logging.getLogger(__name__)
+
+class MrpcProcessor(DataProcessor):
+    """Processor for the MRPC data set (GLUE version)."""
+
+    def get_example_from_tensor_dict(self, tensor_dict):
+        """See base class."""
+        return InputExample(
+            tensor_dict["idx"].numpy(),
+            tensor_dict["sentence1"].numpy().decode("utf-8"),
+            tensor_dict["sentence2"].numpy().decode("utf-8"),
+            str(tensor_dict["label"].numpy()),
+        )
+
+    def get_train_examples(self, data_dir):
+        """See base class."""
+        logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+    def get_labels(self):
+        """See base class."""
+        return ["0", "1"]
+
+    def _create_examples(self, lines, set_type):
+        """Creates examples for the training, dev and test sets."""
+        examples = []
+        for (i, line) in enumerate(lines):
+            if i == 0:
+                continue
+            guid = "%s-%s" % (set_type, i)
+            text_a = line[3]
+            text_b = line[4]
+            label = line[0]
+            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+        return examples
+
+
+class MnliProcessor(DataProcessor):
+    """Processor for the MultiNLI data set (GLUE version)."""
+
+    def get_example_from_tensor_dict(self, tensor_dict):
+        """See base class."""
+        return InputExample(
+            tensor_dict["idx"].numpy(),
+            tensor_dict["premise"].numpy().decode("utf-8"),
+            tensor_dict["hypothesis"].numpy().decode("utf-8"),
+            str(tensor_dict["label"].numpy()),
+        )
+
+    def get_train_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
+
+    def get_labels(self):
+        """See base class."""
+        return ["contradiction", "entailment", "neutral"]
+
+    def _create_examples(self, lines, set_type):
+        """Creates examples for the training, dev and test sets."""
+        examples = []
+        for (i, line) in enumerate(lines):
+            if i == 0:
+                continue
+            guid = "%s-%s" % (set_type, line[0])
+            text_a = line[8]
+            text_b = line[9]
+            label = line[-1]
+            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+        return examples
+
+
+class MnliMismatchedProcessor(MnliProcessor):
+    """Processor for the MultiNLI Mismatched data set (GLUE version)."""
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
+
+
+class SnliProcessor(DataProcessor):
+    """Processor for the MultiNLI data set (GLUE version)."""
+
+    def get_example_from_tensor_dict(self, tensor_dict):
+        """See base class."""
+        return InputExample(
+            tensor_dict["idx"].numpy(),
+            tensor_dict["premise"].numpy().decode("utf-8"),
+            tensor_dict["hypothesis"].numpy().decode("utf-8"),
+            str(tensor_dict["label"].numpy()),
+        )
+
+    def get_train_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+    def get_labels(self):
+        """See base class."""
+        return ["contradiction", "entailment", "neutral"]
+
+    def _create_examples(self, lines, set_type):
+        """Creates examples for the training, dev and test sets."""
+        examples = []
+        for (i, line) in enumerate(lines):
+            if i == 0:
+                continue
+            guid = "%s-%s" % (set_type, line[0])
+            text_a = line[7]
+            text_b = line[8]
+            label = line[-1]
+            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+        return examples
+
+
+class ColaProcessor(DataProcessor):
+    """Processor for the CoLA data set (GLUE version)."""
+
+    def get_example_from_tensor_dict(self, tensor_dict):
+        """See base class."""
+        return InputExample(
+            tensor_dict["idx"].numpy(),
+            tensor_dict["sentence"].numpy().decode("utf-8"),
+            None,
+            str(tensor_dict["label"].numpy()),
+        )
+
+    def get_train_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+    def get_labels(self):
+        """See base class."""
+        return ["0", "1"]
+
+    def _create_examples(self, lines, set_type):
+        """Creates examples for the training, dev and test sets."""
+        test_mode = set_type == "test"
+        text_index = 3
+        examples = []
+        for (i, line) in enumerate(lines):
+            guid = "%s-%s" % (set_type, i)
+            text_a = line[text_index]
+            label = line[1]
+            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
+        return examples
+
+
+class Sst2Processor(DataProcessor):
+    """Processor for the SST-2 data set (GLUE version)."""
+
+    def get_example_from_tensor_dict(self, tensor_dict):
+        """See base class."""
+        return InputExample(
+            tensor_dict["idx"].numpy(),
+            tensor_dict["sentence"].numpy().decode("utf-8"),
+            None,
+            str(tensor_dict["label"].numpy()),
+        )
+
+    def get_train_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+    def get_labels(self):
+        """See base class."""
+        return ["0", "1"]
+
+    def _create_examples(self, lines, set_type):
+        """Creates examples for the training, dev and test sets."""
+        examples = []
+        text_index = 0
+        for (i, line) in enumerate(lines):
+            if i == 0:
+                continue
+            guid = "%s-%s" % (set_type, i)
+            text_a = line[text_index]
+            label = line[1]
+            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
+        return examples
+
+
+class StsbProcessor(DataProcessor):
+    """Processor for the STS-B data set (GLUE version)."""
+
+    def get_example_from_tensor_dict(self, tensor_dict):
+        """See base class."""
+        return InputExample(
+            tensor_dict["idx"].numpy(),
+            tensor_dict["sentence1"].numpy().decode("utf-8"),
+            tensor_dict["sentence2"].numpy().decode("utf-8"),
+            str(tensor_dict["label"].numpy()),
+        )
+
+    def get_train_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+    def get_labels(self):
+        """See base class."""
+        return [None]
+
+    def _create_examples(self, lines, set_type):
+        """Creates examples for the training, dev and test sets."""
+        examples = []
+        for (i, line) in enumerate(lines):
+            if i == 0:
+                continue
+            guid = "%s-%s" % (set_type, line[0])
+            text_a = line[7]
+            text_b = line[8]
+            label = line[-1]
+            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+        return examples
+
+
+class QqpProcessor(DataProcessor):
+    """Processor for the QQP data set (GLUE version)."""
+
+    def get_example_from_tensor_dict(self, tensor_dict):
+        """See base class."""
+        return InputExample(
+            tensor_dict["idx"].numpy(),
+            tensor_dict["question1"].numpy().decode("utf-8"),
+            tensor_dict["question2"].numpy().decode("utf-8"),
+            str(tensor_dict["label"].numpy()),
+        )
+
+    def get_train_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+    def get_labels(self):
+        """See base class."""
+        return ["0", "1"]
+
+    def _create_examples(self, lines, set_type):
+        """Creates examples for the training, dev and test sets."""
+        test_mode = set_type == "test"
+        q1_index = 3
+        q2_index = 4
+        examples = []
+        for (i, line) in enumerate(lines):
+            if i == 0:
+                continue
+            guid = "%s-%s" % (set_type, line[0])
+            try:
+                text_a = line[q1_index]
+                text_b = line[q2_index]
+                label = line[5]
+            except IndexError:
+                continue
+            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+        return examples
+
+
+class QnliProcessor(DataProcessor):
+    """Processor for the QNLI data set (GLUE version)."""
+
+    def get_example_from_tensor_dict(self, tensor_dict):
+        """See base class."""
+        return InputExample(
+            tensor_dict["idx"].numpy(),
+            tensor_dict["question"].numpy().decode("utf-8"),
+            tensor_dict["sentence"].numpy().decode("utf-8"),
+            str(tensor_dict["label"].numpy()),
+        )
+
+    def get_train_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+    def get_labels(self):
+        """See base class."""
+        return ["entailment", "not_entailment"]
+
+    def _create_examples(self, lines, set_type):
+        """Creates examples for the training, dev and test sets."""
+        examples = []
+        for (i, line) in enumerate(lines):
+            if i == 0:
+                continue
+            guid = "%s-%s" % (set_type, line[0])
+            text_a = line[1]
+            text_b = line[2]
+            label = line[-1]
+            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+        return examples
+
+
+class RteProcessor(DataProcessor):
+    """Processor for the RTE data set (GLUE version)."""
+
+    def get_example_from_tensor_dict(self, tensor_dict):
+        """See base class."""
+        return InputExample(
+            tensor_dict["idx"].numpy(),
+            tensor_dict["sentence1"].numpy().decode("utf-8"),
+            tensor_dict["sentence2"].numpy().decode("utf-8"),
+            str(tensor_dict["label"].numpy()),
+        )
+
+    def get_train_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+    def get_labels(self):
+        """See base class."""
+        return ["entailment", "not_entailment"]
+
+    def _create_examples(self, lines, set_type):
+        """Creates examples for the training, dev and test sets."""
+        examples = []
+        for (i, line) in enumerate(lines):
+            if i == 0:
+                continue
+            guid = "%s-%s" % (set_type, line[0])
+            text_a = line[1]
+            text_b = line[2]
+            label = line[-1]
+            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+        return examples
+
+
+class WnliProcessor(DataProcessor):
+    """Processor for the WNLI data set (GLUE version)."""
+
+    def get_example_from_tensor_dict(self, tensor_dict):
+        """See base class."""
+        return InputExample(
+            tensor_dict["idx"].numpy(),
+            tensor_dict["sentence1"].numpy().decode("utf-8"),
+            tensor_dict["sentence2"].numpy().decode("utf-8"),
+            str(tensor_dict["label"].numpy()),
+        )
+
+    def get_train_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
+
+    def get_labels(self):
+        """See base class."""
+        return ["0", "1"]
+
+    def _create_examples(self, lines, set_type):
+        """Creates examples for the training, dev and test sets."""
+        examples = []
+        for (i, line) in enumerate(lines):
+            if i == 0:
+                continue
+            guid = "%s-%s" % (set_type, line[0])
+            text_a = line[1]
+            text_b = line[2]
+            label = line[-1]
+            examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+        return examples
+
+class TextClassificationProcessor(DataProcessor):
+    """
+    Data processor for text classification datasets (mr, sst-5, subj, trec, cr, mpqa).
+    """
+
+    def __init__(self, task_name):
+        self.task_name = task_name 
+
+    def get_example_from_tensor_dict(self, tensor_dict):
+        """See base class."""
+        return InputExample(
+            tensor_dict["idx"].numpy(),
+            tensor_dict["sentence"].numpy().decode("utf-8"),
+            None,
+            str(tensor_dict["label"].numpy()),
+        )
+  
+    def get_train_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(pd.read_csv(os.path.join(data_dir, "train.csv"), header=None).values.tolist(), "train")
+
+    def get_dev_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(pd.read_csv(os.path.join(data_dir, "dev.csv"), header=None).values.tolist(), "dev")
+
+    def get_test_examples(self, data_dir):
+        """See base class."""
+        return self._create_examples(pd.read_csv(os.path.join(data_dir, "test.csv"), header=None).values.tolist(), "test")
+
+    def get_labels(self):
+        """See base class."""
+        if self.task_name == "mr":
+            return list(range(2))
+        elif self.task_name == "sst-5":
+            return list(range(5))
+        elif self.task_name == "subj":
+            return list(range(2))
+        elif self.task_name == "trec":
+            return list(range(6))
+        elif self.task_name == "cr":
+            return list(range(2))
+        elif self.task_name == "mpqa":
+            return list(range(2))
+        else:
+            raise Exception("task_name not supported.")
+        
+    def _create_examples(self, lines, set_type):
+        """Creates examples for the training, dev and test sets."""
+        examples = []
+        for (i, line) in enumerate(lines):
+            guid = "%s-%s" % (set_type, i)
+            if self.task_name == "ag_news":
+                examples.append(InputExample(guid=guid, text_a=line[1] + '. ' + line[2], short_text=line[1] + ".", label=line[0]))
+            elif self.task_name == "yelp_review_full":
+                examples.append(InputExample(guid=guid, text_a=line[1], short_text=line[1], label=line[0]))
+            elif self.task_name == "yahoo_answers":
+                text = line[1]
+                if not pd.isna(line[2]):
+                    text += ' ' + line[2]
+                if not pd.isna(line[3]):
+                    text += ' ' + line[3]
+                examples.append(InputExample(guid=guid, text_a=text, short_text=line[1], label=line[0])) 
+            elif self.task_name in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']:
+                examples.append(InputExample(guid=guid, text_a=line[1], label=line[0]))
+            else:
+                raise Exception("Task_name not supported.")
+
+        return examples
+        
+def text_classification_metrics(task_name, preds, labels):
+    return {"acc": (preds == labels).mean()}
+
+# Add your task to the following mappings
+
+processors_mapping = {
+    "cola": ColaProcessor(),
+    "mnli": MnliProcessor(),
+    "mnli-mm": MnliMismatchedProcessor(),
+    "mrpc": MrpcProcessor(),
+    "sst-2": Sst2Processor(),
+    "sts-b": StsbProcessor(),
+    "qqp": QqpProcessor(),
+    "qnli": QnliProcessor(),
+    "rte": RteProcessor(),
+    "wnli": WnliProcessor(),
+    "snli": SnliProcessor(),
+    "mr": TextClassificationProcessor("mr"),
+    "sst-5": TextClassificationProcessor("sst-5"),
+    "subj": TextClassificationProcessor("subj"),
+    "trec": TextClassificationProcessor("trec"),
+    "cr": TextClassificationProcessor("cr"),
+    "mpqa": TextClassificationProcessor("mpqa")
+}
+
+num_labels_mapping = {
+    "cola": 2,
+    "mnli": 3,
+    "mrpc": 2,
+    "sst-2": 2,
+    "sts-b": 1,
+    "qqp": 2,
+    "qnli": 2,
+    "rte": 2,
+    "wnli": 2,
+    "snli": 3,
+    "mr": 2,
+    "sst-5": 5,
+    "subj": 2,
+    "trec": 6,
+    "cr": 2,
+    "mpqa": 2
+}
+
+output_modes_mapping = {
+    "cola": "classification",
+    "mnli": "classification",
+    "mnli-mm": "classification",
+    "mrpc": "classification",
+    "sst-2": "classification",
+    "sts-b": "regression",
+    "qqp": "classification",
+    "qnli": "classification",
+    "rte": "classification",
+    "wnli": "classification",
+    "snli": "classification",
+    "mr": "classification",
+    "sst-5": "classification",
+    "subj": "classification",
+    "trec": "classification",
+    "cr": "classification",
+    "mpqa": "classification"
+}
+
+# Return a function that takes (task_name, preds, labels) as inputs
+compute_metrics_mapping = {
+    "cola": glue_compute_metrics,
+    "mnli": glue_compute_metrics,
+    "mnli-mm": glue_compute_metrics,
+    "mrpc": glue_compute_metrics,
+    "sst-2": glue_compute_metrics,
+    "sts-b": glue_compute_metrics,
+    "qqp": glue_compute_metrics,
+    "qnli": glue_compute_metrics,
+    "rte": glue_compute_metrics,
+    "wnli": glue_compute_metrics,
+    "snli": text_classification_metrics,
+    "mr": text_classification_metrics,
+    "sst-5": text_classification_metrics,
+    "subj": text_classification_metrics,
+    "trec": text_classification_metrics,
+    "cr": text_classification_metrics,
+    "mpqa": text_classification_metrics,
+}
+
+# For regression task only: median
+median_mapping = {
+    "sts-b": 2.5
+}
+
+bound_mapping = {
+    "sts-b": (0, 5)
+}
diff --git a/LM-BFF/src/trainer.py b/LM-BFF/src/trainer.py
new file mode 100755
index 0000000..b22a527
--- /dev/null
+++ b/LM-BFF/src/trainer.py
@@ -0,0 +1,473 @@
+########## The following part is copied from Transformers' trainer (3.4.0) ########## 
+
+# coding=utf-8
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
+"""
+
+import collections
+import inspect
+import math
+import os
+import re
+import shutil
+import warnings
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from packaging import version
+from torch import nn
+from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.dataset import Dataset
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data.sampler import RandomSampler, SequentialSampler
+
+import transformers
+from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
+from transformers.file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
+from transformers.integrations import (
+    default_hp_search_backend,
+    is_comet_available,
+    is_optuna_available,
+    is_ray_available,
+    is_tensorboard_available,
+    is_wandb_available,
+    run_hp_search_optuna,
+    run_hp_search_ray,
+)
+from transformers.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
+from transformers.modeling_utils import PreTrainedModel
+from transformers.optimization import AdamW, get_linear_schedule_with_warmup
+from transformers.tokenization_utils_base import PreTrainedTokenizerBase
+from transformers.trainer_callback import (
+    CallbackHandler,
+    DefaultFlowCallback,
+    PrinterCallback,
+    ProgressCallback,
+    TrainerCallback,
+    TrainerControl,
+    TrainerState,
+)
+from transformers.trainer_pt_utils import (
+    DistributedTensorGatherer,
+    SequentialDistributedSampler,
+    distributed_broadcast_scalars,
+    distributed_concat,
+    get_tpu_sampler,
+    nested_concat,
+    nested_detach,
+    nested_numpify,
+    nested_xla_mesh_reduce,
+    reissue_pt_warnings,
+)
+from transformers.trainer_utils import (
+    PREFIX_CHECKPOINT_DIR,
+    BestRun,
+    EvalPrediction,
+    HPSearchBackend,
+    PredictionOutput,
+    TrainOutput,
+    default_compute_objective,
+    default_hp_space,
+    set_seed,
+)
+from transformers.training_args import TrainingArguments
+from transformers.utils import logging
+
+from tqdm import tqdm, trange
+
+_use_native_amp = False
+_use_apex = False
+
+DEFAULT_CALLBACKS = [DefaultFlowCallback]
+DEFAULT_PROGRESS_CALLBACK = ProgressCallback
+
+if is_in_notebook():
+    from transformers.utils.notebook import NotebookProgressCallback
+
+    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
+
+# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
+if version.parse(torch.__version__) < version.parse("1.6"):
+    from transformers.file_utils import is_apex_available
+
+    if is_apex_available():
+        from apex import amp
+    _use_apex = True
+else:
+    _use_native_amp = True
+    from torch.cuda.amp import autocast
+
+if version.parse(torch.__version__) < version.parse("1.2"):
+    _use_ddp_no_sync = False
+else:
+    _use_ddp_no_sync = True
+
+if is_datasets_available():
+    import datasets
+
+if is_torch_tpu_available():
+    import torch_xla.core.xla_model as xm
+    import torch_xla.debug.metrics as met
+    import torch_xla.distributed.parallel_loader as pl
+
+if is_tensorboard_available():
+    from transformers.integrations import TensorBoardCallback
+
+    DEFAULT_CALLBACKS.append(TensorBoardCallback)
+
+
+if is_wandb_available():
+    from transformers.integrations import WandbCallback
+
+    DEFAULT_CALLBACKS.append(WandbCallback)
+
+if is_comet_available():
+    from transformers.integrations import CometCallback
+
+    DEFAULT_CALLBACKS.append(CometCallback)
+
+if is_optuna_available():
+    import optuna
+
+if is_ray_available():
+    from ray import tune
+
+logger = logging.get_logger(__name__)
+
+########## The above part is copied from Transformers' trainer (3.4.0) ########## 
+
+def default_dev_objective(metrics):
+    """
+    Objective used for picking the best model on development sets
+    """
+    if "eval_mnli/acc" in metrics:
+        return metrics["eval_mnli/acc"]
+    elif "eval_mnli-mm/acc" in metrics:
+        return metrics["eval_mnli-mm/acc"]
+    elif "eval_f1" in metrics:
+        return metrics["eval_f1"]
+    elif "eval_mcc" in metrics:
+        return metrics["eval_mcc"]
+    elif "eval_pearson" in metrics:
+        return metrics["eval_pearson"]
+    elif "eval_acc" in metrics:
+        return metrics["eval_acc"]
+ 
+    raise Exception("No metric founded for {}".format(metrics))
+
+class Trainer(transformers.Trainer):
+    """
+    Adding some functions based on Transformers' Trainer class.
+    """
+
+    def create_optimizer_and_scheduler(self, num_training_steps: int):
+        """
+        Based on Transformers' default one, we add fixing layer option where the bottom n layers' parameters
+        are fixed and only the top layers are further fine-tuned.
+        """
+        if self.optimizer is None:
+            params = {}
+            for n, p in self.model.named_parameters():
+                if self.args.fix_layers > 0:
+                    if 'encoder.layer' in n:
+                        try:
+                            layer_num = int(n[n.find('encoder.layer') + 14:].split('.')[0])
+                        except:
+                            print(n)
+                            raise Exception("")
+                        if layer_num >= self.args.fix_layers:
+                            print('yes', n)
+                            params[n] = p
+                        else:
+                            print('no ', n)
+                    elif 'embeddings' in n:
+                        print('no ', n)
+                    else:
+                        print('yes', n)
+                        params[n] = p
+                else:
+                    params[n] = p
+            no_decay = ["bias", "LayerNorm.weight"]
+            optimizer_grouped_parameters = [
+                {
+                    "params": [p for n, p in params.items() if not any(nd in n for nd in no_decay)],
+                    "weight_decay": self.args.weight_decay,
+                },
+                {
+                    "params": [p for n, p in params.items() if any(nd in n for nd in no_decay)],
+                    "weight_decay": 0.0,
+                },
+            ]
+            self.optimizer = AdamW(
+                optimizer_grouped_parameters,
+                lr=self.args.learning_rate,
+                betas=(self.args.adam_beta1, self.args.adam_beta2),
+                eps=self.args.adam_epsilon,
+            )
+        if self.lr_scheduler is None:
+            self.lr_scheduler = get_linear_schedule_with_warmup(
+                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
+            )
+
+    def train(self, model_path=None, dev_objective=None):
+        """
+        Main training entry point.
+
+        The training logic is directly borrowed from transformers.Trainer (version 3.0.2).
+        Add early stopping.
+        """
+        self.best_dir = None
+        self.objective = -float("inf")
+        self.dev_objective = dev_objective if dev_objective is not None else default_dev_objective
+
+        # Data loading.
+        train_dataloader = self.get_train_dataloader()
+        num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps 
+        if num_update_steps_per_epoch == 0:
+            num_update_steps_per_epoch = 1
+        if self.args.max_steps > 0:
+            t_total = self.args.max_steps
+            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
+                self.args.max_steps % num_update_steps_per_epoch > 0
+            )
+        else:
+            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
+            num_train_epochs = self.args.num_train_epochs
+
+        self.create_optimizer_and_scheduler(num_training_steps=t_total)
+        optimizer = self.optimizer
+        scheduler = self.lr_scheduler
+
+        # Check if saved optimizer or scheduler states exist
+        if (
+            model_path is not None
+            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
+            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
+        ):
+            # Load in optimizer and scheduler states
+            optimizer.load_state_dict(
+                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
+            )
+            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
+
+        model = self.model
+
+        if self.args.fp16 and _use_apex:
+            if not transformers.is_apex_available():
+                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
+            model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)
+
+        # Multi-gpu training (should be after apex fp16 initialization)
+        if self.args.n_gpu > 1:
+            model = torch.nn.DataParallel(model)
+
+        # Distributed training (should be after apex fp16 initialization)
+        if self.args.local_rank != -1:
+            model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.args.local_rank],
+                output_device=self.args.local_rank,
+                find_unused_parameters=True,
+            )
+
+        # Train
+        if transformers.is_torch_tpu_available():
+            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
+        else:
+            total_train_batch_size = (
+                self.args.train_batch_size
+                * self.args.gradient_accumulation_steps
+                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
+            )
+        logger.info("***** Running training *****")
+        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
+        logger.info("  Num Epochs = %d", num_train_epochs)
+        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
+        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
+        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
+        logger.info("  Total optimization steps = %d", t_total)
+
+        self.global_step = 0
+        self.epoch = 0
+        epochs_trained = 0
+        steps_trained_in_current_epoch = 0
+        # Check if continuing training from a checkpoint
+        if model_path is not None:
+            # set global_step to global_step of last saved checkpoint from model path
+            try:
+                self.global_step = int(model_path.split("-")[-1].split("/")[0])
+                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
+                steps_trained_in_current_epoch = self.global_step % (
+                    len(train_dataloader) // self.args.gradient_accumulation_steps
+                )
+
+                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
+                logger.info("  Continuing training from epoch %d", epochs_trained)
+                logger.info("  Continuing training from global step %d", self.global_step)
+                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
+            except ValueError:
+                self.global_step = 0
+                logger.info("  Starting fine-tuning.")
+
+        tr_loss = torch.tensor(0.0).to(self.args.device)
+        logging_loss_scalar = 0.0
+        model.zero_grad()
+        train_iterator = trange(
+            epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master()
+        )
+        for epoch in train_iterator:
+            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
+                train_dataloader.sampler.set_epoch(epoch)
+
+            if transformers.is_torch_tpu_available():
+                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
+                    self.args.device
+                )
+                epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master())
+            else:
+                epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
+
+            # Reset the past mems state at the beginning of each epoch if necessary.
+            if self.args.past_index >= 0:
+                self._past = None
+
+            for step, inputs in enumerate(epoch_iterator):
+
+                # Skip past any already trained steps if resuming training
+                if steps_trained_in_current_epoch > 0:
+                    steps_trained_in_current_epoch -= 1
+                    continue
+
+                tr_loss += self.training_step(model, inputs)
+
+                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
+                    # last step in epoch but step is always smaller than gradient_accumulation_steps
+                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
+                    and (step + 1) == len(epoch_iterator)
+                ):
+                    if self.args.fp16 and _use_native_amp:
+                        self.scaler.unscale_(optimizer)
+                        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
+                    elif self.args.fp16:
+                        norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
+                    else:
+                        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
+
+                    if transformers.is_torch_tpu_available():
+                        xm.optimizer_step(optimizer)
+                    elif self.args.fp16 and _use_native_amp:
+                        self.scaler.step(optimizer)
+                        self.scaler.update()
+                    else:
+                        optimizer.step()
+
+                    scheduler.step()
+                    model.zero_grad()
+                    self.global_step += 1
+                    self.epoch = epoch + (step + 1) / len(epoch_iterator)
+
+                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
+                        self.global_step == 1 and self.args.logging_first_step
+                    ):
+                        logs = {}
+                        tr_loss_scalar = tr_loss.item()
+                        logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
+                        logs["norm"] = norm.item()
+                        # backward compatibility for pytorch schedulers
+                        logs["learning_rate"] = (
+                            scheduler.get_last_lr()[0]
+                            if version.parse(torch.__version__) >= version.parse("1.4")
+                            else scheduler.get_lr()[0]
+                        )
+                        logging_loss_scalar = tr_loss_scalar
+
+                        self.log(logs)
+
+                    # ----------------------------------------------------------------------
+                    # BEGIN CHANGES.
+                    # ----------------------------------------------------------------------
+
+                    metrics = None
+                    if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
+                        output = self.evaluate()
+                        metrics = output.metrics
+                        objective = self.dev_objective(metrics)
+                        if objective > self.objective:
+                            logger.info("Best dev result: {}".format(objective))
+                            self.objective = objective
+                            self.save_model(self.args.output_dir) 
+
+                    # ----------------------------------------------------------------------
+                    # END CHANGES.
+                    # ----------------------------------------------------------------------
+
+
+                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
+                    epoch_iterator.close()
+                    break
+            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
+                train_iterator.close()
+                break
+            if self.args.tpu_metrics_debug or self.args.debug:
+                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+                xm.master_print(met.metrics_report())
+
+        if self.args.past_index and hasattr(self, "_past"):
+            # Clean the state at the end of training
+            delattr(self, "_past")
+
+        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
+        return TrainOutput(self.global_step, tr_loss / self.global_step), self.objective
+
+
+    """
+    Difference compared to original implementation: return output instead of output.metrics (so there is also the logits)
+    """
+    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
+        """
+        Run evaluation and returns metrics.
+
+        The calling script will be responsible for providing a method to compute metrics, as they are
+        task-dependent (pass it to the init :obj:`compute_metrics` argument).
+
+        You can also subclass and override this method to inject custom behavior.
+
+        Args:
+            eval_dataset (:obj:`Dataset`, `optional`):
+                Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
+                columns not accepted by the ``model.forward()`` method are automatically removed. It must implement
+                the :obj:`__len__` method.
+
+        Returns:
+            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
+        """
+        if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
+            raise ValueError("eval_dataset must implement __len__")
+
+        eval_dataloader = self.get_eval_dataloader(eval_dataset)
+
+        output = self.prediction_loop(eval_dataloader, description="Evaluation")
+
+        self.log(output.metrics)
+
+        if self.args.tpu_metrics_debug or self.args.debug:
+            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+            xm.master_print(met.metrics_report())
+
+        return output
diff --git a/LM-BFF/tools/ensemble.py b/LM-BFF/tools/ensemble.py
new file mode 100755
index 0000000..8f914e1
--- /dev/null
+++ b/LM-BFF/tools/ensemble.py
@@ -0,0 +1,289 @@
+import argparse
+import pandas as pd
+import json
+import numpy as np
+import torch
+import os
+from torch import device
+from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
+from transformers import GlueDataTrainingArguments, glue_compute_metrics
+from transformers.data.metrics import simple_accuracy
+from transformers.data.processors.glue import glue_processors
+
+def get_glue_label(task, line):
+    if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
+        line = line.strip().split('\t')
+        if task == 'CoLA':
+            return line[1]
+        elif task == 'MNLI':
+            return line[-1]
+        elif task == 'MRPC':
+            return line[0]
+        elif task == 'QNLI':
+            return line[-1]
+        elif task == 'QQP':
+            return line[-1]
+        elif task == 'RTE':
+            return line[-1]
+        elif task == 'SNLI':
+            return line[-1]
+        elif task == 'SST-2':
+            return line[-1]
+        elif task == 'STS-B':
+            return 0 if float(line[-1]) < 2.5 else 1
+        elif task == 'WNLI':
+            return line[-1]
+        else:
+            raise NotImplementedError
+    else:
+        raise NotImplementedError
+
+def get_labels(data_dir, k, seed, task, print_name):
+    if print_name in ['sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec']:
+        data = pd.read_csv(os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test.csv'), header=None).values.tolist()
+        labels = np.zeros((len(data)), dtype=np.uint8)
+        for i, example in enumerate(data):
+            labels[i] = example[0]
+    elif print_name in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
+        lines = []
+        file_name = os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test.tsv')
+        if task == 'mnli':
+            file_name = os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test_matched.tsv')
+        elif task == 'mnli-mm':
+            file_name = os.path.join(data_dir, print_name, '{}-{}'.format(k, seed), 'test_mismatched.tsv')
+
+        for line in open(file_name):
+            lines.append(line.strip())
+
+        if task != 'cola':
+            lines = lines[1:]
+        label_list = glue_processors[task]().get_labels()
+        label_map = {k: i for i, k in enumerate(label_list)}
+        if task == 'sts-b':
+            label_map = {0: 0, 1: 1}
+        label_ids = np.zeros((len(lines)))
+        for line_id, line in enumerate(lines):
+            label_ids[line_id] = label_map[get_glue_label(print_name, line)]
+    return label_ids
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--n_models", type=int, help="Number of models")
+    parser.add_argument("--k", type=int, default=16, help="Number of training instances per label")
+    parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
+    
+    # These options should usually be kept as their default values
+    parser.add_argument("--data_dir", type=str, default="data/k-shot", help="Data directory")
+    parser.add_argument("--save_logit_dir", type=str, default="ensemble_predict_results", help="Directory to store the logit file.")
+    parser.add_argument("--log", type=str, default="log", help="Log path.")
+    parser.add_argument("--key", type=str, default='', help="Validation metric name")
+    parser.add_argument("--test_key", type=str, default="", help="Test metric name")
+    parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
+
+    args = parser.parse_args()
+
+    condition = eval(args.condition)
+
+    if len(args.key) == 0:
+        if condition['task_name'] == 'cola':
+            args.key = 'cola_dev_eval_mcc'
+            args.test_key = 'cola_test_eval_mcc'
+        elif condition['task_name'] == 'mrpc/acc':
+            args.key = 'mrpc_dev_eval_acc'
+            args.test_key = 'mrpc_test_eval_acc'
+            args.test_key2 = 'mrpc_test_eval_f1'
+            condition['task_name'] = 'mrpc'
+        elif condition['task_name'] == 'mrpc/f1':
+            args.key = 'mrpc_dev_eval_f1'
+            args.test_key2 = 'mrpc_test_eval_acc'
+            args.test_key = 'mrpc_test_eval_f1'
+            condition['task_name'] = 'mrpc'
+        elif condition['task_name'] == 'qqp/acc':
+            args.key = 'qqp_dev_eval_acc'
+            args.test_key = 'qqp_test_eval_acc'
+            args.test_key2 = 'qqp_test_eval_f1'
+            condition['task_name'] = 'qqp'
+        elif condition['task_name'] == 'qqp/f1':
+            args.key = 'qqp_dev_eval_f1'
+            args.test_key2 = 'qqp_test_eval_acc'
+            args.test_key = 'qqp_test_eval_f1'
+            condition['task_name'] = 'qqp'
+        elif condition['task_name'] == 'sts-b/pearson':
+            args.key = 'sts-b_dev_eval_pearson'
+            args.test_key = 'sts-b_test_eval_pearson'
+            args.test_key2 = 'sts-b_test_eval_spearmanr'
+            condition['task_name'] = 'sts-b'
+        elif condition['task_name'] == 'sts-b/spearmanr':
+            args.key = 'sts-b_dev_eval_spearmanr'
+            args.test_key2 = 'sts-b_test_eval_pearson'
+            args.test_key = 'sts-b_test_eval_spearmanr'
+            condition['task_name'] = 'sts-b'
+        elif condition['task_name'] == 'qnli':
+            args.key = 'qnli_dev_eval_acc'
+            args.test_key = 'qnli_test_eval_acc'
+        elif condition['task_name'] == 'sst-2':
+            args.key = 'sst-2_dev_eval_acc'
+            args.test_key = 'sst-2_test_eval_acc'
+        elif condition['task_name'] == 'snli':
+            args.key = 'snli_dev_eval_acc'
+            args.test_key = 'snli_test_eval_acc'
+        elif condition['task_name'] == 'mnli':
+            args.key = 'mnli_dev_eval_mnli/acc'
+            args.test_key = 'mnli_test_eval_mnli/acc'
+        elif condition['task_name'] == 'mnli-mm':
+            args.key = 'mnli_dev_eval_mnli/acc'
+            args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
+        elif condition['task_name'] == 'rte':
+            args.key = 'rte_dev_eval_acc'
+            args.test_key = 'rte_test_eval_acc'
+        elif condition['task_name'] == 'ag_news':
+            args.key = 'ag_news_dev_eval_acc'
+            args.test_key = 'ag_news_test_eval_acc'
+        elif condition['task_name'] == 'yahoo_answers':
+            args.key = 'yahoo_answers_dev_eval_acc'
+            args.test_key = 'yahoo_answers_test_eval_acc'
+        elif condition['task_name'] == 'yelp_review_full':
+            args.key = 'yelp_review_full_dev_eval_acc'
+            args.test_key = 'yelp_review_full_test_eval_acc'
+        elif condition['task_name'] == 'mr':
+            args.key = 'mr_dev_eval_acc'
+            args.test_key = 'mr_test_eval_acc'
+        elif condition['task_name'] == 'sst-5':
+            args.key = 'sst-5_dev_eval_acc'
+            args.test_key = 'sst-5_test_eval_acc'
+        elif condition['task_name'] == 'subj':
+            args.key = 'subj_dev_eval_acc'
+            args.test_key = 'subj_test_eval_acc'
+        elif condition['task_name'] == 'trec':
+            args.key = 'trec_dev_eval_acc'
+            args.test_key = 'trec_test_eval_acc'
+        elif condition['task_name'] == 'cr':
+            args.key = 'cr_dev_eval_acc'
+            args.test_key = 'cr_test_eval_acc'
+        elif condition['task_name'] == 'mpqa':
+            args.key = 'mpqa_dev_eval_acc'
+            args.test_key = 'mpqa_test_eval_acc'
+        else:
+            raise NotImplementedError
+
+    with open(args.log) as f:
+        result_list = []
+        for line in f:
+            result_list.append(eval(line))
+    
+    seed_result = {}
+    seed_best = {}
+    
+    # Gather all logs satisfying the conditions
+    for item in result_list:
+        ok = True
+        for cond in condition:
+            if cond == 'task_name' and condition['task_name'] == 'mnli-mm':
+                if cond not in item or item[cond] != 'mnli':
+                   ok = False
+                   break
+            else: 
+                if cond not in item or item[cond] != condition[cond]:
+                    ok = False
+                    break
+        if 'model_id' not in item or 'array_id' not in item:
+            ok = False
+        
+        if ok:
+            seed = int(item['data_dir'].split('-')[-1])
+            model_id = item['model_id']
+            array_id = item['array_id']
+           
+            if model_id >= 0 and model_id < args.n_models:
+                if seed not in seed_result:
+                    seed_result[seed] = {}
+                    seed_best[seed] = {}
+                if model_id not in seed_result[seed]:
+                    seed_result[seed][model_id] = []
+                    seed_best[seed][model_id] = {args.key: -1e9} 
+
+                seed_result[seed][model_id].append(item)
+                if item[args.key] > seed_best[seed][model_id][args.key]:
+                    seed_best[seed][model_id] = item
+    
+    final_result_dev = np.zeros((len(seed_result), args.n_models))
+    final_result_test = np.zeros((len(seed_result), args.n_models))
+    final_result_test2 = np.zeros((len(seed_result), args.n_models))
+
+    logit_file_list = {}
+    for seed in seed_result:
+        logit_file_list[seed] = []
+
+    # Get the results for each model and pick the best dev trial for each model/seed
+    for model_id in range(args.n_models):
+        for i, seed in enumerate(seed_result):
+            final_result_dev[i][model_id] = seed_best[seed][model_id][args.key]
+            final_result_test[i][model_id] = seed_best[seed][model_id][args.test_key]
+            if len(args.test_key2) > 0:
+                final_result_test2[i][model_id] = seed_best[seed][model_id][args.test_key2]
+
+            logit_file_list[seed].append("{}-{}-{}.npy".format(condition['task_name'], model_id, seed_best[seed][model_id]["array_id"]))
+
+        s = "Model %d | val: mean +- std: %.1f +- %.1f | test: mean +- std: %.1f (%.1f) (median %.1f)" % (model_id, final_result_dev[:, model_id].mean() * 100, final_result_dev[:, model_id].std() * 100, final_result_test[:, model_id].mean() * 100, final_result_test[:, model_id].std() * 100, np.median(final_result_test[:, model_id]) * 100)
+        if len(args.test_key2) > 0:
+            s += " / %.1f +- %.1f (median %.1f)" % (final_result_test2[:, model_id].mean() * 100, final_result_test2[:, model_id].std() * 100, np.median(final_result_test2[:, model_id]) * 100)
+        print(s)
+
+    # Map lower-case names to official names (data folder name)
+    data_dir_mapping = {
+        'cola': 'CoLA',
+        'mrpc': 'MRPC',
+        'qqp': 'QQP',
+        'sts-b': 'STS-B',
+        'sst-2': 'SST-2',
+        'snli': 'SNLI',
+        'mnli': 'MNLI',
+        'mnli-mm': 'MNLI',
+        'rte': 'RTE',
+        'ag_news': 'ag_news',
+        'yahoo_answers': 'yahoo_answers',
+        'yelp_review_full': 'yelp_review_full',
+        'sst-5': 'sst-5',
+        'mr': 'mr',
+        'cr': 'cr',
+        'mpqa': 'mpqa',
+        'subj': 'subj',
+        'trec': 'trec'
+    }
+
+    tokenizer = AutoTokenizer.from_pretrained('roberta-large')
+    ensemble_result = np.zeros((len(seed_result)))
+    ensemble_result2 = np.zeros((len(seed_result))) # for second metric
+
+    # Ensemble for each seed
+    for seed_id, seed in enumerate(seed_result):
+        labels = get_labels(args.data_dir, args.k, seed, condition['task_name'], data_dir_mapping[condition['task_name']])
+
+        # Logits
+        mean_logits = None
+        for fname in logit_file_list[seed]:
+            logits = np.load(os.path.join(args.save_logit_dir, fname))
+            if mean_logits is None:
+                mean_logits = logits
+            else:
+                mean_logits += logits
+        mean_logits /= len(logit_file_list[seed])
+        
+        # Compute metrics
+        preds = mean_logits.argmax(-1)
+        if condition['task_name'] in ['sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec']:
+            metric = {"acc": simple_accuracy(preds, labels)}
+        else:
+            metric = glue_compute_metrics(condition['task_name'], preds, labels)
+
+        ensemble_result[seed_id] = metric[args.test_key.split('_')[-1]]
+        if len(args.test_key2) > 0:
+            ensemble_result2[seed_id] = metric[args.test_key2.split('_')[-1]]
+     
+    s = "mean +- std: %.1f (%.1f) (median %.1f)" % (ensemble_result.mean() * 100, ensemble_result.std() * 100, np.median(ensemble_result) * 100)
+    if len(args.test_key2) > 0:
+        s += " / %.1f (%.1f) (median %.1f)" % (ensemble_result2.mean() * 100, ensemble_result2.std() * 100, np.median(ensemble_result2) * 100)
+    print(s)
+
+if __name__ == '__main__':
+    main()
diff --git a/LM-BFF/tools/gather_result.py b/LM-BFF/tools/gather_result.py
new file mode 100755
index 0000000..83c947d
--- /dev/null
+++ b/LM-BFF/tools/gather_result.py
@@ -0,0 +1,160 @@
+import argparse
+import json
+import numpy as np
+import torch
+from torch import device
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
+    
+    # These options should be kept as their default values
+    parser.add_argument("--log", type=str, default="log", help="Log path.")
+    parser.add_argument("--key", type=str, default='', help="Validation metric name")
+    parser.add_argument("--test_key", type=str, default="", help="Test metric name")
+    parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
+
+    args = parser.parse_args()
+
+    condition = eval(args.condition)
+
+    if len(args.key) == 0:
+        if condition['task_name'] == 'cola':
+            args.key = 'cola_dev_eval_mcc'
+            args.test_key = 'cola_test_eval_mcc'
+        elif condition['task_name'] == 'mrpc/acc':
+            args.key = 'mrpc_dev_eval_acc'
+            args.test_key = 'mrpc_test_eval_acc'
+            args.test_key2 = 'mrpc_test_eval_f1'
+            condition['task_name'] = 'mrpc'
+        elif condition['task_name'] == 'mrpc/f1':
+            args.key = 'mrpc_dev_eval_f1'
+            args.test_key2 = 'mrpc_test_eval_acc'
+            args.test_key = 'mrpc_test_eval_f1'
+            condition['task_name'] = 'mrpc'
+        elif condition['task_name'] == 'qqp/acc':
+            args.key = 'qqp_dev_eval_acc'
+            args.test_key = 'qqp_test_eval_acc'
+            args.test_key2 = 'qqp_test_eval_f1'
+            condition['task_name'] = 'qqp'
+        elif condition['task_name'] == 'qqp/f1':
+            args.key = 'qqp_dev_eval_f1'
+            args.test_key2 = 'qqp_test_eval_acc'
+            args.test_key = 'qqp_test_eval_f1'
+            condition['task_name'] = 'qqp'
+        elif condition['task_name'] == 'sts-b/pearson':
+            args.key = 'sts-b_dev_eval_pearson'
+            args.test_key = 'sts-b_test_eval_pearson'
+            args.test_key2 = 'sts-b_test_eval_spearmanr'
+            condition['task_name'] = 'sts-b'
+        elif condition['task_name'] == 'sts-b/spearmanr':
+            args.key = 'sts-b_dev_eval_spearmanr'
+            args.test_key2 = 'sts-b_test_eval_pearson'
+            args.test_key = 'sts-b_test_eval_spearmanr'
+            condition['task_name'] = 'sts-b'
+        elif condition['task_name'] == 'qnli':
+            args.key = 'qnli_dev_eval_acc'
+            args.test_key = 'qnli_test_eval_acc'
+        elif condition['task_name'] == 'sst-2':
+            args.key = 'sst-2_dev_eval_acc'
+            args.test_key = 'sst-2_test_eval_acc'
+        elif condition['task_name'] == 'snli':
+            args.key = 'snli_dev_eval_acc'
+            args.test_key = 'snli_test_eval_acc'
+        elif condition['task_name'] == 'mnli':
+            args.key = 'mnli_dev_eval_mnli/acc'
+            args.test_key = 'mnli_test_eval_mnli/acc'
+        elif condition['task_name'] == 'mnli-mm':
+            condition['task_name'] = 'mnli'
+            args.key = 'mnli_dev_eval_mnli/acc'
+            args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
+        elif condition['task_name'] == 'rte':
+            args.key = 'rte_dev_eval_acc'
+            args.test_key = 'rte_test_eval_acc'
+        elif condition['task_name'] == 'ag_news':
+            args.key = 'ag_news_dev_eval_acc'
+            args.test_key = 'ag_news_test_eval_acc'
+        elif condition['task_name'] == 'yahoo_answers':
+            args.key = 'yahoo_answers_dev_eval_acc'
+            args.test_key = 'yahoo_answers_test_eval_acc'
+        elif condition['task_name'] == 'yelp_review_full':
+            args.key = 'yelp_review_full_dev_eval_acc'
+            args.test_key = 'yelp_review_full_test_eval_acc'
+        elif condition['task_name'] == 'mr':
+            args.key = 'mr_dev_eval_acc'
+            args.test_key = 'mr_test_eval_acc'
+        elif condition['task_name'] == 'sst-5':
+            args.key = 'sst-5_dev_eval_acc'
+            args.test_key = 'sst-5_test_eval_acc'
+        elif condition['task_name'] == 'subj':
+            args.key = 'subj_dev_eval_acc'
+            args.test_key = 'subj_test_eval_acc'
+        elif condition['task_name'] == 'trec':
+            args.key = 'trec_dev_eval_acc'
+            args.test_key = 'trec_test_eval_acc'
+        elif condition['task_name'] == 'cr':
+            args.key = 'cr_dev_eval_acc'
+            args.test_key = 'cr_test_eval_acc'
+        elif condition['task_name'] == 'mpqa':
+            args.key = 'mpqa_dev_eval_acc'
+            args.test_key = 'mpqa_test_eval_acc'
+        else:
+            raise NotImplementedError
+
+    with open(args.log) as f:
+        result_list = []
+        for line in f:
+            result_list.append(eval(line))
+    
+    seed_result = {}
+    seed_best = {}
+
+    for item in result_list:
+        ok = True
+        for cond in condition:
+            if isinstance(condition[cond], list):
+                if cond not in item or (item[cond] not in condition[cond]):
+                    ok = False
+                    break
+            else:
+                if cond not in item or (item[cond] != condition[cond]):
+                    ok = False
+                    break
+        
+        if ok:
+            seed = item['data_dir'].split('-')[-1] + '-' + str(item['seed'])
+            if seed not in seed_result:
+                seed_result[seed] = [item]
+                seed_best[seed] = item
+            else:
+                seed_result[seed].append(item)
+                if item[args.key] > seed_best[seed][args.key]:
+                    seed_best[seed] = item
+    
+    final_result_dev = np.zeros((len(seed_best)))
+    final_result_test = np.zeros((len(seed_best)))
+    final_result_test2 = np.zeros((len(seed_best)))
+    for i, seed in enumerate(seed_best):
+        final_result_dev[i] = seed_best[seed][args.key]
+        final_result_test[i] = seed_best[seed][args.test_key]
+        if len(args.test_key2) > 0:
+            final_result_test2[i] = seed_best[seed][args.test_key2]
+        print("%s: best dev (%.5f) test (%.5f) %s | total trials: %d" % (
+            seed,
+            seed_best[seed][args.key],
+            seed_best[seed][args.test_key],
+            "test2 (%.5f)" % (seed_best[seed][args.test_key2]) if len(args.test_key2) > 0 else "",
+            len(seed_result[seed])
+        ))
+        s = ''
+        for k in ['per_device_train_batch_size', 'gradient_accumulation_steps', 'learning_rate', 'eval_steps', 'max_steps']:
+            s += '| {}: {} '.format(k, seed_best[seed][k])
+        print('    ' + s)
+
+    s = "mean +- std: %.1f (%.1f) (median %.1f)" % (final_result_test.mean() * 100, final_result_test.std() * 100, np.median(final_result_test) * 100)
+    if len(args.test_key2) > 0:
+        s += "second metric: %.1f (%.1f) (median %.1f)" % (final_result_test2.mean() * 100, final_result_test2.std() * 100, np.median(final_result_test2) * 100)
+    print(s)
+
+if __name__ == '__main__':
+    main()
diff --git a/LM-BFF/tools/generate_k_shot_data.py b/LM-BFF/tools/generate_k_shot_data.py
new file mode 100755
index 0000000..b7f4222
--- /dev/null
+++ b/LM-BFF/tools/generate_k_shot_data.py
@@ -0,0 +1,181 @@
+"""This script samples K examples randomly without replacement from the original data."""
+
+import argparse
+import os
+import numpy as np
+import pandas as pd
+from pandas import DataFrame
+
+def get_label(task, line):
+    if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
+        # GLUE style
+        line = line.strip().split('\t')
+        if task == 'CoLA':
+            return line[1]
+        elif task == 'MNLI':
+            return line[-1]
+        elif task == 'MRPC':
+            return line[0]
+        elif task == 'QNLI':
+            return line[-1]
+        elif task == 'QQP':
+            return line[-1]
+        elif task == 'RTE':
+            return line[-1]
+        elif task == 'SNLI':
+            return line[-1]
+        elif task == 'SST-2':
+            return line[-1]
+        elif task == 'STS-B':
+            return 0 if float(line[-1]) < 2.5 else 1
+        elif task == 'WNLI':
+            return line[-1]
+        else:
+            raise NotImplementedError
+    else:
+        return line[0]
+
+def load_datasets(data_dir, tasks):
+    datasets = {}
+    for task in tasks:
+        if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
+            # GLUE style (tsv)
+            dataset = {}
+            dirname = os.path.join(data_dir, task)
+            if task == "MNLI":
+                splits = ["train", "dev_matched", "dev_mismatched"]
+            else:
+                splits = ["train", "dev"]
+            for split in splits:
+                filename = os.path.join(dirname, f"{split}.tsv")
+                with open(filename, "r") as f:
+                    lines = f.readlines()
+                dataset[split] = lines
+            datasets[task] = dataset
+        else:
+            # Other datasets (csv)
+            dataset = {}
+            dirname = os.path.join(data_dir, task)
+            splits = ["train", "test"]
+            for split in splits:
+                filename = os.path.join(dirname, f"{split}.csv")
+                dataset[split] = pd.read_csv(filename, header=None)
+            datasets[task] = dataset
+    return datasets
+
+def split_header(task, lines):
+    """
+    Returns if the task file has a header or not. Only for GLUE tasks.
+    """
+    if task in ["CoLA"]:
+        return [], lines
+    elif task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI"]:
+        return lines[0:1], lines[1:]
+    else:
+        raise ValueError("Unknown GLUE task.")
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--k", type=int, default=16,
+        help="Training examples for each class.")
+    parser.add_argument("--task", type=str, nargs="+", 
+        default=['SST-2', 'sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'],
+        help="Task names")
+    parser.add_argument("--seed", type=int, nargs="+", 
+        default=[100, 13, 21, 42, 87],
+        help="Random seeds")
+
+    parser.add_argument("--data_dir", type=str, default="data/original", help="Path to original data")
+    parser.add_argument("--output_dir", type=str, default="data", help="Output path")
+    parser.add_argument("--mode", type=str, default='k-shot', choices=['k-shot', 'k-shot-10x'], help="k-shot or k-shot-10x (10x dev set)") 
+
+    args = parser.parse_args()
+    args.output_dir = os.path.join(args.output_dir, args.mode)
+
+    k = args.k
+    print("K =", k)
+    datasets = load_datasets(args.data_dir, args.task)
+
+    for seed in args.seed:
+        print("Seed = %d" % (seed))
+        for task, dataset in datasets.items():
+            # Set random seed
+            np.random.seed(seed)
+
+            # Shuffle the training set
+            print("| Task = %s" % (task))
+            if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
+                # GLUE style 
+                train_header, train_lines = split_header(task, dataset["train"])
+                np.random.shuffle(train_lines)
+            else:
+                # Other datasets 
+                train_lines = dataset['train'].values.tolist()
+                np.random.shuffle(train_lines)
+
+            # Set up dir
+            task_dir = os.path.join(args.output_dir, task)
+            setting_dir = os.path.join(task_dir, f"{k}-{seed}")
+            os.makedirs(setting_dir, exist_ok=True)
+
+            # Write test splits
+            if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
+                # GLUE style
+                # Use the original development set as the test set (the original test sets are not publicly available)
+                for split, lines in dataset.items():
+                    if split.startswith("train"):
+                        continue
+                    split = split.replace('dev', 'test')
+                    with open(os.path.join(setting_dir, f"{split}.tsv"), "w") as f:
+                        for line in lines:
+                            f.write(line)
+            else:
+                # Other datasets
+                # Use the original test sets
+                dataset['test'].to_csv(os.path.join(setting_dir, 'test.csv'), header=False, index=False)  
+            
+            # Get label list for balanced sampling
+            label_list = {}
+            for line in train_lines:
+                label = get_label(task, line)
+                if label not in label_list:
+                    label_list[label] = [line]
+                else:
+                    label_list[label].append(line)
+            
+            if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
+                with open(os.path.join(setting_dir, "train.tsv"), "w") as f:
+                    for line in train_header:
+                        f.write(line)
+                    for label in label_list:
+                        for line in label_list[label][:k]:
+                            f.write(line)
+                name = "dev.tsv"
+                if task == 'MNLI':
+                    name = "dev_matched.tsv"
+                with open(os.path.join(setting_dir, name), "w") as f:
+                    for line in train_header:
+                        f.write(line)
+                    for label in label_list:
+                        dev_rate = 11 if '10x' in args.mode else 2
+                        for line in label_list[label][k:k*dev_rate]:
+                            f.write(line)
+            else:
+                new_train = []
+                for label in label_list:
+                    for line in label_list[label][:k]:
+                        new_train.append(line)
+                new_train = DataFrame(new_train)
+                new_train.to_csv(os.path.join(setting_dir, 'train.csv'), header=False, index=False)
+            
+                new_dev = []
+                for label in label_list:
+                    dev_rate = 11 if '10x' in args.mode else 2
+                    for line in label_list[label][k:k*dev_rate]:
+                        new_dev.append(line)
+                new_dev = DataFrame(new_dev)
+                new_dev.to_csv(os.path.join(setting_dir, 'dev.csv'), header=False, index=False)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/LM-BFF/tools/generate_labels.py b/LM-BFF/tools/generate_labels.py
new file mode 100755
index 0000000..3acc47f
--- /dev/null
+++ b/LM-BFF/tools/generate_labels.py
@@ -0,0 +1,284 @@
+"""Finetuning the library models for sequence classification on GLUE."""
+
+import os, sys, inspect
+currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
+parentdir = os.path.dirname(currentdir)
+sys.path.insert(0, parentdir)
+
+import logging
+import json
+
+from dataclasses import dataclass, field
+from typing import Optional
+
+from transformers import AutoConfig, AutoTokenizer
+from transformers import GlueDataTrainingArguments as DataTrainingArguments
+from transformers import HfArgumentParser, TrainingArguments, set_seed
+
+from src.label_search import find_labels
+from src.dataset import FewShotDataset
+from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings
+from src.trainer import Trainer
+from src.processors import output_modes_mapping, num_labels_mapping
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ModelArguments:
+    """
+    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
+    """
+    model_name_or_path: str = field(
+        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
+    )
+    config_name: Optional[str] = field(
+        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
+    )
+    tokenizer_name: Optional[str] = field(
+        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
+    )
+    cache_dir: Optional[str] = field(
+        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
+    )
+
+
+@dataclass
+class DynamicDataTrainingArguments(DataTrainingArguments):
+    """
+    Arguments for dynamic training.
+    """
+    # For prompting
+    template: str = field(
+        default=None,
+        metadata={"help": "Template"}
+    )
+
+    mapping: str = field(
+        default=None,
+        metadata={"help": "Label word mapping"}
+    )
+
+    debug_mode: bool = field(
+        default=False,
+        metadata={"help": "Debug mode"}
+    )
+
+    first_sent_limit: int = field(
+        default=None,
+        metadata={"help": "Limit the length of the first sentence (i.e., sent_0)"}
+    )
+
+    other_sent_limit: int = field(
+        default=None,
+        metadata={"help": "Limit the length of sentences other than the first sentence"}
+    )
+
+    use_full_length: bool = field(
+        default=None,
+        metadata={"help": "Use the full length (512)"}
+    )
+
+    truncate_head: bool = field(
+        default=False,
+        metadata={"help": "When exceeding the maximum length, truncate the head instead of the tail."}
+    )
+
+    use_space_word: bool = field(
+        default=True,
+        metadata={"help": "Use space words (e.g., Gpositive) instead of original words."}
+    )
+
+    use_seed_labels: bool = field(
+        default=False,
+        metadata={"help": "Regularize using seed labels"},
+    )
+
+    k_likely: int = field(
+        default=100,
+        metadata={"help": "Take the top-k most (conditionally) likely labels per class."}
+    )
+
+    k_neighbors: int = field(
+        default=50,
+        metadata={"help": "Re-rank by nearest neighbor, and take the top k."}
+    )
+
+    n_pairs: int = field(
+        default=32,
+        metadata={"help": "Number of label pairings to use."}
+    )
+
+    output_file: str = field(
+        default="out",
+        metadata={"help": "Output file"}
+    )
+
+    append_output_file: bool = field(
+        default=False,
+    )
+
+    write_template: bool = field(
+        default=False,
+    )
+
+
+def main():
+    parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments, TrainingArguments))
+
+    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+        # If we pass only one argument to the script and it's the path to a json file,
+        # let's parse it to get our arguments.
+        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
+    else:
+        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+    # Fix prompt to be true.
+    data_args.prompt = True
+    data_args.num_sample = 1
+    data_args.template_list = None
+    data_args.gpt3_in_context_head = False
+    data_args.gpt3_in_context_tail = False
+
+    # Setup logging
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
+    )
+
+    # Check save path
+    if (
+        os.path.exists(training_args.output_dir)
+        and os.listdir(training_args.output_dir)
+        and training_args.do_train
+        and not training_args.overwrite_output_dir
+    ):
+        raise ValueError(f"Output directory ({training_args.output_dir}) already exists.")
+
+    logger.warning(
+        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
+        training_args.local_rank,
+        training_args.device,
+        training_args.n_gpu,
+        bool(training_args.local_rank != -1),
+        training_args.fp16,
+    )
+    logger.info("Training/evaluation parameters %s", training_args)
+
+    # Set seed
+    set_seed(training_args.seed)
+
+    try:
+        num_labels = num_labels_mapping[data_args.task_name]
+        output_mode = output_modes_mapping[data_args.task_name]
+        logger.info("Task name: {}, number of labels: {}, output mode: {}".format(data_args.task_name, num_labels, output_mode))
+    except KeyError:
+        raise ValueError("Task not found: %s" % (data_args.task_name))
+
+    # Create config
+    config = AutoConfig.from_pretrained(
+        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
+        num_labels=num_labels,
+        finetuning_task=data_args.task_name,
+        cache_dir=model_args.cache_dir,
+    )
+
+    if config.model_type == 'roberta':
+        model_fn = RobertaForPromptFinetuning
+    elif config.model_type == 'bert':
+        model_fn = BertForPromptFinetuning
+    else:
+        raise NotImplementedError
+    special_tokens = []
+
+    # Create tokenizer
+    tokenizer = AutoTokenizer.from_pretrained(
+        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
+        additional_special_tokens=special_tokens,
+        cache_dir=model_args.cache_dir,
+    )
+
+    set_seed(training_args.seed)
+
+    model = model_fn.from_pretrained(
+        model_args.model_name_or_path,
+        from_tf=bool(".ckpt" in model_args.model_name_or_path),
+        config=config,
+        cache_dir=model_args.cache_dir,
+    )
+
+    # For BERT, increase the size of the segment (token type) embeddings
+    if config.model_type == 'bert':
+        model.resize_token_embeddings(len(tokenizer))
+        resize_token_type_embeddings(model, new_num_types=10, random_segment=model_args.random_segment)
+
+    # Pass dataset and argument information to the model
+    model.model_args = model_args
+    model.data_args = data_args
+    model.tokenizer = tokenizer
+    model.return_full_softmax = True
+
+    # Initialize our Trainer
+    trainer = Trainer(
+        model=model,
+        args=training_args,
+        train_dataset=None,
+        eval_dataset=None,
+    )
+
+    # First we compute zero-shot logits on all of the examples.
+    dataset = FewShotDataset(data_args, tokenizer=tokenizer, mode="train", use_demo=False)
+
+    # Predict logits.
+    dataloader = trainer.get_eval_dataloader(dataset)
+    output = trainer.prediction_loop(dataloader, description="Evaluation")
+    logits = output.predictions[0] if isinstance(output.predictions, (list, tuple)) else output.predictions
+    labels = output.label_ids
+
+    # Assign words to labels.
+    if data_args.use_seed_labels:
+        if data_args.use_space_word:
+            seed_labels = {k: "Ġ" + v for k, v in eval(data_args.mapping).items()}
+        else:
+            seed_labels = eval(data_args.word_mapping)
+        seed_labels = [seed_labels[label] for label in dataset.get_labels()]
+    else:
+        seed_labels = None
+
+    vocab = list(tokenizer.get_vocab())
+
+    # Find best labels.
+    label_pairings = find_labels(
+        model=trainer.model,
+        train_logits=logits,
+        train_labels=labels,
+        seed_labels=seed_labels,
+        k_likely=data_args.k_likely,
+        k_neighbors=data_args.k_neighbors,
+        top_n=data_args.n_pairs,
+        vocab=vocab,
+        is_regression=config.num_labels == 1)
+
+    labels = dataset.get_labels()
+    if config.num_labels == 1:
+        labels = ['0', '1']
+
+    os.makedirs(os.path.dirname(data_args.output_file), exist_ok=True)
+    if data_args.append_output_file:
+        mode = "a"
+    else:
+        mode = "w"
+
+    # Write to output.
+    with open(data_args.output_file, mode) as f:
+        for pairing in label_pairings:
+            words = [vocab[i][len("Ġ"):] for i in pairing]
+            mapping = {labels[i]: words[i] for i in range(len(labels))}
+            if data_args.write_template:
+                f.write(data_args.template + "\t")
+            f.write(json.dumps(mapping) + "\n")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/LM-BFF/tools/generate_template.py b/LM-BFF/tools/generate_template.py
new file mode 100755
index 0000000..1c5f615
--- /dev/null
+++ b/LM-BFF/tools/generate_template.py
@@ -0,0 +1,375 @@
+import transformers
+from transformers import T5ForConditionalGeneration, T5Tokenizer
+import argparse
+import torch
+import os
+from tqdm import tqdm
+import json
+import argparse
+import pandas as pd
+
+def get_text(template, input_text_tuple, label, tokenizer, mapping):
+    def enc(text):
+        return tokenizer.encode(text, add_special_tokens=False)
+    special_token_mapping = {'cls': tokenizer.cls_token_id, 'mask': tokenizer.mask_token_id, 'sep': tokenizer.sep_token_id, 'sep+': tokenizer.sep_token_id}
+    for i in range(10):
+        special_token_mapping["<extra_id_%d>" % (i)] = tokenizer._convert_token_to_id("<extra_id_%d>" % (i))
+    template_list = template.split('*')
+    input_ids = []
+    for part in template_list:
+        new_tokens = []
+        if part in special_token_mapping:
+            if part == 'cls' and 'T5' in type(tokenizer).__name__:
+                # T5 does not have cls token
+                continue
+            new_tokens.append(special_token_mapping[part])
+        elif part[:5] == 'label':
+            new_tokens += enc(' ' + mapping[label])
+        elif part[:5] == 'sent_':
+            sent_id = int(part.split('_')[1])
+            new_tokens += enc(input_text_tuple[sent_id])
+        elif part[:6] == '+sent_':
+            sent_id = int(part.split('_')[1])
+            new_tokens += enc(' ' + input_text_tuple[sent_id]) # add space
+        elif part[:6] == 'sent-_':
+            # Delete the last token
+            sent_id = int(part.split('_')[1])
+            new_tokens += enc(input_text_tuple[sent_id][:-1])
+        elif part[:7] == '+sentl_':
+            # Lower case the first token
+            sent_id = int(part.split('_')[1])
+            text = input_text_tuple[sent_id]
+            text = text[:1].lower() + text[1:]
+            new_tokens += enc(' ' + text)
+        elif part[:7] == '+sentu_':
+            # Upper case the first token
+            sent_id = int(part.split('_')[1])
+            text = input_text_tuple[sent_id]
+            text = text[:1].upper() + text[1:]
+            new_tokens += enc(' ' + text)
+        elif part[:6] == 'sentl_':
+            # Lower case the first token
+            sent_id = int(part.split('_')[1])
+            text = input_text_tuple[sent_id]
+            text = text[:1].lower() + text[1:]
+            new_tokens += enc(text)
+        elif part[:6] == 'sentu_':
+            # Lower case the first token
+            sent_id = int(part.split('_')[1])
+            text = input_text_tuple[sent_id]
+            text = text[:1].upper() + text[1:]
+            new_tokens += enc(text)
+        elif part[:7] == 'sentl-_':
+            # Lower case the first token
+            sent_id = int(part.split('_')[1])
+            text = input_text_tuple[sent_id]
+            text = text[:1].lower() + text[1:]
+            new_tokens += enc(text[:-1])
+        else:
+            part = part.replace('_', ' ') # there cannot be space in command, so use '_' to replace space
+            # handle special case when t5 tokenizer might add an extra space
+            if len(part) == 1:
+                new_tokens.append(tokenizer._convert_token_to_id(part))
+            else:
+                new_tokens += enc(part)
+
+        input_ids += new_tokens
+    return input_ids
+
+def generate(dataset, template, model, tokenizer, target_number, mapping, beam, label=None, length_limit=None, truncate=None):
+    """
+    Generate templates based on given inputs
+
+    label: Only use instances with this label (deprecated)
+    length_limit: At least generate content as long as length_limit (deprecated)
+    """
+    input_texts = []
+    input_tensors = []
+    max_length = 0
+
+    # Process the inputs
+    for item in dataset:
+        if label is None or item['label'] == label:
+            input_text = get_text(template, item['text'], item['label'], tokenizer, mapping)
+            if truncate is not None:
+                if truncate == 'head':
+                    input_text = input_text[-256:]
+                elif truncate == 'tail':
+                    input_text = input_text[:256]
+                else:
+                    raise NotImplementedError
+            input_ids = torch.tensor(input_text).long()
+            max_length = max(max_length, input_ids.size(-1))
+            input_tensors.append(input_ids)
+
+    # Concatenate inputs as a batch
+    input_ids = torch.zeros((len(input_tensors), max_length)).long()
+    attention_mask = torch.zeros((len(input_tensors), max_length)).long()
+    for i in range(len(input_tensors)):
+        input_ids[i, :input_tensors[i].size(-1)] = input_tensors[i]
+        attention_mask[i, :input_tensors[i].size(-1)] = 1
+
+    # Print some examples
+    print('####### example #######')
+    print(tokenizer.decode(input_ids[0]))
+    print(tokenizer.decode(input_ids[1]))
+    print(tokenizer.decode(input_ids[2]))
+    print('####### example #######\n')
+
+    #input_ids = input_ids.cuda()
+    #attention_mask = attention_mask.cuda()
+    assert len(input_tensors) > 0
+
+    # Maximum generate content length
+    max_length = 20
+
+    start_mask = tokenizer._convert_token_to_id('<extra_id_0>')
+    ori_decoder_input_ids = torch.zeros((input_ids.size(0), max_length)).long()
+    ori_decoder_input_ids[..., 0] = model.config.decoder_start_token_id
+
+    # decoder_input_ids: decoder inputs for next regressive generation
+    # ll: log likelihood
+    # output_id: which part of generated contents we are at
+    # output: generated content so far
+    # last_length (deprecated): how long we have generated for this part
+    current_output = [{'decoder_input_ids': ori_decoder_input_ids, 'll': 0, 'output_id': 1, 'output': [], 'last_length': -1}]
+    for i in tqdm(range(max_length - 2)):
+        new_current_output = []
+        for item in current_output:
+            if item['output_id'] > target_number:
+                # Enough contents
+                new_current_output.append(item)
+                continue
+            decoder_input_ids = item['decoder_input_ids']
+
+            # Forward
+            batch_size = 10
+            turn = input_ids.size(0) // batch_size
+            if input_ids.size(0) % batch_size != 0:
+                turn += 1
+            aggr_output = []
+            for t in range(turn):
+                start = t * batch_size
+                end = min((t + 1) * batch_size, input_ids.size(0))
+
+                with torch.no_grad():
+                    #aggr_output.append(model(input_ids[start:end], attention_mask=attention_mask[start:end], decoder_input_ids=decoder_input_ids.cuda()[start:end])[0])
+                    aggr_output.append(model(input_ids[start:end], attention_mask=attention_mask[start:end], decoder_input_ids=decoder_input_ids[start:end])[0])
+            aggr_output = torch.cat(aggr_output, 0)
+
+            # Gather results across all input sentences, and sort generated tokens by log likelihood
+            aggr_output = aggr_output.mean(0)
+            log_denominator = torch.logsumexp(aggr_output[i], -1).item()
+            ids = list(range(model.config.vocab_size))
+            ids.sort(key=lambda x: aggr_output[i][x].item(), reverse=True)
+            ids = ids[:beam+3]
+            
+            for word_id in ids:
+                output_id = item['output_id']
+
+                if word_id == start_mask - output_id or word_id == tokenizer._convert_token_to_id('</s>'):
+                    # Finish one part
+                    if length_limit is not None and item['last_length'] < length_limit[output_id - 1]:
+                        check = False
+                    else:
+                        check = True
+                    output_id += 1
+                    last_length = 0
+                else:
+                    last_length = item['last_length'] + 1
+                    check = True
+
+                output_text = item['output'] + [word_id]
+                ll = item['ll'] + aggr_output[i][word_id] - log_denominator
+                new_decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.size())
+                new_decoder_input_ids[:] = decoder_input_ids
+                new_decoder_input_ids[..., i + 1] = word_id
+
+                # Forbid single space token, "....", and ".........."
+                if word_id in [3, 19794, 22354]:
+                    check = False
+
+                # Forbid continuous "."
+                if len(output_text) > 1 and output_text[-2] == 5 and output_text[-1] == 5:
+                    check = False
+
+                if check:
+                    # Add new results to beam search pool
+                    new_item = {'decoder_input_ids': new_decoder_input_ids, 'll': ll, 'output_id': output_id, 'output': output_text, 'last_length': last_length}
+                    new_current_output.append(new_item)
+
+        if len(new_current_output) == 0:
+            break
+
+        new_current_output.sort(key=lambda x: x['ll'], reverse=True)
+        new_current_output = new_current_output[:beam]
+        current_output = new_current_output
+
+    result = []
+    print("####### generated results #######")
+    for item in current_output:
+        generate_text = ''
+        for token in item['output']:
+            generate_text += tokenizer._convert_id_to_token(token)
+        print('--------------')
+        print('score:', item['ll'].item())
+        print('generated ids', item['output'])
+        print('generated text', generate_text)
+        result.append(generate_text)
+    print("####### generated results #######\n")
+
+    return result
+
+def load_dataset(task, data_dir):
+    if task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI", "CoLA"]:
+        lines = open(os.path.join(data_dir, 'train.tsv')).readlines()
+        if task != 'CoLA':
+            lines = lines[1:]
+
+        dataset = []
+        for line in lines:
+            line = line.strip().split('\t')
+            if task == 'CoLA':
+                dataset.append({'label': line[1], 'text': [line[-1]]})
+            elif task == 'MNLI':
+                dataset.append({'label': line[-1], 'text': [line[8], line[9]]})
+            elif task == 'MRPC':
+                dataset.append({'label': line[0], 'text': [line[-2], line[-1]]})
+            elif task == 'QNLI':
+                dataset.append({'label': line[-1], 'text': [line[1], line[2]]})
+            elif task == 'QQP':
+                dataset.append({'label': line[-1], 'text': [line[3], line[4]]})
+            elif task == 'RTE':
+                dataset.append({'label': line[-1], 'text': [line[1], line[2]]})
+            elif task == 'SNLI':
+                dataset.append({'label': line[-1], 'text': [line[7], line[8]]})
+            elif task == 'SST-2':
+                dataset.append({'label': line[-1], 'text': [line[0]]})
+            elif task == 'STS-B':
+                dataset.append({'label': '0' if float(line[-1]) < 2.5 else '1', 'text': [line[-3], line[-2]]})
+            elif task == 'WNLI':
+                dataset.append({'label': line[-1], 'text': [line[1], line[2]]})
+            else:
+                raise NotImplementedError
+    else:
+        lines = pd.read_csv(os.path.join(data_dir, 'train.csv')).values.tolist()
+        dataset = []
+        for line in lines:
+            dataset.append({'label': line[0], 'text': [line[1]]})
+
+    return dataset
+
+def search_template(model, tokenizer, task_name, k, seed, beam, output_dir, data_dir):
+    print('#', task_name, k, seed, beam)
+    dataset_path = os.path.join(data_dir, task_name, "{}-{}".format(k, seed))
+    dataset = load_dataset(task_name, dataset_path)
+    print('|', 'dataset examples')
+    print('|', dataset[0])
+    print('|', dataset[-1])
+    print()
+    
+    # Manual label word mappings
+    map_of_mapping = {
+        'SST-2': {'0':'terrible','1':'great'},
+        'sst-5': {0:'terrible',1:'bad',2:'okay',3:'good',4:'great'},
+        'mr': {0:'terrible',1:'great'},
+        'cr': {0:'terrible',1:'great'},
+        'subj': {0:'subjective',1:'objective'},
+        'trec': {0:'Description',1:'Entity',2:'Expression',3:'Human',4:'Location',5:'Number'},
+        'mpqa': {0:'terrible',1:'great'},
+        'CoLA': {'0':'incorrect','1':'correct'},
+        'MRPC': {'0':'No','1':'Yes'},
+        'QQP': {'0':'No','1':'Yes'},
+        'STS-B': {'0':'No','1':'Yes'},
+        'MNLI': {'contradiction':'No','entailment':'Yes','neutral':'Maybe'},
+        'SNLI': {'contradiction':'No','entailment':'Yes','neutral':'Maybe'},
+        'QNLI': {'not_entailment':'No','entailment':'Yes'},
+        'RTE': {'not_entailment':'No','entailment':'Yes'}
+    }
+
+    mapping = map_of_mapping[task_name]
+    print('|', 'mapping')
+    print('|', mapping)
+
+    os.makedirs(output_dir, exist_ok=True)
+    os.makedirs(os.path.join(output_dir, task_name), exist_ok=True)
+    f = open(os.path.join(output_dir, task_name, "{}-{}.txt".format(k, seed)), 'w')
+
+    if task_name in ['SST-2', 'sst-5', 'mr', 'cr', 'subj', 'trec', 'CoLA', 'mpqa']:
+        # Single sentence tasks
+        # We take two kinds of templates: put [MASK] at the beginning or the end
+        template = "*cls**sentu_0**<extra_id_0>**label**<extra_id_1>**sep+*"
+        generate_text = generate(dataset, template, model, tokenizer, target_number=2, mapping=mapping, beam=beam, label=None, truncate='head')[:beam//2]
+
+        print("####### generated templates #######")
+        for text in generate_text:
+            # Transform T5 outputs to our template format
+            text = text.replace('<extra_id_0>', '*cls**sent_0*')
+            text = text.replace('<extra_id_1>', '*mask*')
+            text = text.replace('<extra_id_2>', '*sep+*')
+            text = text.replace('</s>', '*sep+*')
+            text = text.replace('▁', '_')
+            print(text)
+            f.write(text + '\n')
+        print("####### generated templates #######\n")
+
+        template = "*cls*.*<extra_id_0>**label**<extra_id_1>**+sentu_0**sep+*"
+        generate_text = generate(dataset, template, model, tokenizer, target_number=2, mapping=mapping, beam=beam, label=None, truncate='tail')[:beam//2]
+        print("####### generated templates #######")
+        for text in generate_text:
+            # Transform T5 outputs to our template format
+            text = text.replace('<extra_id_0>', '*cls*')
+            text = text.replace('<extra_id_1>', '*mask*')
+            text = text.replace('<extra_id_2>', '*+sent_0**sep+*')
+            text = text.replace('</s>', '*+sent_0**sep+*')
+            text = text.replace('▁', '_')
+            print(text)
+            f.write(text + '\n')
+        print("####### generated templates #######\n")
+
+    elif task_name in ['MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE']:
+        # Sentence pair tasks
+        # We always put [MASK] between the two sentences
+        template = "*cls**sent-_0**<extra_id_0>**label**<extra_id_1>**+sentl_1**sep+*"
+        generate_text = generate(dataset, template, model, tokenizer, target_number=2, mapping=mapping, beam=beam, label=None)
+        print("####### generated templates #######")
+        for text in generate_text:
+            # Transform T5 outputs to our template format
+            text = text.replace('<extra_id_0>', '*cls**sent-_0*')
+            text = text.replace('<extra_id_1>', '*mask*')
+            text = text.replace('<extra_id_2>', '*+sentl_1**sep+*')
+            text = text.replace('</s>', '*+sentl_1**sep+*')
+            text = text.replace('▁', '_')
+            print(text)
+            f.write(text + '\n')
+        print("####### generated templates #######\n")
+    else:
+        raise NotImplementedError
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--t5_model', type=str, default='t5-3b', help='T5 pre-trained model')
+    parser.add_argument('--seed', type=int, nargs='+', default=[42, 13, 21, 100, 87], help="Data split seeds")
+    parser.add_argument('--task_name', type=str, nargs='+', default=['SST-2', 'sst-5', 'mr', 'cr', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'], help="Task names")
+    parser.add_argument('--output_dir', type=str, default='Output directory')
+
+    parser.add_argument('--data_dir', type=str, default="data/k-shot", help="Data directory")
+    parser.add_argument('--beam', type=int, default=100, help="Beam search width")
+    parser.add_argument('--k', type=int, default=16, help="Number of training instances per label")
+ 
+    args = parser.parse_args()
+
+    model = T5ForConditionalGeneration.from_pretrained(args.t5_model)
+    tokenizer = T5Tokenizer.from_pretrained(args.t5_model)
+    tokenizer.sep_token = '</s>'
+
+    #model = model.cuda()
+    model.eval()
+
+    for task_name in args.task_name:
+        for seed in args.seed:
+            search_template(model=model, tokenizer=tokenizer, task_name=task_name, k=args.k, seed=seed, beam=args.beam, output_dir=args.output_dir, data_dir=args.data_dir)
+
+if __name__ == '__main__':
+    main()
diff --git a/LM-BFF/tools/get_sbert_embedding.py b/LM-BFF/tools/get_sbert_embedding.py
new file mode 100755
index 0000000..806243c
--- /dev/null
+++ b/LM-BFF/tools/get_sbert_embedding.py
@@ -0,0 +1,105 @@
+from sentence_transformers import SentenceTransformer, util
+import argparse
+import os
+import numpy as np
+from tqdm import tqdm
+import pandas as pd
+
+def get_sentence(task, line):
+    if task in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']:
+        # Text classification tasks
+        if line[1] is None or pd.isna(line[1]):
+            return ''
+        else:
+            return line[1]
+    else:
+        # GLUE tasks
+        line = line.strip().split('\t')
+        if task == 'CoLA':
+            return line[-1]
+        elif task == 'MNLI':
+            return line[8] + ' ' + line[9]
+        elif task == 'MRPC':
+            return line[-2] + ' ' + line[-1]
+        elif task == 'QNLI':
+            return line[1] + ' ' + line[2]
+        elif task == 'QQP':
+            return line[3] + ' ' + line[4]
+        elif task == 'RTE':
+            return line[1] + ' ' + line[2]
+        elif task == 'SNLI':
+            return line[7] + ' ' + line[8]
+        elif task == 'SST-2':
+            return line[0]
+        elif task == 'STS-B':
+            return line[-3] + ' ' + line[-2]
+        elif task == 'WNLI':
+            return line[1] + ' ' + line[2]
+        else:
+            raise NotImplementedError
+
+def split_header(task, lines):
+    """Returns if the task file has a header or not."""
+    if task in ["CoLA"]:
+        return [], lines
+    elif task in ["MNLI", "MRPC", "QNLI", "QQP", "RTE", "SNLI", "SST-2", "STS-B", "WNLI"]:
+        return lines[0:1], lines[1:]
+    else:
+        raise ValueError("Unknown GLUE task.")
+
+def load_datasets(data_dir, task, do_test=False):
+    dataset = {}
+    if task == "MNLI":
+        splits = ["train", "dev_matched"]
+        if do_test:
+            splits += ['test_matched', 'test_mismatched'] 
+    else:
+        splits = ["train", "dev"]
+        if do_test:
+            splits.append('test')
+    for split in splits:
+        if task in ['mr', 'sst-5', 'subj', 'trec', 'cr', 'mpqa']:
+            filename = os.path.join(data_dir, f"{split}.csv")
+            dataset[split] = pd.read_csv(filename, header=None).values.tolist()
+        else:
+            filename = os.path.join(data_dir, f"{split}.tsv")
+            with open(filename, "r") as f:
+                lines = f.readlines()
+                header, content = split_header(task, lines)
+            dataset[split] = content
+    return dataset
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--do_test", action='store_true', help="Generate embeddings for test splits (test set is usually large, so we don't want to repeatedly generate embeddings for them)")
+    parser.add_argument("--sbert_model", type=str, default='roberta-large', help="Sentence BERT model name")
+
+    parser.add_argument("--k", type=int, help="Number of training instances per label", default=16)
+    parser.add_argument("--data_dir", type=str, default="data/k-shot", help="Path to few-shot data")
+    parser.add_argument("--seed", type=int, nargs="+", default=[42, 13, 21, 87, 100], help="Seeds for data splits")
+    parser.add_argument("--task", type=str, nargs="+", default=["SST-2", "sst-5", "mr", "cr", "mpqa", "subj", "trec", "CoLA", "MRPC", "QQP", "STS-B", "MNLI", "SNLI", "QNLI", "RTE"], help="Tasks")
+
+    args = parser.parse_args()
+
+    model = SentenceTransformer('{}-nli-stsb-mean-tokens'.format(args.sbert_model))
+    model = model.cuda()
+
+    for task in args.task:
+        for seed in args.seed:
+            folder = os.path.join(args.data_dir, task, '{}-{}'.format(args.k, seed))
+            dataset = load_datasets(folder, task, do_test=args.do_test)
+            for split in dataset:
+                print('{}-{}-{}-{}'.format(task, args.k, seed, split))
+                lines = dataset[split]
+                embeddings = []
+                for line_id, line in tqdm(enumerate(lines)):
+                    sent = get_sentence(task, line)
+                    if line_id == 0:
+                        print('|', sent)
+                    emb = model.encode(sent)
+                    embeddings.append(emb)
+                embeddings = np.stack(embeddings)
+                np.save(os.path.join(folder, "{}_sbert-{}.npy".format(split, args.sbert_model)), embeddings)
+
+if __name__ == '__main__':
+    main()
diff --git a/LM-BFF/tools/sort_mapping.py b/LM-BFF/tools/sort_mapping.py
new file mode 100755
index 0000000..a3ea209
--- /dev/null
+++ b/LM-BFF/tools/sort_mapping.py
@@ -0,0 +1,173 @@
+import argparse
+import json
+import numpy as np
+import torch
+from torch import device
+import os
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
+    parser.add_argument('--mapping_dir', type=str, help='Mapping directory')
+
+    # These options should be kept as their default values
+    parser.add_argument("--k", type=int, default=16)
+    parser.add_argument("--log", type=str, default="log", help="Log path.")
+    parser.add_argument("--key", type=str, default='', help="Validation metric name")
+    parser.add_argument("--test_key", type=str, default="", help="Test metric name")
+    parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
+
+    args = parser.parse_args()
+
+    condition = eval(args.condition)
+
+    if len(args.key) == 0:
+        if condition['task_name'] == 'cola':
+            args.key = 'cola_dev_eval_mcc'
+            args.test_key = 'cola_test_eval_mcc'
+            print_name = 'CoLA'
+        elif condition['task_name'] == 'mrpc/acc':
+            args.key = 'mrpc_dev_eval_acc'
+            args.test_key = 'mrpc_test_eval_acc'
+            args.test_key2 = 'mrpc_test_eval_f1'
+            condition['task_name'] = 'mrpc'
+            print_name = 'MRPC'
+        elif condition['task_name'] == 'mrpc/f1':
+            args.key = 'mrpc_dev_eval_f1'
+            args.test_key2 = 'mrpc_test_eval_acc'
+            args.test_key = 'mrpc_test_eval_f1'
+            condition['task_name'] = 'mrpc'
+            print_name = 'MRPC'
+        elif condition['task_name'] == 'qqp/acc':
+            args.key = 'qqp_dev_eval_acc'
+            args.test_key = 'qqp_test_eval_acc'
+            args.test_key2 = 'qqp_test_eval_f1'
+            condition['task_name'] = 'qqp'
+            print_name = 'QQP'
+        elif condition['task_name'] == 'qqp/f1':
+            args.key = 'qqp_dev_eval_f1'
+            args.test_key2 = 'qqp_test_eval_acc'
+            args.test_key = 'qqp_test_eval_f1'
+            condition['task_name'] = 'qqp'
+            print_name = 'QQP'
+        elif condition['task_name'] == 'sts-b/pearson':
+            args.key = 'sts-b_dev_eval_pearson'
+            args.test_key = 'sts-b_test_eval_pearson'
+            args.test_key2 = 'sts-b_test_eval_spearmanr'
+            condition['task_name'] = 'sts-b'
+            print_name = 'STS-B'
+        elif condition['task_name'] == 'sts-b/spearmanr':
+            args.key = 'sts-b_dev_eval_spearmanr'
+            args.test_key2 = 'sts-b_test_eval_pearson'
+            args.test_key = 'sts-b_test_eval_spearmanr'
+            condition['task_name'] = 'sts-b'
+            print_name = 'STS-B'
+        elif condition['task_name'] == 'qnli':
+            args.key = 'qnli_dev_eval_acc'
+            args.test_key = 'qnli_test_eval_acc'
+            print_name = 'QNLI'
+        elif condition['task_name'] == 'sst-2':
+            args.key = 'sst-2_dev_eval_acc'
+            args.test_key = 'sst-2_test_eval_acc'
+            print_name = 'SST-2'
+        elif condition['task_name'] == 'snli':
+            args.key = 'snli_dev_eval_acc'
+            args.test_key = 'snli_test_eval_acc'
+            print_name = 'SNLI'
+        elif condition['task_name'] == 'mnli':
+            args.key = 'mnli_dev_eval_mnli/acc'
+            args.test_key = 'mnli_test_eval_mnli/acc'
+            print_name = 'MNLI'
+        elif condition['task_name'] == 'mnli-mm':
+            condition['task_name'] = 'mnli'
+            args.key = 'mnli_dev_eval_mnli/acc'
+            args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
+            print_name = 'MNLI'
+        elif condition['task_name'] == 'rte':
+            args.key = 'rte_dev_eval_acc'
+            args.test_key = 'rte_test_eval_acc'
+            print_name = 'RTE'
+        elif condition['task_name'] == 'ag_news':
+            args.key = 'ag_news_dev_eval_acc'
+            args.test_key = 'ag_news_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'yahoo_answers':
+            args.key = 'yahoo_answers_dev_eval_acc'
+            args.test_key = 'yahoo_answers_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'yelp_review_full':
+            args.key = 'yelp_review_full_dev_eval_acc'
+            args.test_key = 'yelp_review_full_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'mr':
+            args.key = 'mr_dev_eval_acc'
+            args.test_key = 'mr_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'sst-5':
+            args.key = 'sst-5_dev_eval_acc'
+            args.test_key = 'sst-5_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'subj':
+            args.key = 'subj_dev_eval_acc'
+            args.test_key = 'subj_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'trec':
+            args.key = 'trec_dev_eval_acc'
+            args.test_key = 'trec_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'cr':
+            args.key = 'cr_dev_eval_acc'
+            args.test_key = 'cr_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'mpqa':
+            args.key = 'mpqa_dev_eval_acc'
+            args.test_key = 'mpqa_test_eval_acc'
+            print_name = condition['task_name']
+        else:
+            raise NotImplementedError
+
+    with open(args.log) as f:
+        result_list = []
+        for line in f:
+            result_list.append(eval(line))
+    
+    seed_result = {}
+    seed_result_mapping_id = {} # avoid duplication
+
+    for item in result_list:
+        ok = True
+        for cond in condition:
+            if cond not in item or item[cond] != condition[cond]:
+                ok = False
+                break
+        
+        if ok:
+            seed = item['seed']
+            if seed not in seed_result:
+                seed_result[seed] = [item]
+                seed_result_mapping_id[seed] = {item['mapping_id']: 1}
+            else:
+                if item['mapping_id'] not in seed_result_mapping_id[seed]:
+                    seed_result[seed].append(item)
+                    seed_result_mapping_id[seed][item['mapping_id']] = 1
+
+    for seed in seed_result:
+        print("Seed %d has %d results" % (seed, len(seed_result[seed])))
+
+        # Load all mappings
+        with open(os.path.join(args.mapping_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f:
+            mappings = []
+            for line in f:
+                mappings.append(line.strip())
+
+        # Write sorted mappings
+        fsort = open(os.path.join(args.mapping_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w')
+        fscore = open(os.path.join(args.mapping_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w')
+
+        seed_result[seed].sort(key=lambda x: x[args.key], reverse=True)
+        for item in seed_result[seed]:
+            fsort.write(mappings[item['mapping_id']] + '\n')
+            fscore.write("%.5f %s\n" % (item[args.key], mappings[item['mapping_id']]))
+
+if __name__ == '__main__':
+    main()
diff --git a/LM-BFF/tools/sort_prompt.py b/LM-BFF/tools/sort_prompt.py
new file mode 100755
index 0000000..ef6420d
--- /dev/null
+++ b/LM-BFF/tools/sort_prompt.py
@@ -0,0 +1,173 @@
+import argparse
+import json
+import numpy as np
+import torch
+from torch import device
+import os
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
+    parser.add_argument('--prompt_dir', type=str, help='Prompt directory')
+
+    # These options should be kept as their default values
+    parser.add_argument("--k", type=int, default=16)
+    parser.add_argument("--log", type=str, default="log", help="Log path.")
+    parser.add_argument("--key", type=str, default='', help="Validation metric name")
+    parser.add_argument("--test_key", type=str, default="", help="Test metric name")
+    parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
+
+    args = parser.parse_args()
+
+    condition = eval(args.condition)
+
+    if len(args.key) == 0:
+        if condition['task_name'] == 'cola':
+            args.key = 'cola_dev_eval_mcc'
+            args.test_key = 'cola_test_eval_mcc'
+            print_name = 'CoLA'
+        elif condition['task_name'] == 'mrpc/acc':
+            args.key = 'mrpc_dev_eval_acc'
+            args.test_key = 'mrpc_test_eval_acc'
+            args.test_key2 = 'mrpc_test_eval_f1'
+            condition['task_name'] = 'mrpc'
+            print_name = 'MRPC'
+        elif condition['task_name'] == 'mrpc/f1':
+            args.key = 'mrpc_dev_eval_f1'
+            args.test_key2 = 'mrpc_test_eval_acc'
+            args.test_key = 'mrpc_test_eval_f1'
+            condition['task_name'] = 'mrpc'
+            print_name = 'MRPC'
+        elif condition['task_name'] == 'qqp/acc':
+            args.key = 'qqp_dev_eval_acc'
+            args.test_key = 'qqp_test_eval_acc'
+            args.test_key2 = 'qqp_test_eval_f1'
+            condition['task_name'] = 'qqp'
+            print_name = 'QQP'
+        elif condition['task_name'] == 'qqp/f1':
+            args.key = 'qqp_dev_eval_f1'
+            args.test_key2 = 'qqp_test_eval_acc'
+            args.test_key = 'qqp_test_eval_f1'
+            condition['task_name'] = 'qqp'
+            print_name = 'QQP'
+        elif condition['task_name'] == 'sts-b/pearson':
+            args.key = 'sts-b_dev_eval_pearson'
+            args.test_key = 'sts-b_test_eval_pearson'
+            args.test_key2 = 'sts-b_test_eval_spearmanr'
+            condition['task_name'] = 'sts-b'
+            print_name = 'STS-B'
+        elif condition['task_name'] == 'sts-b/spearmanr':
+            args.key = 'sts-b_dev_eval_spearmanr'
+            args.test_key2 = 'sts-b_test_eval_pearson'
+            args.test_key = 'sts-b_test_eval_spearmanr'
+            condition['task_name'] = 'sts-b'
+            print_name = 'STS-B'
+        elif condition['task_name'] == 'qnli':
+            args.key = 'qnli_dev_eval_acc'
+            args.test_key = 'qnli_test_eval_acc'
+            print_name = 'QNLI'
+        elif condition['task_name'] == 'sst-2':
+            args.key = 'sst-2_dev_eval_acc'
+            args.test_key = 'sst-2_test_eval_acc'
+            print_name = 'SST-2'
+        elif condition['task_name'] == 'snli':
+            args.key = 'snli_dev_eval_acc'
+            args.test_key = 'snli_test_eval_acc'
+            print_name = 'SNLI'
+        elif condition['task_name'] == 'mnli':
+            args.key = 'mnli_dev_eval_mnli/acc'
+            args.test_key = 'mnli_test_eval_mnli/acc'
+            print_name = 'MNLI'
+        elif condition['task_name'] == 'mnli-mm':
+            condition['task_name'] = 'mnli'
+            args.key = 'mnli_dev_eval_mnli/acc'
+            args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
+            print_name = 'MNLI'
+        elif condition['task_name'] == 'rte':
+            args.key = 'rte_dev_eval_acc'
+            args.test_key = 'rte_test_eval_acc'
+            print_name = 'RTE'
+        elif condition['task_name'] == 'ag_news':
+            args.key = 'ag_news_dev_eval_acc'
+            args.test_key = 'ag_news_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'yahoo_answers':
+            args.key = 'yahoo_answers_dev_eval_acc'
+            args.test_key = 'yahoo_answers_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'yelp_review_full':
+            args.key = 'yelp_review_full_dev_eval_acc'
+            args.test_key = 'yelp_review_full_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'mr':
+            args.key = 'mr_dev_eval_acc'
+            args.test_key = 'mr_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'sst-5':
+            args.key = 'sst-5_dev_eval_acc'
+            args.test_key = 'sst-5_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'subj':
+            args.key = 'subj_dev_eval_acc'
+            args.test_key = 'subj_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'trec':
+            args.key = 'trec_dev_eval_acc'
+            args.test_key = 'trec_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'cr':
+            args.key = 'cr_dev_eval_acc'
+            args.test_key = 'cr_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'mpqa':
+            args.key = 'mpqa_dev_eval_acc'
+            args.test_key = 'mpqa_test_eval_acc'
+            print_name = condition['task_name']
+        else:
+            raise NotImplementedError
+
+    with open(args.log) as f:
+        result_list = []
+        for line in f:
+            result_list.append(eval(line))
+    
+    seed_result = {}
+    seed_result_prompt_id = {} # avoid duplication
+
+    for item in result_list:
+        ok = True
+        for cond in condition:
+            if cond not in item or item[cond] != condition[cond]:
+                ok = False
+                break
+        
+        if ok:
+            seed = item['seed']
+            if seed not in seed_result:
+                seed_result[seed] = [item]
+                seed_result_prompt_id[seed] = {item['prompt_id']: 1}
+            else:
+                if item['prompt_id'] not in seed_result_prompt_id[seed]:
+                    seed_result[seed].append(item)
+                    seed_result_prompt_id[seed][item['prompt_id']] = 1
+
+    for seed in seed_result:
+        print("Seed %d has %d results" % (seed, len(seed_result[seed])))
+
+        # Load all prompts
+        with open(os.path.join(args.prompt_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f:
+            prompts = []
+            for line in f:
+                prompts.append(line.strip())
+
+        # Write sorted prompts
+        fsort = open(os.path.join(args.prompt_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w')
+        fscore = open(os.path.join(args.prompt_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w')
+
+        seed_result[seed].sort(key=lambda x: x[args.key], reverse=True)
+        for item in seed_result[seed]:
+            fsort.write(prompts[item['prompt_id']] + '\n')
+            fscore.write("%.5f %s\n" % (item[args.key], prompts[item['prompt_id']]))
+
+if __name__ == '__main__':
+    main()
diff --git a/LM-BFF/tools/sort_template.py b/LM-BFF/tools/sort_template.py
new file mode 100755
index 0000000..64239b7
--- /dev/null
+++ b/LM-BFF/tools/sort_template.py
@@ -0,0 +1,173 @@
+import argparse
+import json
+import numpy as np
+import torch
+from torch import device
+import os
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--condition", type=str, help="A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)")
+    parser.add_argument('--template_dir', type=str, help='Template directory')
+
+    # These options should be kept as their default values
+    parser.add_argument("--k", type=int, default=16)
+    parser.add_argument("--log", type=str, default="log", help="Log path.")
+    parser.add_argument("--key", type=str, default='', help="Validation metric name")
+    parser.add_argument("--test_key", type=str, default="", help="Test metric name")
+    parser.add_argument("--test_key2", type=str, default="", help="Second test metric name")
+
+    args = parser.parse_args()
+
+    condition = eval(args.condition)
+
+    if len(args.key) == 0:
+        if condition['task_name'] == 'cola':
+            args.key = 'cola_dev_eval_mcc'
+            args.test_key = 'cola_test_eval_mcc'
+            print_name = 'CoLA'
+        elif condition['task_name'] == 'mrpc/acc':
+            args.key = 'mrpc_dev_eval_acc'
+            args.test_key = 'mrpc_test_eval_acc'
+            args.test_key2 = 'mrpc_test_eval_f1'
+            condition['task_name'] = 'mrpc'
+            print_name = 'MRPC'
+        elif condition['task_name'] == 'mrpc/f1':
+            args.key = 'mrpc_dev_eval_f1'
+            args.test_key2 = 'mrpc_test_eval_acc'
+            args.test_key = 'mrpc_test_eval_f1'
+            condition['task_name'] = 'mrpc'
+            print_name = 'MRPC'
+        elif condition['task_name'] == 'qqp/acc':
+            args.key = 'qqp_dev_eval_acc'
+            args.test_key = 'qqp_test_eval_acc'
+            args.test_key2 = 'qqp_test_eval_f1'
+            condition['task_name'] = 'qqp'
+            print_name = 'QQP'
+        elif condition['task_name'] == 'qqp/f1':
+            args.key = 'qqp_dev_eval_f1'
+            args.test_key2 = 'qqp_test_eval_acc'
+            args.test_key = 'qqp_test_eval_f1'
+            condition['task_name'] = 'qqp'
+            print_name = 'QQP'
+        elif condition['task_name'] == 'sts-b/pearson':
+            args.key = 'sts-b_dev_eval_pearson'
+            args.test_key = 'sts-b_test_eval_pearson'
+            args.test_key2 = 'sts-b_test_eval_spearmanr'
+            condition['task_name'] = 'sts-b'
+            print_name = 'STS-B'
+        elif condition['task_name'] == 'sts-b/spearmanr':
+            args.key = 'sts-b_dev_eval_spearmanr'
+            args.test_key2 = 'sts-b_test_eval_pearson'
+            args.test_key = 'sts-b_test_eval_spearmanr'
+            condition['task_name'] = 'sts-b'
+            print_name = 'STS-B'
+        elif condition['task_name'] == 'qnli':
+            args.key = 'qnli_dev_eval_acc'
+            args.test_key = 'qnli_test_eval_acc'
+            print_name = 'QNLI'
+        elif condition['task_name'] == 'sst-2':
+            args.key = 'sst-2_dev_eval_acc'
+            args.test_key = 'sst-2_test_eval_acc'
+            print_name = 'SST-2'
+        elif condition['task_name'] == 'snli':
+            args.key = 'snli_dev_eval_acc'
+            args.test_key = 'snli_test_eval_acc'
+            print_name = 'SNLI'
+        elif condition['task_name'] == 'mnli':
+            args.key = 'mnli_dev_eval_mnli/acc'
+            args.test_key = 'mnli_test_eval_mnli/acc'
+            print_name = 'MNLI'
+        elif condition['task_name'] == 'mnli-mm':
+            condition['task_name'] = 'mnli'
+            args.key = 'mnli_dev_eval_mnli/acc'
+            args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
+            print_name = 'MNLI'
+        elif condition['task_name'] == 'rte':
+            args.key = 'rte_dev_eval_acc'
+            args.test_key = 'rte_test_eval_acc'
+            print_name = 'RTE'
+        elif condition['task_name'] == 'ag_news':
+            args.key = 'ag_news_dev_eval_acc'
+            args.test_key = 'ag_news_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'yahoo_answers':
+            args.key = 'yahoo_answers_dev_eval_acc'
+            args.test_key = 'yahoo_answers_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'yelp_review_full':
+            args.key = 'yelp_review_full_dev_eval_acc'
+            args.test_key = 'yelp_review_full_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'mr':
+            args.key = 'mr_dev_eval_acc'
+            args.test_key = 'mr_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'sst-5':
+            args.key = 'sst-5_dev_eval_acc'
+            args.test_key = 'sst-5_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'subj':
+            args.key = 'subj_dev_eval_acc'
+            args.test_key = 'subj_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'trec':
+            args.key = 'trec_dev_eval_acc'
+            args.test_key = 'trec_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'cr':
+            args.key = 'cr_dev_eval_acc'
+            args.test_key = 'cr_test_eval_acc'
+            print_name = condition['task_name']
+        elif condition['task_name'] == 'mpqa':
+            args.key = 'mpqa_dev_eval_acc'
+            args.test_key = 'mpqa_test_eval_acc'
+            print_name = condition['task_name']
+        else:
+            raise NotImplementedError
+
+    with open(args.log) as f:
+        result_list = []
+        for line in f:
+            result_list.append(eval(line))
+    
+    seed_result = {}
+    seed_result_template_id = {} # avoid duplication
+
+    for item in result_list:
+        ok = True
+        for cond in condition:
+            if cond not in item or item[cond] != condition[cond]:
+                ok = False
+                break
+        
+        if ok:
+            seed = item['seed']
+            if seed not in seed_result:
+                seed_result[seed] = [item]
+                seed_result_template_id[seed] = {item['template_id']: 1}
+            else:
+                if item['template_id'] not in seed_result_template_id[seed]:
+                    seed_result[seed].append(item)
+                    seed_result_template_id[seed][item['template_id']] = 1
+
+    for seed in seed_result:
+        print("Seed %d has %d results" % (seed, len(seed_result[seed])))
+
+        # Load all templates
+        with open(os.path.join(args.template_dir, print_name, "{}-{}.txt".format(args.k, seed))) as f:
+            templates = []
+            for line in f:
+                templates.append(line.strip())
+
+        # Write sorted templates
+        fsort = open(os.path.join(args.template_dir, print_name, "{}-{}.sort.txt".format(args.k, seed)), 'w')
+        fscore = open(os.path.join(args.template_dir, print_name, "{}-{}.score.txt".format(args.k, seed)), 'w')
+
+        seed_result[seed].sort(key=lambda x: x[args.key], reverse=True)
+        for item in seed_result[seed]:
+            fsort.write(templates[item['template_id']] + '\n')
+            fscore.write("%.5f %s\n" % (item[args.key], templates[item['template_id']]))
+
+if __name__ == '__main__':
+    main()
diff --git "a/LM-BFF/\346\234\252\345\221\275\345\220\215.ipynb" "b/LM-BFF/\346\234\252\345\221\275\345\220\215.ipynb"
new file mode 100755
index 0000000..39dc502
--- /dev/null
+++ "b/LM-BFF/\346\234\252\345\221\275\345\220\215.ipynb"
@@ -0,0 +1,432 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f857d1db-7d6d-4c76-af4f-d508a4027192",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "head: error reading 'data/x-stance/': Is a directory\n"
+     ]
+    }
+   ],
+   "source": [
+    "!head -2 data/x-stance/questions.en.json\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 42,
+   "id": "b6421d22-ca47-4ecf-b667-8f19f4cb035a",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "questions.de.jsonl  questions.en.jsonl\tquestions.fr.jsonl  questions.it.jsonl\n"
+     ]
+    }
+   ],
+   "source": [
+    "!ls  data/x-stance"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "id": "1d288f21-0b89-4974-bded-d5ca9ff24f82",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import json\n",
+    "import pandas as pd\n",
+    "import os"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 48,
+   "id": "67b7a07f-26d9-4e6a-8473-2614f6b34887",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "json_list = []\n",
+    "for line in open('data/x-stance/questions.it.jsonl'):\n",
+    "  json_list.append(json.loads(line))\n",
+    "data = pd.DataFrame.from_dict(json_list)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 37,
+   "id": "d505a983-7d08-4fa0-8281-99027e8edd4c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "k = 100\n",
+    "seed_list = [100,13,21,42,87]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 39,
+   "id": "824ff804-347d-4913-a3f0-20bd38cb4159",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "100-100  100-13  100-21  100-42  100-87  16-100  16-13\t16-21  16-42  16-87\n"
+     ]
+    }
+   ],
+   "source": [
+    "!ls /home/mist/projects/LM-BFF/data/k-shot/x-stance"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "id": "5ddb0ab8-0dfb-44bc-99c3-69829c86d8a7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "output_dir = '/home/mist/projects/LM-BFF/data/k-shot/'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "id": "8b8b604e-5767-4bdb-b620-2d4c464c595c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "task = 'x-stance'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "id": "59dd94e5-b067-4644-832f-bc0e45d3486f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for seed in seed_list:\n",
+    "    task_dir = os.path.join(output_dir, task)\n",
+    "    setting_dir = os.path.join(task_dir, f\"{k}-{seed}\")\n",
+    "    os.makedirs(setting_dir, exist_ok=True)\n",
+    "    data.sample(random_state=seed,n=k)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 45,
+   "id": "531acdd0-cbf6-430a-91e2-66923236908c",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>id</th>\n",
+       "      <th>text</th>\n",
+       "      <th>topic</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>2</td>\n",
+       "      <td>Finden Sie es grundsätzlich richtig, dass der ...</td>\n",
+       "      <td>Welfare</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>4</td>\n",
+       "      <td>Soll zusätzlich zur bestehenden Mutterschaftsv...</td>\n",
+       "      <td>Welfare</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>6</td>\n",
+       "      <td>Die Invalidenversicherung spricht bei nicht ob...</td>\n",
+       "      <td>Welfare</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>7</td>\n",
+       "      <td>Würden Sie eine nationale Spitalplanung befürw...</td>\n",
+       "      <td>Healthcare</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>9</td>\n",
+       "      <td>Finden Sie es richtig, dass einzelne ärztliche...</td>\n",
+       "      <td>Healthcare</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>...</th>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>189</th>\n",
+       "      <td>3464</td>\n",
+       "      <td>Würden Sie eine Ausdehnung der rechtlichen Mög...</td>\n",
+       "      <td>Security</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>190</th>\n",
+       "      <td>3468</td>\n",
+       "      <td>Soll die Schweiz Verhandlungen über den Beitri...</td>\n",
+       "      <td>Foreign Policy</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>191</th>\n",
+       "      <td>3469</td>\n",
+       "      <td>Soll der Bundesrat ein Freihandelsabkommen mit...</td>\n",
+       "      <td>Foreign Policy</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>192</th>\n",
+       "      <td>3470</td>\n",
+       "      <td>Eine Initiative fordert, dass die Haftungsrege...</td>\n",
+       "      <td>Foreign Policy</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>193</th>\n",
+       "      <td>3471</td>\n",
+       "      <td>Befürworten Sie die Kandidatur der Schweiz für...</td>\n",
+       "      <td>Foreign Policy</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "<p>194 rows × 3 columns</p>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "       id                                               text           topic\n",
+       "0       2  Finden Sie es grundsätzlich richtig, dass der ...         Welfare\n",
+       "1       4  Soll zusätzlich zur bestehenden Mutterschaftsv...         Welfare\n",
+       "2       6  Die Invalidenversicherung spricht bei nicht ob...         Welfare\n",
+       "3       7  Würden Sie eine nationale Spitalplanung befürw...      Healthcare\n",
+       "4       9  Finden Sie es richtig, dass einzelne ärztliche...      Healthcare\n",
+       "..    ...                                                ...             ...\n",
+       "189  3464  Würden Sie eine Ausdehnung der rechtlichen Mög...        Security\n",
+       "190  3468  Soll die Schweiz Verhandlungen über den Beitri...  Foreign Policy\n",
+       "191  3469  Soll der Bundesrat ein Freihandelsabkommen mit...  Foreign Policy\n",
+       "192  3470  Eine Initiative fordert, dass die Haftungsrege...  Foreign Policy\n",
+       "193  3471  Befürworten Sie die Kandidatur der Schweiz für...  Foreign Policy\n",
+       "\n",
+       "[194 rows x 3 columns]"
+      ]
+     },
+     "execution_count": 45,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 49,
+   "id": "bcf450cf-cb23-462f-b070-3529c1dfa86d",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Infrastructure & Environment    31\n",
+       "Economy                         23\n",
+       "Security                        20\n",
+       "Immigration                     19\n",
+       "Society                         17\n",
+       "Education                       16\n",
+       "Foreign Policy                  16\n",
+       "Finances                        15\n",
+       "Welfare                         15\n",
+       "Healthcare                      11\n",
+       "Political System                 9\n",
+       "Digitisation                     2\n",
+       "Name: topic, dtype: int64"
+      ]
+     },
+     "execution_count": 49,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "data['topic'].value_counts()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "id": "b07328e8-5fcf-4fad-9468-49ff91652ef9",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for seed in seed_list:\n",
+    "    data.sample(random_state=seed,n=k)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "id": "8d7aa681-bd14-4c48-881a-9130a9b88edc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "label_encoding = {\n",
+    "'Infrastructure & Environment': 0,\n",
+    "'Economy': 1                      ,  \n",
+    "'Security':    2                   , \n",
+    "'Immigration':   3                  ,\n",
+    "'Society': 4                        ,\n",
+    "'Education':   5                    ,\n",
+    "'Foreign Policy': 6                 ,\n",
+    "'Finances': 7                       ,\n",
+    "'Welfare':8                         ,\n",
+    "'Healthcare':9                      ,\n",
+    "'Political System':  10              , \n",
+    "'Digitisation':11 }"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "id": "3c2fad2e-77cf-4eff-9dda-5773f59f4402",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data['label'] = data['topic'].apply(lambda x:label_encoding[x])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "id": "6e4c0a41-31a8-4620-bfed-c776a21e0c1c",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(194, 4)"
+      ]
+     },
+     "execution_count": 40,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "data.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "4fc29374-8495-43b7-829f-90c11bf8a974",
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'pd' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-4-a66617debac6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mline\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'data/x-stance/questions.en.jsonl'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m   \u001b[0mjson_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataFrame\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjson_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m'+1'\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'stars'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m>=\u001b[0m\u001b[0;36m3\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m'-1'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfrac\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'pd' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "\n",
+    "\n",
+    "data['label']=data.apply(lambda x:'+1' if x['stars']>=3 else '-1',axis=1)\n",
+    "data = data.sample(frac=1)\n",
+    "data['text']=data['text'].apply(lambda x:' '.join(x.replace('\\n','.').split(' ')[:500]))\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "551574a5-f461-4a6f-8c0a-97cf92747b74",
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'data' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-9-c5d84736ba45>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m: name 'data' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d3187648-9c75-423d-bc21-ba655e30df6f",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
--
libgit2 0.26.0