Commit 59302e20 by 20210828028

v1

parent 8e48a3b2
LM-BFF @ d3f96076
Subproject commit d3f960766ece2006bfc83b1567b149f146eb2c05
This source diff could not be displayed because it is too large. You can view the blob instead.
{
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# Pattern-Exploiting Training (PET)
This repository contains the code for [Exploiting Cloze Questions for Few-Shot Text Classification and Natural Language Inference](https://arxiv.org/abs/2001.07676) and [It's Not Just Size That Matters: Small Language Models Are Also Few-Shot Learners](https://arxiv.org/abs/2009.07118). The papers introduce pattern-exploiting training (PET), a semi-supervised training procedure that reformulates input examples as cloze-style phrases. In low-resource settings, PET and iPET significantly outperform regular supervised training, various semi-supervised baselines and even GPT-3 despite requiring 99.9% less parameters. The iterative variant of PET (iPET) trains multiple generations of models and can even be used without any training data.
<table>
<tr>
<th>#Examples</th>
<th>Training Mode</th>
<th>Yelp (Full)</th>
<th>AG's News</th>
<th>Yahoo Questions</th>
<th>MNLI</th>
</tr>
<tr>
<td rowspan="2" align="center"><b>0</b></td>
<td>unsupervised</td>
<td align="right">33.8</td>
<td align="right">69.5</td>
<td align="right">44.0</td>
<td align="right">39.1</td>
</tr>
<tr>
<td>iPET</td>
<td align="right"><b>56.7</b></td>
<td align="right"><b>87.5</b></td>
<td align="right"><b>70.7</b></td>
<td align="right"><b>53.6</b></td>
</tr>
<tr>
<td rowspan="3" align="center"><b>100</b></td>
<td>supervised</td>
<td align="right">53.0</td>
<td align="right">86.0</td>
<td align="right">62.9</td>
<td align="right">47.9</td>
</tr>
<tr>
<td>PET</td>
<td align="right">61.9</td>
<td align="right">88.3</td>
<td align="right">69.2</td>
<td align="right">74.7</td>
</tr>
<tr>
<td>iPET</td>
<td align="right"><b>62.9</b></td>
<td align="right"><b>89.6</b></td>
<td align="right"><b>71.2</b></td>
<td align="right"><b>78.4</b></td>
</tr>
</table>
<sup>*Note*: To exactly reproduce the above results, make sure to use v1.1.0 (`--branch v1.1.0`).</sup>
## 📑 Contents
**[🔧 Setup](#-setup)**
**[💬 CLI Usage](#-cli-usage)**
**[💻 API Usage](#-api-usage)**
**[🐶 Train your own PET](#-train-your-own-pet)**
**[📕 Citation](#-citation)**
## 🔧 Setup
All requirements for PET can be found in `requirements.txt`. You can install all required packages with `pip install -r requirements.txt`.
## 💬 CLI Usage
The command line interface `cli.py` in this repository currently supports three different training modes (PET, iPET, supervised training), two additional evaluation methods (unsupervised and priming) and 13 different tasks. For Yelp Reviews, AG's News, Yahoo Questions, MNLI and X-Stance, see [the original paper](https://arxiv.org/abs/2001.07676) for further details. For the 8 SuperGLUE tasks, see [this paper](https://arxiv.org/abs/2009.07118).
### PET Training and Evaluation
To train and evaluate a PET model for one of the supported tasks, simply run the following command:
python3 cli.py \
--method pet \
--pattern_ids $PATTERN_IDS \
--data_dir $DATA_DIR \
--model_type $MODEL_TYPE \
--model_name_or_path $MODEL_NAME_OR_PATH \
--task_name $TASK \
--output_dir $OUTPUT_DIR \
--do_train \
--do_eval
where
- `$PATTERN_IDS` specifies the PVPs to use. For example, if you want to use *all* patterns, specify `PATTERN_IDS 0 1 2 3 4` for AG's News and Yahoo Questions or `PATTERN_IDS 0 1 2 3` for Yelp Reviews and MNLI.
- `$DATA_DIR` is the directory containing the train and test files (check `tasks.py` to see how these files should be named and formatted for each task).
- `$MODEL_TYPE` is the name of the model being used, e.g. `albert`, `bert` or `roberta`.
- `$MODEL_NAME` is the name of a pretrained model (e.g., `roberta-large` or `albert-xxlarge-v2`) or the path to a pretrained model.
- `$TASK_NAME` is the name of the task to train and evaluate on.
- `$OUTPUT_DIR` is the name of the directory in which the trained model and evaluation results are saved.
You can additionally specify various training parameters for both the ensemble of PET models corresponding to individual PVPs (prefix `--pet_`) and for the final sequence classification model (prefix `--sc_`). For example, the default parameters used for our SuperGLUE evaluation are:
--pet_per_gpu_eval_batch_size 8 \
--pet_per_gpu_train_batch_size 2 \
--pet_gradient_accumulation_steps 8 \
--pet_max_steps 250 \
--pet_max_seq_length 256 \
--pet_repetitions 3 \
--sc_per_gpu_train_batch_size 2 \
--sc_per_gpu_unlabeled_batch_size 2 \
--sc_gradient_accumulation_steps 8 \
--sc_max_steps 5000 \
--sc_max_seq_length 256 \
--sc_repetitions 1
For each pattern `$P` and repetition `$I`, running the above command creates a directory `$OUTPUT_DIR/p$P-i$I` that contains the following files:
- `pytorch_model.bin`: the finetuned model, possibly along with some model-specific files (e.g, `spiece.model`, `special_tokens_map.json`)
- `wrapper_config.json`: the configuration of the model being used
- `train_config.json`: the configuration used for training
- `eval_config.json`: the configuration used for evaluation
- `logits.txt`: the model's predictions on the unlabeled data
- `eval_logits.txt`: the model's prediction on the evaluation data
- `results.json`: a json file containing results such as the model's final accuracy
- `predictions.jsonl`: a prediction file for the evaluation set in the SuperGlue format
The final (distilled) model for each repetition `$I` can be found in `$OUTPUT_DIR/final/p0-i$I`, which contains the same files as described above.
🚨 If your GPU runs out of memory during training, you can try decreasing both the `pet_per_gpu_train_batch_size` and the `sc_per_gpu_unlabeled_batch_size` while increasing both `pet_gradient_accumulation_steps` and `sc_gradient_accumulation_steps`.
### iPET Training and Evaluation
To train and evaluate an iPET model for one of the supported tasks, simply run the same command as above, but replace `--method pet` with `--method ipet`. There are various additional iPET parameters that you can modify; all of them are prefixed with `--ipet_`.
For each generation `$G`, pattern `$P` and iteration `$I`, this creates a directory `$OUTPUT_DIR/g$G/p$P-i$I` that is structured as for regular PET. The final (distilled) model can again be found in `$OUTPUT_DIR/final/p0-i$I`.
🚨 If you use iPET with zero training examples, you need to specify how many examples for each label should be chosen in the first generation and you need to change the reduction strategy to mean: `--ipet_n_most_likely 100 --reduction mean`.
### Supervised Training and Evaluation
To train and evaluate a regular sequence classifier in a supervised fashion, simply run the same command as above, but replace `--method pet` with `--method sequence_classifier`. There are various additional parameters for the sequence classifier that you can modify; all of them are prefixed with `--sc_`.
### Unsupervised Evaluation
To evaluate a pretrained language model with the default PET patterns and verbalizers, but without fine-tuning, remove the argument `--do_train` and add `--no_distillation` so that no final distillation is performed.
### Priming
If you want to use priming, remove the argument `--do_train` and add the arguments `--priming --no_distillation` so that all training examples are used for priming and no final distillation is performed.
🚨 Remember that you may need to increase the maximum sequence length to a much larger value, e.g. `--pet_max_seq_length 5000`. This only works with language models that support such long sequences, e.g. XLNet. For using XLNet, you can specify `--model_type xlnet --model_name_or_path xlnet-large-cased --wrapper_type plm`.
## 💻 API Usage
Instead of using the command line interface, you can also directly use the PET API, most of which is defined in `pet.modeling`. By including `import pet`, you can access methods such as `train_pet`, `train_ipet` and `train_classifier`. Check out their documentation for more information.
## 🐶 Train your own PET
To use PET for custom tasks, you need to define two things:
- a **DataProcessor**, responsible for loading training and test data. See `examples/custom_task_processor.py` for an example.
- a **PVP**, responsible for applying patterns to inputs and mapping labels to natural language verbalizations. See `examples/custom_task_pvp.py` for an example.
After having implemented the DataProcessor and the PVP, you can train a PET model using the command line as [described above](#pet-training-and-evaluation). Below, you can find additional information on how to define the two components of a PVP, *verbalizers* and *patterns*.
### Verbalizers
Verbalizers are used to map task labels to words in natural language. For example, in a binary sentiment classification task, you could map the positive label (`+1`) to the word `good` and the negative label (`-1`) to the word `bad`. Verbalizers are realized through a PVP's `verbalize()` method. The simplest way of defining a verbalizer is to use a dictionary:
```python
VERBALIZER = {"+1": ["good"], "-1": ["bad"]}
def verbalize(self, label) -> List[str]:
return self.VERBALIZER[label]
```
Importantly, in PET's current version, verbalizers are by default restricted to **single tokens** in the underlying LMs vocabulary (for using more than one token, [see below](#pet-with-multiple-masks)). Given a language model's tokenizer, you can easily check whether a word corresponds to a single token by verifying that `len(tokenizer.tokenize(word)) == 1`.
You can also define multiple verbalizations for a single label. For example, if you are unsure which words best represent the labels in a binary sentiment classification task, you could define your verbalizer as follows:
```python
VERBALIZER = {"+1": ["great", "good", "wonderful", "perfect"], "-1": ["bad", "terrible", "horrible"]}
```
### Patterns
Patterns are used to make the language model understand a given task; they must contain exactly one `<MASK>` token which is to be filled using the verbalizer. For binary sentiment classification based on a review's summary (`<A>`) and body (`<B>`), a suitable pattern may be `<A>. <B>. Overall, it was <MASK>.` Patterns are realized through a PVP's `get_parts()` method, which returns a pair of text sequences (where each sequence is represented by a list of strings):
```python
def get_parts(self, example: InputExample):
return [example.text_a, '.', example.text_b, '.'], ['Overall, it was ', self.mask]
```
If you do not want to use a pair of sequences, you can simply leave the second sequence empty:
```python
def get_parts(self, example: InputExample):
return [example.text_a, '.', example.text_b, '. Overall, it was ', self.mask], []
```
If you want to define several patterns, simply use the `PVP`s `pattern_id` attribute:
```python
def get_parts(self, example: InputExample):
if self.pattern_id == 1:
return [example.text_a, '.', example.text_b, '.'], ['Overall, it was ', self.mask]
elif self.pattern_id == 2:
return ['It was just ', self.mask, '!', example.text_a, '.', example.text_b, '.'], []
```
When training the model using the command line, specify all patterns to be used (e.g., `--pattern_ids 1 2`).
Importantly, if a sequence is longer than the specified maximum sequence length of the underlying LM, PET must know which parts of the input can be shortened and which ones cannot (for example, the mask token must always be there). Therefore, `PVP` provides a `shortenable()` method to indicate that a piece of text can be shortened:
```python
def get_parts(self, example: InputExample):
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
return [text_a, '.', text_b, '. Overall, it was ', self.mask], []
```
### PET with Multiple Masks
By default, the current implementation of PET and iPET only supports a fixed set of labels that is shared across all examples and verbalizers that correspond to a single token.
However, for some tasks it may be necessary to use verbalizers that correspond to multiple tokens ([as described here](http://arxiv.org/abs/2009.07118)).
To do so, you simply need the following two modifications:
1) Add the following lines in your task's **DataProcessor** (see `examples/custom_task_processor.py`):
```python
from pet.tasks import TASK_HELPERS
from pet.task_helpers import MultiMaskTaskHelper
TASK_HELPERS['my_task'] = MultiMaskTaskHelper
```
where ```'my_task'``` is the name of your task.
2) In your **PVP**, make sure that the ``get_parts()`` method always inserts **the maximum number of mask tokens** required for any verbalization. For example, if your verbalizer maps ``+1`` to "really awesome" and ``-1`` to "terrible" and if those are tokenized as ``["really", "awe", "##some"]`` and ``["terrible"]``, respectively, your ``get_parts()`` method should always return a sequence that contains exactly 3 mask tokens.
With this modification, you can now use verbalizers consisting of multiple tokens:
```python
VERBALIZER = {"+1": ["really good"], "-1": ["just bad"]}
```
However, there are several limitations to consider:
- When using a ``MultiMaskTaskHelper``, the maximum batch size for evaluation is 1.
- As using multiple masks requires multiple forward passes during evaluation, the time required for evaluation scales about linearly with the length of the longest verbalizer. If you require verbalizers that consist of 10 or more tokens, [using a generative LM](https://arxiv.org/abs/2012.11926) might be a better approach.
- The ``MultiMaskTaskHelper`` class is an experimental feature that is not thoroughly tested. In particular, this feature has only been tested for PET and not for iPET. If you observe something strange, please raise an issue.
For more flexibility, you can also write a custom `TaskHelper`. As a starting point, you can check out the classes `CopaTaskHelper`, `WscTaskHelper` and `RecordTaskHelper` in `pet/task_helpers.py`.
## 📕 Citation
If you make use of the code in this repository, please cite the following papers:
@article{schick2020exploiting,
title={Exploiting Cloze Questions for Few-Shot Text Classification and Natural Language Inference},
author={Timo Schick and Hinrich Schütze},
journal={Computing Research Repository},
volume={arXiv:2001.07676},
url={http://arxiv.org/abs/2001.07676},
year={2020}
}
@article{schick2020small,
title={It's Not Just Size That Matters: Small Language Models Are Also Few-Shot Learners},
author={Timo Schick and Hinrich Schütze},
journal={Computing Research Repository},
volume={arXiv:2009.07118},
url={http://arxiv.org/abs/2009.07118},
year={2020}
}
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script can be used to train and evaluate either a regular supervised model or a PET/iPET model on
one of the supported tasks and datasets.
"""
import argparse
import os
from typing import Tuple
import torch
from pet.tasks import PROCESSORS, load_examples, UNLABELED_SET, TRAIN_SET, DEV_SET, TEST_SET, METRICS, DEFAULT_METRICS
from pet.utils import eq_div
from pet.wrapper import WRAPPER_TYPES, MODEL_CLASSES, SEQUENCE_CLASSIFIER_WRAPPER, WrapperConfig
import pet
import log
logger = log.get_logger('root')
def load_pet_configs(args) -> Tuple[WrapperConfig, pet.TrainConfig, pet.EvalConfig]:
"""
Load the model, training and evaluation configs for PET from the given command line arguments.
"""
model_cfg = WrapperConfig(model_type=args.model_type, model_name_or_path=args.model_name_or_path,
wrapper_type=args.wrapper_type, task_name=args.task_name, label_list=args.label_list,
max_seq_length=args.pet_max_seq_length, verbalizer_file=args.verbalizer_file,
cache_dir=args.cache_dir)
train_cfg = pet.TrainConfig(device=args.device, per_gpu_train_batch_size=args.pet_per_gpu_train_batch_size,
per_gpu_unlabeled_batch_size=args.pet_per_gpu_unlabeled_batch_size, n_gpu=args.n_gpu,
num_train_epochs=args.pet_num_train_epochs, max_steps=args.pet_max_steps,
gradient_accumulation_steps=args.pet_gradient_accumulation_steps,
weight_decay=args.weight_decay, learning_rate=args.learning_rate,
adam_epsilon=args.adam_epsilon, warmup_steps=args.warmup_steps,
max_grad_norm=args.max_grad_norm, lm_training=args.lm_training, alpha=args.alpha)
eval_cfg = pet.EvalConfig(device=args.device, n_gpu=args.n_gpu, metrics=args.metrics,
per_gpu_eval_batch_size=args.pet_per_gpu_eval_batch_size,
decoding_strategy=args.decoding_strategy, priming=args.priming)
return model_cfg, train_cfg, eval_cfg
def load_sequence_classifier_configs(args) -> Tuple[WrapperConfig, pet.TrainConfig, pet.EvalConfig]:
"""
Load the model, training and evaluation configs for a regular sequence classifier from the given command line
arguments. This classifier can either be used as a standalone model or as the final classifier for PET/iPET.
"""
model_cfg = WrapperConfig(model_type=args.model_type, model_name_or_path=args.model_name_or_path,
wrapper_type=SEQUENCE_CLASSIFIER_WRAPPER, task_name=args.task_name,
label_list=args.label_list, max_seq_length=args.sc_max_seq_length,
verbalizer_file=args.verbalizer_file, cache_dir=args.cache_dir)
train_cfg = pet.TrainConfig(device=args.device, per_gpu_train_batch_size=args.sc_per_gpu_train_batch_size,
per_gpu_unlabeled_batch_size=args.sc_per_gpu_unlabeled_batch_size, n_gpu=args.n_gpu,
num_train_epochs=args.sc_num_train_epochs, max_steps=args.sc_max_steps,
temperature=args.temperature,
gradient_accumulation_steps=args.sc_gradient_accumulation_steps,
weight_decay=args.weight_decay, learning_rate=args.learning_rate,
adam_epsilon=args.adam_epsilon, warmup_steps=args.warmup_steps,
max_grad_norm=args.max_grad_norm, use_logits=args.method != 'sequence_classifier')
eval_cfg = pet.EvalConfig(device=args.device, n_gpu=args.n_gpu, metrics=args.metrics,
per_gpu_eval_batch_size=args.sc_per_gpu_eval_batch_size)
return model_cfg, train_cfg, eval_cfg
def load_ipet_config(args) -> pet.IPetConfig:
"""
Load the iPET config from the given command line arguments.
"""
ipet_cfg = pet.IPetConfig(generations=args.ipet_generations, logits_percentage=args.ipet_logits_percentage,
scale_factor=args.ipet_scale_factor, n_most_likely=args.ipet_n_most_likely)
return ipet_cfg
def main():
parser = argparse.ArgumentParser(description="Command line interface for PET/iPET")
# Required parameters
parser.add_argument("--method", required=True, choices=['pet', 'ipet', 'sequence_classifier'],
help="The training method to use. Either regular sequence classification, PET or iPET.")
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the data files for the task.")
parser.add_argument("--model_type", default=None, type=str, required=True, choices=MODEL_CLASSES.keys(),
help="The type of the pretrained language model to use")
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to the pre-trained model or shortcut name")
parser.add_argument("--task_name", default=None, type=str, required=True, choices=PROCESSORS.keys(),
help="The name of the task to train/evaluate on")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written")
# PET-specific optional parameters
parser.add_argument("--wrapper_type", default="mlm", choices=WRAPPER_TYPES,
help="The wrapper type. Set this to 'mlm' for a masked language model like BERT or to 'plm' "
"for a permuted language model like XLNet (only for PET)")
parser.add_argument("--pattern_ids", default=[0], type=int, nargs='+',
help="The ids of the PVPs to be used (only for PET)")
parser.add_argument("--lm_training", action='store_true',
help="Whether to use language modeling as auxiliary task (only for PET)")
parser.add_argument("--alpha", default=0.9999, type=float,
help="Weighting term for the auxiliary language modeling task (only for PET)")
parser.add_argument("--temperature", default=2, type=float,
help="Temperature used for combining PVPs (only for PET)")
parser.add_argument("--verbalizer_file", default=None,
help="The path to a file to override default verbalizers (only for PET)")
parser.add_argument("--reduction", default='wmean', choices=['wmean', 'mean'],
help="Reduction strategy for merging predictions from multiple PET models. Select either "
"uniform weighting (mean) or weighting based on train set accuracy (wmean)")
parser.add_argument("--decoding_strategy", default='default', choices=['default', 'ltr', 'parallel'],
help="The decoding strategy for PET with multiple masks (only for PET)")
parser.add_argument("--no_distillation", action='store_true',
help="If set to true, no distillation is performed (only for PET)")
parser.add_argument("--pet_repetitions", default=3, type=int,
help="The number of times to repeat PET training and testing with different seeds.")
parser.add_argument("--pet_max_seq_length", default=256, type=int,
help="The maximum total input sequence length after tokenization for PET. Sequences longer "
"than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--pet_per_gpu_train_batch_size", default=4, type=int,
help="Batch size per GPU/CPU for PET training.")
parser.add_argument("--pet_per_gpu_eval_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for PET evaluation.")
parser.add_argument("--pet_per_gpu_unlabeled_batch_size", default=4, type=int,
help="Batch size per GPU/CPU for auxiliary language modeling examples in PET.")
parser.add_argument('--pet_gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass in PET.")
parser.add_argument("--pet_num_train_epochs", default=3, type=float,
help="Total number of training epochs to perform in PET.")
parser.add_argument("--pet_max_steps", default=-1, type=int,
help="If > 0: set total number of training steps to perform in PET. Override num_train_epochs.")
# SequenceClassifier-specific optional parameters (also used for the final PET classifier)
parser.add_argument("--sc_repetitions", default=1, type=int,
help="The number of times to repeat seq. classifier training and testing with different seeds.")
parser.add_argument("--sc_max_seq_length", default=256, type=int,
help="The maximum total input sequence length after tokenization for sequence classification. "
"Sequences longer than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--sc_per_gpu_train_batch_size", default=4, type=int,
help="Batch size per GPU/CPU for sequence classifier training.")
parser.add_argument("--sc_per_gpu_eval_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for sequence classifier evaluation.")
parser.add_argument("--sc_per_gpu_unlabeled_batch_size", default=4, type=int,
help="Batch size per GPU/CPU for unlabeled examples used for distillation.")
parser.add_argument('--sc_gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass for "
"sequence classifier training.")
parser.add_argument("--sc_num_train_epochs", default=3, type=float,
help="Total number of training epochs to perform for sequence classifier training.")
parser.add_argument("--sc_max_steps", default=-1, type=int,
help="If > 0: set total number of training steps to perform for sequence classifier training. "
"Override num_train_epochs.")
# iPET-specific optional parameters
parser.add_argument("--ipet_generations", default=3, type=int,
help="The number of generations to train (only for iPET)")
parser.add_argument("--ipet_logits_percentage", default=0.25, type=float,
help="The percentage of models to choose for annotating new training sets (only for iPET)")
parser.add_argument("--ipet_scale_factor", default=5, type=float,
help="The factor by which to increase the training set size per generation (only for iPET)")
parser.add_argument("--ipet_n_most_likely", default=-1, type=int,
help="If >0, in the first generation the n_most_likely examples per label are chosen even "
"if their predicted label is different (only for iPET)")
# Other optional parameters
parser.add_argument("--train_examples", default=-1, type=int,
help="The total number of train examples to use, where -1 equals all examples.")
parser.add_argument("--test_examples", default=-1, type=int,
help="The total number of test examples to use, where -1 equals all examples.")
parser.add_argument("--unlabeled_examples", default=-1, type=int,
help="The total number of unlabeled examples to use, where -1 equals all examples")
parser.add_argument("--split_examples_evenly", action='store_true',
help="If true, train examples are not chosen randomly, but split evenly across all labels.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where to store the pre-trained models downloaded from S3.")
parser.add_argument("--learning_rate", default=1e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.01, type=float,
help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm.")
parser.add_argument("--warmup_steps", default=0, type=int,
help="Linear warmup over warmup_steps.")
parser.add_argument('--logging_steps', type=int, default=50,
help="Log every X updates steps.")
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")
parser.add_argument('--overwrite_output_dir', action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--do_train', action='store_true',
help="Whether to perform training")
parser.add_argument('--do_eval', action='store_true',
help="Whether to perform evaluation")
parser.add_argument('--priming', action='store_true',
help="Whether to use priming for evaluation")
parser.add_argument("--eval_set", choices=['dev', 'test'], default='dev',
help="Whether to perform evaluation on the dev set or the test set")
args = parser.parse_args()
logger.info("Parameters: {}".format(args))
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) \
and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
# Setup CUDA, GPU & distributed training
args.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
args.n_gpu = torch.cuda.device_count()
# Prepare task
args.task_name = args.task_name.lower()
if args.task_name not in PROCESSORS:
raise ValueError("Task '{}' not found".format(args.task_name))
processor = PROCESSORS[args.task_name]()
args.label_list = processor.get_labels()
train_ex_per_label, test_ex_per_label = None, None
train_ex, test_ex = args.train_examples, args.test_examples
if args.split_examples_evenly:
train_ex_per_label = eq_div(args.train_examples, len(args.label_list)) if args.train_examples != -1 else -1
test_ex_per_label = eq_div(args.test_examples, len(args.label_list)) if args.test_examples != -1 else -1
train_ex, test_ex = None, None
eval_set = TEST_SET if args.eval_set == 'test' else DEV_SET
train_data = load_examples(
args.task_name, args.data_dir, TRAIN_SET, num_examples=train_ex, num_examples_per_label=train_ex_per_label)
eval_data = load_examples(
args.task_name, args.data_dir, eval_set, num_examples=test_ex, num_examples_per_label=test_ex_per_label)
unlabeled_data = load_examples(
args.task_name, args.data_dir, UNLABELED_SET, num_examples=args.unlabeled_examples)
args.metrics = METRICS.get(args.task_name, DEFAULT_METRICS)
pet_model_cfg, pet_train_cfg, pet_eval_cfg = load_pet_configs(args)
sc_model_cfg, sc_train_cfg, sc_eval_cfg = load_sequence_classifier_configs(args)
ipet_cfg = load_ipet_config(args)
if args.method == 'pet':
pet.train_pet(pet_model_cfg, pet_train_cfg, pet_eval_cfg, sc_model_cfg, sc_train_cfg, sc_eval_cfg,
pattern_ids=args.pattern_ids, output_dir=args.output_dir,
ensemble_repetitions=args.pet_repetitions, final_repetitions=args.sc_repetitions,
reduction=args.reduction, train_data=train_data, unlabeled_data=unlabeled_data,
eval_data=eval_data, do_train=args.do_train, do_eval=args.do_eval,
no_distillation=args.no_distillation, seed=args.seed)
elif args.method == 'ipet':
pet.train_ipet(pet_model_cfg, pet_train_cfg, pet_eval_cfg, ipet_cfg, sc_model_cfg, sc_train_cfg, sc_eval_cfg,
pattern_ids=args.pattern_ids, output_dir=args.output_dir,
ensemble_repetitions=args.pet_repetitions, final_repetitions=args.sc_repetitions,
reduction=args.reduction, train_data=train_data, unlabeled_data=unlabeled_data,
eval_data=eval_data, do_train=args.do_train, do_eval=args.do_eval, seed=args.seed)
elif args.method == 'sequence_classifier':
pet.train_classifier(sc_model_cfg, sc_train_cfg, sc_eval_cfg, output_dir=args.output_dir,
repetitions=args.sc_repetitions, train_data=train_data, unlabeled_data=unlabeled_data,
eval_data=eval_data, do_train=args.do_train, do_eval=args.do_eval, seed=args.seed)
else:
raise ValueError(f"Training method '{args.method}' not implemented")
if __name__ == "__main__":
main()
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To add a new task to PET, both a DataProcessor and a PVP for this task must
be added. The DataProcessor is responsible for loading training and test data.
This file shows an example of a DataProcessor for a new task.
"""
import csv
import os
from typing import List
from pet.task_helpers import MultiMaskTaskHelper
from pet.tasks import DataProcessor, PROCESSORS, TASK_HELPERS
from pet.utils import InputExample
class MyTaskDataProcessor(DataProcessor):
"""
Example for a data processor.
"""
# Set this to the name of the task
TASK_NAME = "my-task"
# Set this to the name of the file containing the train examples
TRAIN_FILE_NAME = "train.csv"
# Set this to the name of the file containing the dev examples
DEV_FILE_NAME = "dev.csv"
# Set this to the name of the file containing the test examples
TEST_FILE_NAME = "test.csv"
# Set this to the name of the file containing the unlabeled examples
UNLABELED_FILE_NAME = "unlabeled.csv"
# Set this to a list of all labels in the train + test data
LABELS = ["1", "2", "3", "4"]
# Set this to the column of the train/test csv files containing the input's text a
TEXT_A_COLUMN = 1
# Set this to the column of the train/test csv files containing the input's text b or to -1 if there is no text b
TEXT_B_COLUMN = 2
# Set this to the column of the train/test csv files containing the input's gold label
LABEL_COLUMN = 0
def get_train_examples(self, data_dir: str) -> List[InputExample]:
"""
This method loads train examples from a file with name `TRAIN_FILE_NAME` in the given directory.
:param data_dir: the directory in which the training data can be found
:return: a list of train examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.TRAIN_FILE_NAME), "train")
def get_dev_examples(self, data_dir: str) -> List[InputExample]:
"""
This method loads dev examples from a file with name `DEV_FILE_NAME` in the given directory.
:param data_dir: the directory in which the dev data can be found
:return: a list of dev examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.DEV_FILE_NAME), "dev")
def get_test_examples(self, data_dir) -> List[InputExample]:
"""
This method loads test examples from a file with name `TEST_FILE_NAME` in the given directory.
:param data_dir: the directory in which the test data can be found
:return: a list of test examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.TEST_FILE_NAME), "test")
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
"""
This method loads unlabeled examples from a file with name `UNLABELED_FILE_NAME` in the given directory.
:param data_dir: the directory in which the unlabeled data can be found
:return: a list of unlabeled examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.UNLABELED_FILE_NAME), "unlabeled")
def get_labels(self) -> List[str]:
"""This method returns all possible labels for the task."""
return MyTaskDataProcessor.LABELS
def _create_examples(self, path, set_type, max_examples=-1, skip_first=0):
"""Creates examples for the training and dev sets."""
examples = []
with open(path) as f:
reader = csv.reader(f, delimiter=',')
for idx, row in enumerate(reader):
guid = "%s-%s" % (set_type, idx)
label = row[MyTaskDataProcessor.LABEL_COLUMN]
text_a = row[MyTaskDataProcessor.TEXT_A_COLUMN]
text_b = row[MyTaskDataProcessor.TEXT_B_COLUMN] if MyTaskDataProcessor.TEXT_B_COLUMN >= 0 else None
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
return examples
# register the processor for this task with its name
PROCESSORS[MyTaskDataProcessor.TASK_NAME] = MyTaskDataProcessor
# optional: if you have to use verbalizers that correspond to multiple tokens, uncomment the following line
# TASK_HELPERS[MyTaskDataProcessor.TASK_NAME] = MultiMaskTaskHelper
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To add a new task to PET, both a DataProcessor and a PVP for this task must
be added. The PVP is responsible for applying patterns to inputs and mapping
labels to their verbalizations (see the paper for more details on PVPs).
This file shows an example of a PVP for a new task.
"""
from typing import List
from pet.pvp import PVP, PVPS
from pet.utils import InputExample
class MyTaskPVP(PVP):
"""
Example for a pattern-verbalizer pair (PVP).
"""
# Set this to the name of the task
TASK_NAME = "my-task"
# Set this to the verbalizer for the given task: a mapping from the task's labels (which can be obtained using
# the corresponding DataProcessor's get_labels method) to tokens from the language model's vocabulary
VERBALIZER = {
"1": ["World"],
"2": ["Sports"],
"3": ["Business"],
"4": ["Tech"]
}
def get_parts(self, example: InputExample):
"""
This function defines the actual patterns: It takes as input an example and outputs the result of applying a
pattern to it. To allow for multiple patterns, a pattern_id can be passed to the PVP's constructor. This
method must implement the application of all patterns.
"""
# We tell the tokenizer that both text_a and text_b can be truncated if the resulting sequence is longer than
# our language model's max sequence length.
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
# For each pattern_id, we define the corresponding pattern and return a pair of text a and text b (where text b
# can also be empty).
if self.pattern_id == 0:
# this corresponds to the pattern [MASK]: a b
return [self.mask, ':', text_a, text_b], []
elif self.pattern_id == 1:
# this corresponds to the pattern [MASK] News: a || (b)
return [self.mask, 'News:', text_a], ['(', text_b, ')']
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
return MyTaskPVP.VERBALIZER[label]
# register the PVP for this task with its name
PVPS[MyTaskPVP.TASK_NAME] = MyTaskPVP
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains basic logging logic.
"""
import logging
names = set()
def __setup_custom_logger(name: str) -> logging.Logger:
root_logger = logging.getLogger()
root_logger.handlers.clear()
formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
names.add(name)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
logger.addHandler(handler)
return logger
def get_logger(name: str) -> logging.Logger:
if name in names:
return logging.getLogger(name)
else:
return __setup_custom_logger(name)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains the pattern-verbalizer pairs (PVPs) for all tasks.
"""
import random
import string
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Tuple, List, Union, Dict
import torch
from transformers import PreTrainedTokenizer, GPT2Tokenizer
from pet.task_helpers import MultiMaskTaskHelper
from pet.tasks import TASK_HELPERS
from pet.utils import InputExample, get_verbalization_ids
import log
from pet import wrapper as wrp
logger = log.get_logger('root')
FilledPattern = Tuple[List[Union[str, Tuple[str, bool]]], List[Union[str, Tuple[str, bool]]]]
class PVP(ABC):
"""
This class contains functions to apply patterns and verbalizers as required by PET. Each task requires its own
custom implementation of a PVP.
"""
def __init__(self, wrapper, pattern_id: int = 0, verbalizer_file: str = None, seed: int = 42):
"""
Create a new PVP.
:param wrapper: the wrapper for the underlying language model
:param pattern_id: the pattern id to use
:param verbalizer_file: an optional file that contains the verbalizer to be used
:param seed: a seed to be used for generating random numbers if necessary
"""
self.wrapper = wrapper
self.pattern_id = pattern_id
self.rng = random.Random(seed)
if verbalizer_file:
self.verbalize = PVP._load_verbalizer_from_file(verbalizer_file, self.pattern_id)
use_multimask = (self.wrapper.config.task_name in TASK_HELPERS) and (
issubclass(TASK_HELPERS[self.wrapper.config.task_name], MultiMaskTaskHelper)
)
if not use_multimask and self.wrapper.config.wrapper_type in [wrp.MLM_WRAPPER, wrp.PLM_WRAPPER]:
self.mlm_logits_to_cls_logits_tensor = self._build_mlm_logits_to_cls_logits_tensor()
def _build_mlm_logits_to_cls_logits_tensor(self):
label_list = self.wrapper.config.label_list
m2c_tensor = torch.ones([len(label_list), self.max_num_verbalizers], dtype=torch.long) * -1
for label_idx, label in enumerate(label_list):
verbalizers = self.verbalize(label)
for verbalizer_idx, verbalizer in enumerate(verbalizers):
verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True)
assert verbalizer_id != self.wrapper.tokenizer.unk_token_id, "verbalization was tokenized as <UNK>"
m2c_tensor[label_idx, verbalizer_idx] = verbalizer_id
return m2c_tensor
@property
def mask(self) -> str:
"""Return the underlying LM's mask token"""
return self.wrapper.tokenizer.mask_token
@property
def mask_id(self) -> int:
"""Return the underlying LM's mask id"""
return self.wrapper.tokenizer.mask_token_id
@property
def max_num_verbalizers(self) -> int:
"""Return the maximum number of verbalizers across all labels"""
return max(len(self.verbalize(label)) for label in self.wrapper.config.label_list)
@staticmethod
def shortenable(s):
"""Return an instance of this string that is marked as shortenable"""
return s, True
@staticmethod
def remove_final_punc(s: Union[str, Tuple[str, bool]]):
"""Remove the final punctuation mark"""
if isinstance(s, tuple):
return PVP.remove_final_punc(s[0]), s[1]
return s.rstrip(string.punctuation)
@staticmethod
def lowercase_first(s: Union[str, Tuple[str, bool]]):
"""Lowercase the first character"""
if isinstance(s, tuple):
return PVP.lowercase_first(s[0]), s[1]
return s[0].lower() + s[1:]
def encode(self, example: InputExample, priming: bool = False, labeled: bool = False) \
-> Tuple[List[int], List[int]]:
"""
Encode an input example using this pattern-verbalizer pair.
:param example: the input example to encode
:param priming: whether to use this example for priming
:param labeled: if ``priming=True``, whether the label should be appended to this example
:return: A tuple, consisting of a list of input ids and a list of token type ids
"""
if not priming:
assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true"
tokenizer = self.wrapper.tokenizer # type: PreTrainedTokenizer
parts_a, parts_b = self.get_parts(example)
kwargs = {'add_prefix_space': True} if isinstance(tokenizer, GPT2Tokenizer) else {}
parts_a = [x if isinstance(x, tuple) else (x, False) for x in parts_a]
parts_a = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_a if x]
if parts_b:
parts_b = [x if isinstance(x, tuple) else (x, False) for x in parts_b]
parts_b = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_b if x]
self.truncate(parts_a, parts_b, max_length=self.wrapper.config.max_seq_length)
tokens_a = [token_id for part, _ in parts_a for token_id in part]
tokens_b = [token_id for part, _ in parts_b for token_id in part] if parts_b else None
if priming:
input_ids = tokens_a
if tokens_b:
input_ids += tokens_b
if labeled:
mask_idx = input_ids.index(self.mask_id)
assert mask_idx >= 0, 'sequence of input_ids must contain a mask token'
assert len(self.verbalize(example.label)) == 1, 'priming only supports one verbalization per label'
verbalizer = self.verbalize(example.label)[0]
verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True)
input_ids[mask_idx] = verbalizer_id
return input_ids, []
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
return input_ids, token_type_ids
@staticmethod
def _seq_length(parts: List[Tuple[str, bool]], only_shortenable: bool = False):
return sum([len(x) for x, shortenable in parts if not only_shortenable or shortenable]) if parts else 0
@staticmethod
def _remove_last(parts: List[Tuple[str, bool]]):
last_idx = max(idx for idx, (seq, shortenable) in enumerate(parts) if shortenable and seq)
parts[last_idx] = (parts[last_idx][0][:-1], parts[last_idx][1])
def truncate(self, parts_a: List[Tuple[str, bool]], parts_b: List[Tuple[str, bool]], max_length: int):
"""Truncate two sequences of text to a predefined total maximum length"""
total_len = self._seq_length(parts_a) + self._seq_length(parts_b)
total_len += self.wrapper.tokenizer.num_special_tokens_to_add(bool(parts_b))
num_tokens_to_remove = total_len - max_length
if num_tokens_to_remove <= 0:
return parts_a, parts_b
for _ in range(num_tokens_to_remove):
if self._seq_length(parts_a, only_shortenable=True) > self._seq_length(parts_b, only_shortenable=True):
self._remove_last(parts_a)
else:
self._remove_last(parts_b)
@abstractmethod
def get_parts(self, example: InputExample) -> FilledPattern:
"""
Given an input example, apply a pattern to obtain two text sequences (text_a and text_b) containing exactly one
mask token (or one consecutive sequence of mask tokens for PET with multiple masks). If a task requires only a
single sequence of text, the second sequence should be an empty list.
:param example: the input example to process
:return: Two sequences of text. All text segments can optionally be marked as being shortenable.
"""
pass
@abstractmethod
def verbalize(self, label) -> List[str]:
"""
Return all verbalizations for a given label.
:param label: the label
:return: the list of verbalizations
"""
pass
def get_mask_positions(self, input_ids: List[int]) -> List[int]:
label_idx = input_ids.index(self.mask_id)
labels = [-1] * len(input_ids)
labels[label_idx] = 1
return labels
def convert_mlm_logits_to_cls_logits(self, mlm_labels: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
masked_logits = logits[mlm_labels >= 0]
cls_logits = torch.stack([self._convert_single_mlm_logits_to_cls_logits(ml) for ml in masked_logits])
return cls_logits
def _convert_single_mlm_logits_to_cls_logits(self, logits: torch.Tensor) -> torch.Tensor:
m2c = self.mlm_logits_to_cls_logits_tensor.to(logits.device)
# filler_len.shape() == max_fillers
filler_len = torch.tensor([len(self.verbalize(label)) for label in self.wrapper.config.label_list],
dtype=torch.float)
filler_len = filler_len.to(logits.device)
# cls_logits.shape() == num_labels x max_fillers (and 0 when there are not as many fillers).
cls_logits = logits[torch.max(torch.zeros_like(m2c), m2c)]
cls_logits = cls_logits * (m2c > 0).float()
# cls_logits.shape() == num_labels
cls_logits = cls_logits.sum(axis=1) / filler_len
return cls_logits
def convert_plm_logits_to_cls_logits(self, logits: torch.Tensor) -> torch.Tensor:
assert logits.shape[1] == 1
logits = torch.squeeze(logits, 1) # remove second dimension as we always have exactly one <mask> per example
cls_logits = torch.stack([self._convert_single_mlm_logits_to_cls_logits(lgt) for lgt in logits])
return cls_logits
@staticmethod
def _load_verbalizer_from_file(path: str, pattern_id: int):
verbalizers = defaultdict(dict) # type: Dict[int, Dict[str, List[str]]]
current_pattern_id = None
with open(path, 'r') as fh:
for line in fh.read().splitlines():
if line.isdigit():
current_pattern_id = int(line)
elif line:
label, *realizations = line.split()
verbalizers[current_pattern_id][label] = realizations
logger.info("Automatically loaded the following verbalizer: \n {}".format(verbalizers[pattern_id]))
def verbalize(label) -> List[str]:
return verbalizers[pattern_id][label]
return verbalize
class AgnewsPVP(PVP):
VERBALIZER = {
"1": ["World"],
"2": ["Sports"],
"3": ["Business"],
"4": ["Tech"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
if self.pattern_id == 0:
return [self.mask, ':', text_a, text_b], []
elif self.pattern_id == 1:
return [self.mask, 'News:', text_a, text_b], []
elif self.pattern_id == 2:
return [text_a, '(', self.mask, ')', text_b], []
elif self.pattern_id == 3:
return [text_a, text_b, '(', self.mask, ')'], []
elif self.pattern_id == 4:
return ['[ Category:', self.mask, ']', text_a, text_b], []
elif self.pattern_id == 5:
return [self.mask, '-', text_a, text_b], []
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
return AgnewsPVP.VERBALIZER[label]
class YahooPVP(PVP):
VERBALIZER = {
"1": ["Society"],
"2": ["Science"],
"3": ["Health"],
"4": ["Education"],
"5": ["Computer"],
"6": ["Sports"],
"7": ["Business"],
"8": ["Entertainment"],
"9": ["Relationship"],
"10": ["Politics"],
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
if self.pattern_id == 0:
return [self.mask, ':', text_a, text_b], []
elif self.pattern_id == 1:
return [self.mask, 'Question:', text_a, text_b], []
elif self.pattern_id == 2:
return [text_a, '(', self.mask, ')', text_b], []
elif self.pattern_id == 3:
return [text_a, text_b, '(', self.mask, ')'], []
elif self.pattern_id == 4:
return ['[ Category:', self.mask, ']', text_a, text_b], []
elif self.pattern_id == 5:
return [self.mask, '-', text_a, text_b], []
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
return YahooPVP.VERBALIZER[label]
class MnliPVP(PVP):
VERBALIZER_A = {
"contradiction": ["Wrong"],
"entailment": ["Right"],
"neutral": ["Maybe"]
}
VERBALIZER_B = {
"contradiction": ["No"],
"entailment": ["Yes"],
"neutral": ["Maybe"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(self.remove_final_punc(example.text_a))
text_b = self.shortenable(example.text_b)
if self.pattern_id == 0 or self.pattern_id == 2:
return ['"', text_a, '" ?'], [self.mask, ', "', text_b, '"']
elif self.pattern_id == 1 or self.pattern_id == 3:
return [text_a, '?'], [self.mask, ',', text_b]
def verbalize(self, label) -> List[str]:
if self.pattern_id == 0 or self.pattern_id == 1:
return MnliPVP.VERBALIZER_A[label]
return MnliPVP.VERBALIZER_B[label]
class YelpPolarityPVP(PVP):
VERBALIZER = {
"1": ["bad"],
"2": ["good"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
text = self.shortenable(example.text_a)
if self.pattern_id == 0:
return ['It was', self.mask, '.', text], []
elif self.pattern_id == 1:
return [text, '. All in all, it was', self.mask, '.'], []
elif self.pattern_id == 2:
return ['Just', self.mask, "!"], [text]
elif self.pattern_id == 3:
return [text], ['In summary, the restaurant is', self.mask, '.']
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
return YelpPolarityPVP.VERBALIZER[label]
class YelpFullPVP(YelpPolarityPVP):
VERBALIZER = {
"1": ["terrible"],
"2": ["bad"],
"3": ["okay"],
"4": ["good"],
"5": ["great"]
}
def verbalize(self, label) -> List[str]:
return YelpFullPVP.VERBALIZER[label]
class XStancePVP(PVP):
VERBALIZERS = {
'en': {"FAVOR": ["Yes"], "AGAINST": ["No"]},
'de': {"FAVOR": ["Ja"], "AGAINST": ["Nein"]},
'fr': {"FAVOR": ["Oui"], "AGAINST": ["Non"]}
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4:
return ['"', text_a, '"'], [self.mask, '. "', text_b, '"']
elif self.pattern_id == 1 or self.pattern_id == 3 or self.pattern_id == 5:
return [text_a], [self.mask, '.', text_b]
def verbalize(self, label) -> List[str]:
lang = 'de' if self.pattern_id < 2 else 'en' if self.pattern_id < 4 else 'fr'
return XStancePVP.VERBALIZERS[lang][label]
class RtePVP(PVP):
VERBALIZER = {
"not_entailment": ["No"],
"entailment": ["Yes"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
# switch text_a and text_b to get the correct order
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b.rstrip(string.punctuation))
if self.pattern_id == 0:
return ['"', text_b, '" ?'], [self.mask, ', "', text_a, '"']
elif self.pattern_id == 1:
return [text_b, '?'], [self.mask, ',', text_a]
if self.pattern_id == 2:
return ['"', text_b, '" ?'], [self.mask, '. "', text_a, '"']
elif self.pattern_id == 3:
return [text_b, '?'], [self.mask, '.', text_a]
elif self.pattern_id == 4:
return [text_a, ' question: ', self.shortenable(example.text_b), ' True or False? answer:', self.mask], []
def verbalize(self, label) -> List[str]:
if self.pattern_id == 4:
return ['true'] if label == 'entailment' else ['false']
return RtePVP.VERBALIZER[label]
class CbPVP(RtePVP):
VERBALIZER = {
"contradiction": ["No"],
"entailment": ["Yes"],
"neutral": ["Maybe"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
if self.pattern_id == 4:
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
return [text_a, ' question: ', text_b, ' true, false or neither? answer:', self.mask], []
return super().get_parts(example)
def verbalize(self, label) -> List[str]:
if self.pattern_id == 4:
return ['true'] if label == 'entailment' else ['false'] if label == 'contradiction' else ['neither']
return CbPVP.VERBALIZER[label]
class CopaPVP(PVP):
def get_parts(self, example: InputExample) -> FilledPattern:
premise = self.remove_final_punc(self.shortenable(example.text_a))
choice1 = self.remove_final_punc(self.lowercase_first(example.meta['choice1']))
choice2 = self.remove_final_punc(self.lowercase_first(example.meta['choice2']))
question = example.meta['question']
assert question in ['cause', 'effect']
example.meta['choice1'], example.meta['choice2'] = choice1, choice2
num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in [choice1, choice2])
if question == 'cause':
if self.pattern_id == 0:
return ['"', choice1, '" or "', choice2, '"?', premise, 'because', self.mask * num_masks, '.'], []
elif self.pattern_id == 1:
return [choice1, 'or', choice2, '?', premise, 'because', self.mask * num_masks, '.'], []
else:
if self.pattern_id == 0:
return ['"', choice1, '" or "', choice2, '"?', premise, ', so', self.mask * num_masks, '.'], []
elif self.pattern_id == 1:
return [choice1, 'or', choice2, '?', premise, ', so', self.mask * num_masks, '.'], []
def verbalize(self, label) -> List[str]:
return []
class WscPVP(PVP):
def get_parts(self, example: InputExample) -> FilledPattern:
pronoun = example.meta['span2_text']
target = example.meta['span1_text']
pronoun_idx = example.meta['span2_index']
words_a = example.text_a.split()
words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*'
text_a = ' '.join(words_a)
text_a = self.shortenable(text_a)
num_pad = self.rng.randint(0, 3) if 'train' in example.guid else 1
num_masks = len(get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)) + num_pad
masks = self.mask * num_masks
if self.pattern_id == 0:
return [text_a, "The pronoun '*" + pronoun + "*' refers to", masks + '.'], []
elif self.pattern_id == 1:
return [text_a, "In the previous sentence, the pronoun '*" + pronoun + "*' refers to", masks + '.'], []
elif self.pattern_id == 2:
return [text_a,
"Question: In the passage above, what does the pronoun '*" + pronoun + "*' refer to? Answer: ",
masks + '.'], []
def verbalize(self, label) -> List[str]:
return []
class BoolQPVP(PVP):
VERBALIZER_A = {
"False": ["No"],
"True": ["Yes"]
}
VERBALIZER_B = {
"False": ["false"],
"True": ["true"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
passage = self.shortenable(example.text_a)
question = self.shortenable(example.text_b)
if self.pattern_id < 2:
return [passage, '. Question: ', question, '? Answer: ', self.mask, '.'], []
elif self.pattern_id < 4:
return [passage, '. Based on the previous passage, ', question, '?', self.mask, '.'], []
else:
return ['Based on the following passage, ', question, '?', self.mask, '.', passage], []
def verbalize(self, label) -> List[str]:
if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4:
return BoolQPVP.VERBALIZER_A[label]
else:
return BoolQPVP.VERBALIZER_B[label]
class MultiRcPVP(PVP):
VERBALIZER = {
"0": ["No"],
"1": ["Yes"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
passage = self.shortenable(example.text_a)
question = example.text_b
answer = example.meta['answer']
if self.pattern_id == 0:
return [passage, '. Question: ', question, '? Is it ', answer, '?', self.mask, '.'], []
if self.pattern_id == 1:
return [passage, '. Question: ', question, '? Is the correct answer "', answer, '"?', self.mask, '.'], []
if self.pattern_id == 2:
return [passage, '. Based on the previous passage, ', question, '? Is "', answer, '" a correct answer?',
self.mask, '.'], []
if self.pattern_id == 3:
return [passage, question, '- [', self.mask, ']', answer], []
def verbalize(self, label) -> List[str]:
if self.pattern_id == 3:
return ['False'] if label == "0" else ['True']
return MultiRcPVP.VERBALIZER[label]
class WicPVP(PVP):
VERBALIZER_A = {
"F": ["No"],
"T": ["Yes"]
}
VERBALIZER_B = {
"F": ["2"],
"T": ["b"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
word = example.meta['word']
if self.pattern_id == 0:
return ['"', text_a, '" / "', text_b, '" Similar sense of "' + word + '"?', self.mask, '.'], []
if self.pattern_id == 1:
return [text_a, text_b, 'Does ' + word + ' have the same meaning in both sentences?', self.mask], []
if self.pattern_id == 2:
return [word, ' . Sense (1) (a) "', text_a, '" (', self.mask, ') "', text_b, '"'], []
def verbalize(self, label) -> List[str]:
if self.pattern_id == 2:
return WicPVP.VERBALIZER_B[label]
return WicPVP.VERBALIZER_A[label]
class RecordPVP(PVP):
def get_parts(self, example: InputExample) -> FilledPattern:
premise = self.shortenable(example.text_a)
choices = example.meta['candidates']
assert '@placeholder' in example.text_b, f'question "{example.text_b}" does not contain a @placeholder token'
num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in choices)
question = example.text_b.replace('@placeholder', self.mask * num_masks)
return [premise, question], []
def verbalize(self, label) -> List[str]:
return []
PVPS = {
'agnews': AgnewsPVP,
'mnli': MnliPVP,
'yelp-polarity': YelpPolarityPVP,
'yelp-full': YelpFullPVP,
'yahoo': YahooPVP,
'xstance': XStancePVP,
'xstance-de': XStancePVP,
'xstance-fr': XStancePVP,
'rte': RtePVP,
'wic': WicPVP,
'cb': CbPVP,
'wsc': WscPVP,
'boolq': BoolQPVP,
'copa': CopaPVP,
'multirc': MultiRcPVP,
'record': RecordPVP,
'ax-b': RtePVP,
'ax-g': RtePVP,
}
class MyTaskPVP(PVP):
"""
Example for a pattern-verbalizer pair (PVP).
"""
# Set this to the name of the task
TASK_NAME = "my-task"
# Set this to the verbalizer for the given task: a mapping from the task's labels (which can be obtained using
# the corresponding DataProcessor's get_labels method) to tokens from the language model's vocabulary
#VERBALIZER = {"+1": ["Good"],"-1": ["Bad"],}
VERBALIZER = {1: ["fascinating",'effective',"irresistible","thrilling"], # "breathtaking",
0: ["boring","embarrassing","weird","nothing"],} #"depressing",
def get_parts(self, example: InputExample):
text_a = self.shortenable(example.text_a)
if self.pattern_id == 0:
#*cls*_A*mask*_idea.*+sent_0**sep+* 93 {"0": "boring", "1": "fascinating"}
return ['A',self.mask,'idea.',text_a],[]
if self.pattern_id == 1:
# *cls*_A*mask*_show.*+sent_0**sep+* 95 {"0": "embarrassing", "1": "effective"}
return ['A',self.mask,'show.',text_a],[]
if self.pattern_id == 2:
# *cls*_The_story_is*mask*.*+sent_0**sep+* 81 {"0": "depressing", "1": "irresistible"}
return ['The story is',self.mask,'.',text_a],[]
if self.pattern_id == 3:
# *cls**sent_0*_The_result_is*mask*.*sep+* 49 {"0": "weird", "1": "breathtaking"}
return [text_a,'The result is', self.mask,'.'],[]
if self.pattern_id == 4:
# *cls**sent_0*_Very*mask*!*sep+* 30 {"0": "embarrassing", "1": "thrilling"}
return [text_a,'Very',self.mask,'!'],[]
if self.pattern_id == 5:
# *cls**sent_0*_This_is*mask*.*sep+* 7 {"0": "nothing", "1": "thrilling"}
return [text_a,'This is',self.mask,'.'],[]
def verbalize(self, label) -> List[str]:
return MyTaskPVP.VERBALIZER[label]
class MyTaskPVP2(PVP):
"""
Example for a pattern-verbalizer pair (PVP).
"""
# Set this to the name of the task
TASK_NAME = "my-task2"
# Set this to the verbalizer for the given task: a mapping from the task's labels (which can be obtained using
# the corresponding DataProcessor's get_labels method) to tokens from the language model's vocabulary
#VERBALIZER = {"+1": ["Good"],"-1": ["Bad"],}
VERBALIZER = {1: ["Good"], # "breathtaking",
0: ["Bad"],} #"depressing",
def get_parts(self, example: InputExample):
text_a = self.shortenable(example.text_a)
return [text_a,'It is',self.mask,'.'],[]
def verbalize(self, label) -> List[str]:
return MyTaskPVP.VERBALIZER[label]
class AutoBest5(PVP):
"""
Example for a pattern-verbalizer pair (PVP).
"""
# Set this to the name of the task
TASK_NAME = "autobest5"
# Set this to the verbalizer for the given task: a mapping from the task's labels (which can be obtained using
# the corresponding DataProcessor's get_labels method) to tokens from the language model's vocabulary
#VERBALIZER = {"+1": ["Good"],"-1": ["Bad"],}
VERBALIZER = {1: ["good"], # "breathtaking",
0: ["bad"],} #"depressing",
def get_parts(self, example: InputExample):
text_a = self.shortenable(example.text_a)
'''
0.89967 *cls*_A*mask*_film!*+sent_0**sep+*
0.89298 *cls**sent_0*_This_is_really*mask*.*sep+*
0.89298 *cls*_It_is*mask*!*+sent_0**sep+*
0.89298 *cls*_This_movie_is*mask*.*+sent_0**sep+*
0.88963 *cls*_Just*mask*.*+sent_0**sep+*
0.88963 *cls*_The_movie_is*mask*.*+sent_0**sep+*
0.88963 *cls*_This_film_is*mask*.*+sent_0**sep+*
0.88629 *cls**sent_0*_It's*mask*.*sep+*
0.88629 *cls**sent_0*_The_movie_is*mask*.*sep+*
0.88629 *cls**sent_0*_This_is_just*mask*.*sep+*
'''
if self.pattern_id == 0:
return ['A',self.mask,'film.',text_a],[]
if self.pattern_id == 1:
return [text_a,'This is really',self.mask,'.'],[]
if self.pattern_id == 2:
return ['It is',self.mask,'!',text_a],[]
if self.pattern_id == 3:
return ['The movie is',self.mask,'.',text_a],[]
# *cls**sent_0*_The_result_is*mask*.*sep+* 49 {"0": "weird", "1": "breathtaking"}
#return [text_a,'The result is', self.mask,'.'],[]
if self.pattern_id == 4:
return ['Just',self.mask,'.',text_a],[]
# *cls**sent_0*_Very*mask*!*sep+* 30 {"0": "embarrassing", "1": "thrilling"}
#return [text_a,'Very',self.mask,'!'],[]
#if self.pattern_id == 5:
def verbalize(self, label) -> List[str]:
return MyTaskPVP.VERBALIZER[label]
# register the PVP for this task with its name
PVPS[MyTaskPVP.TASK_NAME] = MyTaskPVP
PVPS[MyTaskPVP2.TASK_NAME] = MyTaskPVP2
PVPS[AutoBest5.TASK_NAME] = AutoBest5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains the logic for loading training and test data for all tasks.
"""
import csv
import json
import os
import random
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
from typing import List, Dict, Callable
import log
from pet import task_helpers
from pet.utils import InputExample
logger = log.get_logger('root')
def _shuffle_and_restrict(examples: List[InputExample], num_examples: int, seed: int = 42) -> List[InputExample]:
"""
Shuffle a list of examples and restrict it to a given maximum size.
:param examples: the examples to shuffle and restrict
:param num_examples: the maximum number of examples
:param seed: the random seed for shuffling
:return: the first ``num_examples`` elements of the shuffled list
"""
if 0 < num_examples < len(examples):
random.Random(seed).shuffle(examples)
examples = examples[:num_examples]
return examples
class LimitedExampleList:
def __init__(self, labels: List[str], max_examples=-1):
"""
Implementation of a list that stores only a limited amount of examples per label.
:param labels: the set of all possible labels
:param max_examples: the maximum number of examples per label. This can either be a fixed number,
in which case `max_examples` examples are loaded for every label, or a list with the same size as
`labels`, in which case at most `max_examples[i]` examples are loaded for label `labels[i]`.
"""
self._labels = labels
self._examples = []
self._examples_per_label = defaultdict(int)
if isinstance(max_examples, list):
self._max_examples = dict(zip(self._labels, max_examples))
else:
self._max_examples = {label: max_examples for label in self._labels}
def is_full(self):
"""Return `true` iff no more examples can be added to this list"""
for label in self._labels:
if self._examples_per_label[label] < self._max_examples[label] or self._max_examples[label] < 0:
return False
return True
def add(self, example: InputExample) -> bool:
"""
Add a new input example to this list.
:param example: the example to add
:returns: `true` iff the example was actually added to the list
"""
label = example.label
if self._examples_per_label[label] < self._max_examples[label] or self._max_examples[label] < 0:
self._examples_per_label[label] += 1
self._examples.append(example)
return True
return False
def to_list(self):
return self._examples
class DataProcessor(ABC):
"""
Abstract class that provides methods for loading training, testing, development and unlabeled examples for a given
task
"""
@abstractmethod
def get_train_examples(self, data_dir) -> List[InputExample]:
"""Get a collection of `InputExample`s for the train set."""
pass
@abstractmethod
def get_dev_examples(self, data_dir) -> List[InputExample]:
"""Get a collection of `InputExample`s for the dev set."""
pass
@abstractmethod
def get_test_examples(self, data_dir) -> List[InputExample]:
"""Get a collection of `InputExample`s for the test set."""
pass
@abstractmethod
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
"""Get a collection of `InputExample`s for the unlabeled set."""
pass
@abstractmethod
def get_labels(self) -> List[str]:
"""Get the list of labels for this data set."""
pass
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
return self._create_examples(MnliProcessor._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(MnliProcessor._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
return self.get_train_examples(data_dir)
def get_labels(self):
return ["contradiction", "entailment", "neutral"]
@staticmethod
def _create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]:
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[8]
text_b = line[9]
label = line[-1]
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
return examples
@staticmethod
def _read_tsv(input_file, quotechar=None):
with open(input_file, "r", encoding="utf-8-sig") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
class MnliMismatchedProcessor(MnliProcessor):
"""Processor for the MultiNLI mismatched data set (GLUE version)."""
def get_dev_examples(self, data_dir):
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
class AgnewsProcessor(DataProcessor):
"""Processor for the AG news data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.csv"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.csv"), "dev")
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
return self.get_train_examples(data_dir)
def get_labels(self):
return ["1", "2", "3", "4"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path) as f:
reader = csv.reader(f, delimiter=',')
for idx, row in enumerate(reader):
label, headline, body = row
guid = "%s-%s" % (set_type, idx)
text_a = headline.replace('\\', ' ')
text_b = body.replace('\\', ' ')
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
return examples
class YahooAnswersProcessor(DataProcessor):
"""Processor for the Yahoo Answers data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.csv"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.csv"), "dev")
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
return self.get_train_examples(data_dir)
def get_labels(self):
return ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
reader = csv.reader(f, delimiter=',')
for idx, row in enumerate(reader):
label, question_title, question_body, answer = row
guid = "%s-%s" % (set_type, idx)
text_a = ' '.join([question_title.replace('\\n', ' ').replace('\\', ' '),
question_body.replace('\\n', ' ').replace('\\', ' ')])
text_b = answer.replace('\\n', ' ').replace('\\', ' ')
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
return examples
class YelpPolarityProcessor(DataProcessor):
"""Processor for the YELP binary classification set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.csv"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.csv"), "dev")
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
return self.get_train_examples(data_dir)
def get_labels(self):
return ["1", "2"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path) as f:
reader = csv.reader(f, delimiter=',')
for idx, row in enumerate(reader):
label, body = row
guid = "%s-%s" % (set_type, idx)
text_a = body.replace('\\n', ' ').replace('\\', ' ')
example = InputExample(guid=guid, text_a=text_a, label=label)
examples.append(example)
return examples
class YelpFullProcessor(YelpPolarityProcessor):
"""Processor for the YELP full classification set."""
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_labels(self):
return ["1", "2", "3", "4", "5"]
class XStanceProcessor(DataProcessor):
"""Processor for the X-Stance data set."""
def __init__(self, language: str = None):
if language is not None:
assert language in ['de', 'fr']
self.language = language
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"))
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"))
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
return self.get_train_examples(data_dir)
def get_labels(self):
return ["FAVOR", "AGAINST"]
def _create_examples(self, path: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
label = example_json['label']
id_ = example_json['id']
text_a = example_json['question']
text_b = example_json['comment']
language = example_json['language']
if self.language is not None and language != self.language:
continue
example = InputExample(guid=id_, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
return examples
class RteProcessor(DataProcessor):
"""Processor for the RTE data set."""
def __init__(self):
self.mnli_processor = MnliProcessor()
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["entailment", "not_entailment"]
def _create_examples(self, path: str, set_type: str, hypothesis_name: str = "hypothesis",
premise_name: str = "premise") -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line_idx, line in enumerate(f):
example_json = json.loads(line)
idx = example_json['idx']
if isinstance(idx, str):
try:
idx = int(idx)
except ValueError:
idx = line_idx
label = example_json.get('label')
guid = "%s-%s" % (set_type, idx)
text_a = example_json[premise_name]
text_b = example_json[hypothesis_name]
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx)
examples.append(example)
return examples
class AxGProcessor(RteProcessor):
"""Processor for the AX-G diagnostic data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "AX-g.jsonl"), "train")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "AX-g.jsonl"), "test")
class AxBProcessor(RteProcessor):
"""Processor for the AX-B diagnostic data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "AX-b.jsonl"), "train")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "AX-b.jsonl"), "test")
def _create_examples(self, path, set_type, hypothesis_name="sentence2", premise_name="sentence1"):
return super()._create_examples(path, set_type, hypothesis_name, premise_name)
class CbProcessor(RteProcessor):
"""Processor for the CB data set."""
def get_labels(self):
return ["entailment", "contradiction", "neutral"]
class WicProcessor(DataProcessor):
"""Processor for the WiC data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["F", "T"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
idx = example_json['idx']
if isinstance(idx, str):
idx = int(idx)
label = "T" if example_json.get('label') else "F"
guid = "%s-%s" % (set_type, idx)
text_a = example_json['sentence1']
text_b = example_json['sentence2']
meta = {'word': example_json['word']}
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx, meta=meta)
examples.append(example)
return examples
class WscProcessor(DataProcessor):
"""Processor for the WSC data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["False", "True"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
idx = example_json['idx']
label = str(example_json['label']) if 'label' in example_json else None
guid = "%s-%s" % (set_type, idx)
text_a = example_json['text']
meta = {
'span1_text': example_json['target']['span1_text'],
'span2_text': example_json['target']['span2_text'],
'span1_index': example_json['target']['span1_index'],
'span2_index': example_json['target']['span2_index']
}
# the indices in the dataset are wrong for some examples, so we manually fix them
span1_index, span1_text = meta['span1_index'], meta['span1_text']
span2_index, span2_text = meta['span2_index'], meta['span2_text']
words_a = text_a.split()
words_a_lower = text_a.lower().split()
words_span1_text = span1_text.lower().split()
span1_len = len(words_span1_text)
if words_a_lower[span1_index:span1_index + span1_len] != words_span1_text:
for offset in [-1, +1]:
if words_a_lower[span1_index + offset:span1_index + span1_len + offset] == words_span1_text:
span1_index += offset
if words_a_lower[span1_index:span1_index + span1_len] != words_span1_text:
logger.warning(f"Got '{words_a_lower[span1_index:span1_index + span1_len]}' but expected "
f"'{words_span1_text}' at index {span1_index} for '{words_a}'")
if words_a[span2_index] != span2_text:
for offset in [-1, +1]:
if words_a[span2_index + offset] == span2_text:
span2_index += offset
if words_a[span2_index] != span2_text and words_a[span2_index].startswith(span2_text):
words_a = words_a[:span2_index] \
+ [words_a[span2_index][:len(span2_text)], words_a[span2_index][len(span2_text):]] \
+ words_a[span2_index + 1:]
assert words_a[span2_index] == span2_text, \
f"Got '{words_a[span2_index]}' but expected '{span2_text}' at index {span2_index} for '{words_a}'"
text_a = ' '.join(words_a)
meta['span1_index'], meta['span2_index'] = span1_index, span2_index
example = InputExample(guid=guid, text_a=text_a, label=label, meta=meta, idx=idx)
if set_type == 'train' and label != 'True':
continue
examples.append(example)
return examples
class BoolQProcessor(DataProcessor):
"""Processor for the BoolQ data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["False", "True"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
idx = example_json['idx']
label = str(example_json['label']) if 'label' in example_json else None
guid = "%s-%s" % (set_type, idx)
text_a = example_json['passage']
text_b = example_json['question']
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx)
examples.append(example)
return examples
class CopaProcessor(DataProcessor):
"""Processor for the COPA data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["0", "1"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
label = str(example_json['label']) if 'label' in example_json else None
idx = example_json['idx']
guid = "%s-%s" % (set_type, idx)
text_a = example_json['premise']
meta = {
'choice1': example_json['choice1'],
'choice2': example_json['choice2'],
'question': example_json['question']
}
example = InputExample(guid=guid, text_a=text_a, label=label, meta=meta, idx=idx)
examples.append(example)
if set_type == 'train' or set_type == 'unlabeled':
mirror_examples = []
for ex in examples:
label = "1" if ex.label == "0" else "0"
meta = {
'choice1': ex.meta['choice2'],
'choice2': ex.meta['choice1'],
'question': ex.meta['question']
}
mirror_example = InputExample(guid=ex.guid + 'm', text_a=ex.text_a, label=label, meta=meta)
mirror_examples.append(mirror_example)
examples += mirror_examples
logger.info(f"Added {len(mirror_examples)} mirror examples, total size is {len(examples)}...")
return examples
class MultiRcProcessor(DataProcessor):
"""Processor for the MultiRC data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["0", "1"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
passage_idx = example_json['idx']
text = example_json['passage']['text']
questions = example_json['passage']['questions']
for question_json in questions:
question = question_json["question"]
question_idx = question_json['idx']
answers = question_json["answers"]
for answer_json in answers:
label = str(answer_json["label"]) if 'label' in answer_json else None
answer_idx = answer_json["idx"]
guid = f'{set_type}-p{passage_idx}-q{question_idx}-a{answer_idx}'
meta = {
'passage_idx': passage_idx,
'question_idx': question_idx,
'answer_idx': answer_idx,
'answer': answer_json["text"]
}
idx = [passage_idx, question_idx, answer_idx]
example = InputExample(guid=guid, text_a=text, text_b=question, label=label, meta=meta, idx=idx)
examples.append(example)
question_indices = list(set(example.meta['question_idx'] for example in examples))
label_distribution = Counter(example.label for example in examples)
logger.info(f"Returning {len(examples)} examples corresponding to {len(question_indices)} questions with label "
f"distribution {list(label_distribution.items())}")
return examples
class RecordProcessor(DataProcessor):
"""Processor for the ReCoRD data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["0", "1"]
@staticmethod
def _create_examples(path, set_type, seed=42, max_train_candidates_per_question: int = 10) -> List[InputExample]:
examples = []
entity_shuffler = random.Random(seed)
with open(path, encoding='utf8') as f:
for idx, line in enumerate(f):
example_json = json.loads(line)
idx = example_json['idx']
text = example_json['passage']['text']
entities = set()
for entity_json in example_json['passage']['entities']:
start = entity_json['start']
end = entity_json['end']
entity = text[start:end + 1]
entities.add(entity)
entities = list(entities)
text = text.replace("@highlight\n", "- ") # we follow the GPT-3 paper wrt @highlight annotations
questions = example_json['qas']
for question_json in questions:
question = question_json['query']
question_idx = question_json['idx']
answers = set()
for answer_json in question_json.get('answers', []):
answer = answer_json['text']
answers.add(answer)
answers = list(answers)
if set_type == 'train':
# create a single example per *correct* answer
for answer_idx, answer in enumerate(answers):
candidates = [ent for ent in entities if ent not in answers]
if len(candidates) > max_train_candidates_per_question - 1:
entity_shuffler.shuffle(candidates)
candidates = candidates[:max_train_candidates_per_question - 1]
guid = f'{set_type}-p{idx}-q{question_idx}-a{answer_idx}'
meta = {
'passage_idx': idx,
'question_idx': question_idx,
'candidates': [answer] + candidates,
'answers': [answer]
}
ex_idx = [idx, question_idx, answer_idx]
example = InputExample(guid=guid, text_a=text, text_b=question, label="1", meta=meta,
idx=ex_idx)
examples.append(example)
else:
# create just one example with *all* correct answers and *all* answer candidates
guid = f'{set_type}-p{idx}-q{question_idx}'
meta = {
'passage_idx': idx,
'question_idx': question_idx,
'candidates': entities,
'answers': answers
}
example = InputExample(guid=guid, text_a=text, text_b=question, label="1", meta=meta)
examples.append(example)
question_indices = list(set(example.meta['question_idx'] for example in examples))
label_distribution = Counter(example.label for example in examples)
logger.info(f"Returning {len(examples)} examples corresponding to {len(question_indices)} questions with label "
f"distribution {list(label_distribution.items())}")
return examples
PROCESSORS = {
"mnli": MnliProcessor,
"mnli-mm": MnliMismatchedProcessor,
"agnews": AgnewsProcessor,
"yahoo": YahooAnswersProcessor,
"yelp-polarity": YelpPolarityProcessor,
"yelp-full": YelpFullProcessor,
"xstance-de": lambda: XStanceProcessor("de"),
"xstance-fr": lambda: XStanceProcessor("fr"),
"xstance": XStanceProcessor,
"wic": WicProcessor,
"rte": RteProcessor,
"cb": CbProcessor,
"wsc": WscProcessor,
"boolq": BoolQProcessor,
"copa": CopaProcessor,
"multirc": MultiRcProcessor,
"record": RecordProcessor,
"ax-g": AxGProcessor,
"ax-b": AxBProcessor,
} # type: Dict[str,Callable[[],DataProcessor]]
class MyTaskDataProcessor(DataProcessor):
"""
Example for a data processor.
"""
# Set this to the name of the task
TASK_NAME = "my-task"
# Set this to the name of the file containing the train examples
TRAIN_FILE_NAME = "train.tsv"
# Set this to the name of the file containing the dev examples
DEV_FILE_NAME = "dev.tsv"
# Set this to the name of the file containing the test examples
TEST_FILE_NAME = "test.tsv"
# Set this to the name of the file containing the unlabeled examples
UNLABELED_FILE_NAME = "unlabeled.tsv"
# Set this to a list of all labels in the train + test data
#LABELS = ["+1", "-1"]
LABELS = [1, 0]
# Set this to the column of the train/test csv files containing the input's text a
TEXT_COLUMN = 0
# Set this to the column of the train/test csv files containing the input's gold label
LABEL_COLUMN = 1
def get_train_examples(self, data_dir: str) -> List[InputExample]:
"""
This method loads train examples from a file with name `TRAIN_FILE_NAME` in the given directory.
:param data_dir: the directory in which the training data can be found
:return: a list of train examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.TRAIN_FILE_NAME), "train")
def get_dev_examples(self, data_dir: str) -> List[InputExample]:
"""
This method loads dev examples from a file with name `DEV_FILE_NAME` in the given directory.
:param data_dir: the directory in which the dev data can be found
:return: a list of dev examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.DEV_FILE_NAME), "dev")
def get_test_examples(self, data_dir) -> List[InputExample]:
"""
This method loads test examples from a file with name `TEST_FILE_NAME` in the given directory.
:param data_dir: the directory in which the test data can be found
:return: a list of test examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.TEST_FILE_NAME), "test")
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
"""
This method loads unlabeled examples from a file with name `UNLABELED_FILE_NAME` in the given directory.
:param data_dir: the directory in which the unlabeled data can be found
:return: a list of unlabeled examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.UNLABELED_FILE_NAME), "unlabeled")
def get_labels(self) -> List[str]:
"""This method returns all possible labels for the task."""
return MyTaskDataProcessor.LABELS
def _create_examples(self, path, set_type, max_examples=-1, skip_first=0):
"""Creates examples for the training and dev sets."""
examples = []
with open(path) as f:
reader = csv.reader(f, delimiter='\t')
for idx, row in enumerate(reader):
if idx!=0:
guid = "%s-%s" % (set_type, idx)
label = int(row[MyTaskDataProcessor.LABEL_COLUMN] )
# print(label)
text = row[MyTaskDataProcessor.TEXT_COLUMN]
# text_b = row[MyTaskDataProcessor.TEXT_B_COLUMN] if MyTaskDataProcessor.TEXT_B_COLUMN >= 0 else None
example = InputExample(guid=guid, text_a=text, label=label)
examples.append(example)
return examples
class MyTaskDataProcessor2(DataProcessor):
"""
Example for a data processor.
"""
# Set this to the name of the task
TASK_NAME = "my-task2"
# Set this to the name of the file containing the train examples
TRAIN_FILE_NAME = "train.tsv"
# Set this to the name of the file containing the dev examples
DEV_FILE_NAME = "dev.tsv"
# Set this to the name of the file containing the test examples
TEST_FILE_NAME = "test.tsv"
# Set this to the name of the file containing the unlabeled examples
UNLABELED_FILE_NAME = "unlabeled.tsv"
# Set this to a list of all labels in the train + test data
#LABELS = ["+1", "-1"]
LABELS = [1, 0]
# Set this to the column of the train/test csv files containing the input's text a
TEXT_COLUMN = 0
# Set this to the column of the train/test csv files containing the input's gold label
LABEL_COLUMN = 1
def get_train_examples(self, data_dir: str) -> List[InputExample]:
"""
This method loads train examples from a file with name `TRAIN_FILE_NAME` in the given directory.
:param data_dir: the directory in which the training data can be found
:return: a list of train examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.TRAIN_FILE_NAME), "train")
def get_dev_examples(self, data_dir: str) -> List[InputExample]:
"""
This method loads dev examples from a file with name `DEV_FILE_NAME` in the given directory.
:param data_dir: the directory in which the dev data can be found
:return: a list of dev examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.DEV_FILE_NAME), "dev")
def get_test_examples(self, data_dir) -> List[InputExample]:
"""
This method loads test examples from a file with name `TEST_FILE_NAME` in the given directory.
:param data_dir: the directory in which the test data can be found
:return: a list of test examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.TEST_FILE_NAME), "test")
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
"""
This method loads unlabeled examples from a file with name `UNLABELED_FILE_NAME` in the given directory.
:param data_dir: the directory in which the unlabeled data can be found
:return: a list of unlabeled examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.UNLABELED_FILE_NAME), "unlabeled")
def get_labels(self) -> List[str]:
"""This method returns all possible labels for the task."""
return MyTaskDataProcessor.LABELS
def _create_examples(self, path, set_type, max_examples=-1, skip_first=0):
"""Creates examples for the training and dev sets."""
examples = []
with open(path) as f:
reader = csv.reader(f, delimiter='\t')
for idx, row in enumerate(reader):
if idx!=0:
guid = "%s-%s" % (set_type, idx)
label = int(row[MyTaskDataProcessor.LABEL_COLUMN] )
# print(label)
text = row[MyTaskDataProcessor.TEXT_COLUMN]
# text_b = row[MyTaskDataProcessor.TEXT_B_COLUMN] if MyTaskDataProcessor.TEXT_B_COLUMN >= 0 else None
example = InputExample(guid=guid, text_a=text, label=label)
examples.append(example)
return examples
# register the processor for this task with its name
PROCESSORS[MyTaskDataProcessor.TASK_NAME] = MyTaskDataProcessor
PROCESSORS[MyTaskDataProcessor2.TASK_NAME] = MyTaskDataProcessor2
PROCESSORS['autobest5'] = MyTaskDataProcessor2
TASK_HELPERS = {
"wsc": task_helpers.WscTaskHelper,
"multirc": task_helpers.MultiRcTaskHelper,
"copa": task_helpers.CopaTaskHelper,
"record": task_helpers.RecordTaskHelper,
}
METRICS = {
"cb": ["acc", "f1-macro"],
"multirc": ["acc", "f1", "em"]
}
DEFAULT_METRICS = ["acc"]
TRAIN_SET = "train"
DEV_SET = "dev"
TEST_SET = "test"
UNLABELED_SET = "unlabeled"
SET_TYPES = [TRAIN_SET, DEV_SET, TEST_SET, UNLABELED_SET]
def load_examples(task, data_dir: str, set_type: str, *_, num_examples: int = None,
num_examples_per_label: int = None, seed: int = 42) -> List[InputExample]:
"""Load examples for a given task."""
assert (num_examples is not None) ^ (num_examples_per_label is not None), \
"Exactly one of 'num_examples' and 'num_examples_per_label' must be set."
assert (not set_type == UNLABELED_SET) or (num_examples is not None), \
"For unlabeled data, 'num_examples_per_label' is not allowed"
processor = PROCESSORS[task]()
ex_str = f"num_examples={num_examples}" if num_examples is not None \
else f"num_examples_per_label={num_examples_per_label}"
logger.info(
f"Creating features from dataset file at {data_dir} ({ex_str}, set_type={set_type})"
)
if set_type == DEV_SET:
examples = processor.get_dev_examples(data_dir)
elif set_type == TEST_SET:
examples = processor.get_test_examples(data_dir)
elif set_type == TRAIN_SET:
examples = processor.get_train_examples(data_dir)
elif set_type == UNLABELED_SET:
examples = processor.get_unlabeled_examples(data_dir)
for example in examples:
example.label = processor.get_labels()[0]
else:
raise ValueError(f"'set_type' must be one of {SET_TYPES}, got '{set_type}' instead")
if num_examples is not None:
examples = _shuffle_and_restrict(examples, num_examples, seed)
elif num_examples_per_label is not None:
limited_examples = LimitedExampleList(processor.get_labels(), num_examples_per_label)
for example in examples:
limited_examples.add(example)
examples = limited_examples.to_list()
label_distribution = Counter(example.label for example in examples)
logger.info(f"Returning {len(examples)} {set_type} examples with label dist.: {list(label_distribution.items())}")
return examples
{
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
from pet.modeling import *
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import json
import os
import random
import statistics
from abc import ABC
from collections import defaultdict
from copy import deepcopy
from typing import List, Dict
import numpy as np
import torch
from sklearn.metrics import f1_score
from transformers.data.metrics import simple_accuracy
import log
from pet.utils import InputExample, exact_match, save_logits, save_predictions, softmax, LogitsList, set_seed, eq_div
from pet.wrapper import TransformerModelWrapper, SEQUENCE_CLASSIFIER_WRAPPER, WrapperConfig
logger = log.get_logger('root')
class PetConfig(ABC):
"""Abstract class for a PET configuration that can be saved to and loaded from a json file."""
def __repr__(self):
return repr(self.__dict__)
def save(self, path: str):
"""Save this config to a file."""
with open(path, 'w', encoding='utf8') as fh:
json.dump(self.__dict__, fh)
@classmethod
def load(cls, path: str):
"""Load a config from a file."""
cfg = cls.__new__(cls)
with open(path, 'r', encoding='utf8') as fh:
cfg.__dict__ = json.load(fh)
return cfg
class TrainConfig(PetConfig):
"""Configuration for training a model."""
def __init__(self, device: str = None, per_gpu_train_batch_size: int = 8, per_gpu_unlabeled_batch_size: int = 8,
n_gpu: int = 1, num_train_epochs: int = 3, max_steps: int = -1, gradient_accumulation_steps: int = 1,
weight_decay: float = 0.0, learning_rate: float = 5e-5, adam_epsilon: float = 1e-8,
warmup_steps: int = 0, max_grad_norm: float = 1, lm_training: bool = False, use_logits: bool = False,
alpha: float = 0.9999, temperature: float = 1):
"""
Create a new training config.
:param device: the device to use ('cpu' or 'gpu')
:param per_gpu_train_batch_size: the number of labeled training examples per batch and gpu
:param per_gpu_unlabeled_batch_size: the number of unlabeled examples per batch and gpu
:param n_gpu: the number of gpus to use
:param num_train_epochs: the number of epochs to train for
:param max_steps: the maximum number of steps to train for (overrides ``num_train_epochs``)
:param gradient_accumulation_steps: the number of steps to accumulate gradients for before performing an update
:param weight_decay: the weight decay to use
:param learning_rate: the maximum learning rate to use
:param adam_epsilon: the epsilon value for Adam
:param warmup_steps: the number of warmup steps to perform before reaching the maximum learning rate
:param max_grad_norm: the maximum norm for the gradient
:param lm_training: whether to perform auxiliary language modeling (only for MLMs)
:param use_logits: whether to use each training example's logits instead of its label (used for distillation)
:param alpha: the alpha parameter for auxiliary language modeling
:param temperature: the temperature for distillation
"""
self.device = device
self.per_gpu_train_batch_size = per_gpu_train_batch_size
self.per_gpu_unlabeled_batch_size = per_gpu_unlabeled_batch_size
self.n_gpu = n_gpu
self.num_train_epochs = num_train_epochs
self.max_steps = max_steps
self.gradient_accumulation_steps = gradient_accumulation_steps
self.weight_decay = weight_decay
self.learning_rate = learning_rate
self.adam_epsilon = adam_epsilon
self.warmup_steps = warmup_steps
self.max_grad_norm = max_grad_norm
self.lm_training = lm_training
self.use_logits = use_logits
self.alpha = alpha
self.temperature = temperature
class EvalConfig(PetConfig):
"""Configuration for evaluating a model."""
def __init__(self, device: str = None, n_gpu: int = 1, per_gpu_eval_batch_size: int = 8,
metrics: List[str] = None, decoding_strategy: str = 'default', priming: bool = False):
"""
Create a new evaluation config.
:param device: the device to use ('cpu' or 'gpu')
:param n_gpu: the number of gpus to use
:param per_gpu_eval_batch_size: the number of evaluation examples per batch and gpu
:param metrics: the evaluation metrics to use (default: accuracy only)
:param decoding_strategy: the decoding strategy for PET with multiple masks ('default', 'ltr', or 'parallel')
:param priming: whether to use priming
"""
self.device = device
self.n_gpu = n_gpu
self.per_gpu_eval_batch_size = per_gpu_eval_batch_size
self.metrics = metrics
self.decoding_strategy = decoding_strategy
self.priming = priming
class IPetConfig(PetConfig):
"""Configuration for iterative PET training."""
def __init__(self, generations: int = 3, logits_percentage: float = 0.25, scale_factor: float = 5,
n_most_likely: int = -1):
"""
Create a new iPET config.
:param generations: the number of generations to train
:param logits_percentage: the percentage of models to use for annotating training sets for the next generation
:param scale_factor: the factor by which the training set is increased for each generation
:param n_most_likely: If >0, in the first generation the n_most_likely examples per label are chosen even
if their predicted label is different
"""
self.generations = generations
self.logits_percentage = logits_percentage
self.scale_factor = scale_factor
self.n_most_likely = n_most_likely
def init_model(config: WrapperConfig) -> TransformerModelWrapper:
"""Initialize a new model from the given config."""
assert config.pattern_id is not None, 'A pattern_id must be set for initializing a new PET model'
model = TransformerModelWrapper(config)
return model
def train_ipet(ensemble_model_config: WrapperConfig, ensemble_train_config: TrainConfig,
ensemble_eval_config: EvalConfig, ipet_config: IPetConfig, final_model_config: WrapperConfig,
final_train_config: TrainConfig, final_eval_config: EvalConfig, pattern_ids: List[int], output_dir: str,
ensemble_repetitions: int = 3, final_repetitions: int = 1, reduction: str = 'wmean',
train_data: List[InputExample] = None, unlabeled_data: List[InputExample] = None,
eval_data: List[InputExample] = None, do_train: bool = True, do_eval: bool = True, seed: int = 42):
"""
Train and evaluate a new iPET model for a given task.
:param ensemble_model_config: the model configuration for each model corresponding to an individual PVP
:param ensemble_train_config: the training configuration for each model corresponding to an individual PVP
:param ensemble_eval_config: the evaluation configuration for each model corresponding to an individual PVP
:param ipet_config: the iPET training configuration
:param final_model_config: the model configuration for the final distilled sequence classifier
:param final_train_config: the training configuration for the final distilled sequence classifier
:param final_eval_config: the evaluation configuration for the final distilled sequence classifier
:param pattern_ids: the ids of all PVPs to use
:param output_dir: the output directory
:param ensemble_repetitions: the number of training repetitions for each model corresponding to an individual PVP
:param final_repetitions: the number of training repetitions for the final distilled sequence classifier
:param reduction: the reduction strategy for merging predictions, either 'mean' or 'wmean'
:param train_data: the training examples to use
:param unlabeled_data: the unlabeled examples to use
:param eval_data: the evaluation examples to use
:param do_train: whether to perform training
:param do_eval: whether to perform evaluation
:param seed: the random seed to use
"""
for gen in range(ipet_config.generations):
gen_output_dir = os.path.join(output_dir, f'g{gen}')
# Step 1: Train an ensemble of models corresponding to individual patterns
ipet_data_dir = os.path.join(output_dir, f'g{gen - 1}', 'next-gen-train-data') if gen > 0 else None
train_pet_ensemble(ensemble_model_config, ensemble_train_config, ensemble_eval_config, pattern_ids,
gen_output_dir, ipet_data_dir=ipet_data_dir,
repetitions=ensemble_repetitions, train_data=train_data, unlabeled_data=unlabeled_data,
eval_data=eval_data, do_train=do_train, do_eval=do_eval, save_unlabeled_logits=True)
# Step 2: Use the model to annotate examples for the next generation
original_data_size = len(train_data) if train_data else 10 / ipet_config.scale_factor
num_new_examples = int(original_data_size * (ipet_config.scale_factor ** (gen + 1)) - len(train_data))
generate_ipet_train_sets(train_data=train_data, unlabeled_data=unlabeled_data,
labels=ensemble_model_config.label_list, logits_dir=gen_output_dir,
output_dir=os.path.join(gen_output_dir, 'next-gen-train-data'), reduction=reduction,
num_new_examples=num_new_examples, logits_percentage=ipet_config.logits_percentage,
n_most_likely=ipet_config.n_most_likely if gen == 0 else -1, seed=seed)
# Step 3: Merge the annotations created by each individual model
logits_dir = os.path.join(output_dir, f'g{ipet_config.generations - 1}')
logits_file = os.path.join(logits_dir, 'unlabeled_logits.txt')
merge_logits(logits_dir, logits_file, reduction)
logits = LogitsList.load(logits_file).logits
assert len(logits) == len(unlabeled_data)
logger.info("Got {} logits from file {}".format(len(logits), logits_file))
for example, example_logits in zip(unlabeled_data, logits):
example.logits = example_logits
# Step 4: Train the final sequence classifier model
final_model_config.wrapper_type = SEQUENCE_CLASSIFIER_WRAPPER
final_train_config.use_logits = True
train_classifier(final_model_config, final_train_config, final_eval_config, os.path.join(output_dir, 'final'),
repetitions=final_repetitions, train_data=train_data, unlabeled_data=unlabeled_data,
eval_data=eval_data, do_train=do_train, do_eval=do_eval)
def train_pet(ensemble_model_config: WrapperConfig, ensemble_train_config: TrainConfig,
ensemble_eval_config: EvalConfig, final_model_config: WrapperConfig, final_train_config: TrainConfig,
final_eval_config: EvalConfig, pattern_ids: List[int], output_dir: str, ensemble_repetitions: int = 3,
final_repetitions: int = 1, reduction: str = 'wmean', train_data: List[InputExample] = None,
unlabeled_data: List[InputExample] = None, eval_data: List[InputExample] = None, do_train: bool = True,
do_eval: bool = True, no_distillation: bool = False, seed: int = 42):
"""
Train and evaluate a new PET model for a given task.
:param ensemble_model_config: the model configuration for each model corresponding to an individual PVP
:param ensemble_train_config: the training configuration for each model corresponding to an individual PVP
:param ensemble_eval_config: the evaluation configuration for each model corresponding to an individual PVP
:param final_model_config: the model configuration for the final distilled sequence classifier
:param final_train_config: the training configuration for the final distilled sequence classifier
:param final_eval_config: the evaluation configuration for the final distilled sequence classifier
:param pattern_ids: the ids of all PVPs to use
:param output_dir: the output directory
:param ensemble_repetitions: the number of training repetitions for each model corresponding to an individual PVP
:param final_repetitions: the number of training repetitions for the final distilled sequence classifier
:param reduction: the reduction strategy for merging predictions, either 'mean' or 'wmean'
:param train_data: the training examples to use
:param unlabeled_data: the unlabeled examples to use
:param eval_data: the evaluation examples to use
:param do_train: whether to perform training
:param do_eval: whether to perform evaluation
:param no_distillation: if true, no distillation is performed
:param seed: the random seed to use
"""
# Step 1: Train an ensemble of models corresponding to individual patterns
train_pet_ensemble(ensemble_model_config, ensemble_train_config, ensemble_eval_config, pattern_ids, output_dir,
repetitions=ensemble_repetitions, train_data=train_data, unlabeled_data=unlabeled_data,
eval_data=eval_data, do_train=do_train, do_eval=do_eval,
save_unlabeled_logits=not no_distillation, seed=seed)
if no_distillation:
return
# Step 2: Merge the annotations created by each individual model
logits_file = os.path.join(output_dir, 'unlabeled_logits.txt')
merge_logits(output_dir, logits_file, reduction)
logits = LogitsList.load(logits_file).logits
assert len(logits) == len(unlabeled_data)
logger.info("Got {} logits from file {}".format(len(logits), logits_file))
for example, example_logits in zip(unlabeled_data, logits):
example.logits = example_logits
# Step 3: Train the final sequence classifier model
final_model_config.wrapper_type = SEQUENCE_CLASSIFIER_WRAPPER
final_train_config.use_logits = True
train_classifier(final_model_config, final_train_config, final_eval_config, os.path.join(output_dir, 'final'),
repetitions=final_repetitions, train_data=train_data, unlabeled_data=unlabeled_data,
eval_data=eval_data, do_train=do_train, do_eval=do_eval, seed=seed)
def train_classifier(model_config: WrapperConfig, train_config: TrainConfig, eval_config: EvalConfig, output_dir: str,
repetitions: int = 3, train_data: List[InputExample] = None,
unlabeled_data: List[InputExample] = None, eval_data: List[InputExample] = None,
do_train: bool = True, do_eval: bool = True, seed: int = 42):
"""
Train and evaluate a sequence classification model.
:param model_config: the model configuration to use
:param train_config: the training configuration to use
:param eval_config: the evaluation configuration to use
:param output_dir: the output directory
:param repetitions: the number of training repetitions
:param train_data: the training examples to use
:param unlabeled_data: the unlabeled examples to use
:param eval_data: the evaluation examples to use
:param do_train: whether to perform training
:param do_eval: whether to perform evaluation
:param seed: the random seed to use
"""
train_pet_ensemble(model_config, train_config, eval_config, pattern_ids=[0], output_dir=output_dir,
repetitions=repetitions,
train_data=train_data, unlabeled_data=unlabeled_data, eval_data=eval_data, do_train=do_train,
do_eval=do_eval, seed=seed)
def train_pet_ensemble(model_config: WrapperConfig, train_config: TrainConfig, eval_config: EvalConfig,
pattern_ids: List[int], output_dir: str, ipet_data_dir: str = None, repetitions: int = 3,
train_data: List[InputExample] = None, unlabeled_data: List[InputExample] = None,
eval_data: List[InputExample] = None, do_train: bool = True, do_eval: bool = True,
save_unlabeled_logits: bool = False, seed: int = 42):
"""
Train and evaluate an ensemble of PET models without knowledge distillation.
:param model_config: the model configuration to use
:param train_config: the training configuration to use
:param eval_config: the evaluation configuration to use
:param pattern_ids: the ids of all PVPs to use
:param output_dir: the output directory
:param ipet_data_dir: optional directory containing additional training data for iPET
:param repetitions: the number of training repetitions
:param train_data: the training examples to use
:param unlabeled_data: the unlabeled examples to use
:param eval_data: the evaluation examples to use
:param do_train: whether to perform training
:param do_eval: whether to perform evaluation
:param save_unlabeled_logits: whether logits for unlabeled examples should be saved in a file ``logits.txt``. This
is required for both iPET and knowledge distillation.
:param seed: the random seed to use
"""
results = defaultdict(lambda: defaultdict(list))
set_seed(seed)
for pattern_id in pattern_ids:
for iteration in range(repetitions):
model_config.pattern_id = pattern_id
results_dict = {}
pattern_iter_output_dir = "{}/p{}-i{}".format(output_dir, pattern_id, iteration)
if os.path.exists(pattern_iter_output_dir):
logger.warning(f"Path {pattern_iter_output_dir} already exists, skipping it...")
continue
if not os.path.exists(pattern_iter_output_dir):
os.makedirs(pattern_iter_output_dir)
wrapper = init_model(model_config)
# Training
if do_train:
if ipet_data_dir:
p = os.path.join(ipet_data_dir, 'p{}-i{}-train.bin'.format(pattern_id, iteration))
ipet_train_data = InputExample.load_examples(p)
for example in ipet_train_data:
example.logits = None
else:
ipet_train_data = None
results_dict.update(train_single_model(wrapper, train_data, train_config, eval_config,
ipet_train_data=ipet_train_data,
unlabeled_data=unlabeled_data))
with open(os.path.join(pattern_iter_output_dir, 'results.txt'), 'w') as fh:
fh.write(str(results_dict))
logger.info("Saving trained model at {}...".format(pattern_iter_output_dir))
wrapper.save(pattern_iter_output_dir)
train_config.save(os.path.join(pattern_iter_output_dir, 'train_config.json'))
eval_config.save(os.path.join(pattern_iter_output_dir, 'eval_config.json'))
logger.info("Saving complete")
if save_unlabeled_logits:
logits = evaluate(wrapper, unlabeled_data, eval_config)['logits']
save_logits(os.path.join(pattern_iter_output_dir, 'logits.txt'), logits)
if not do_eval:
wrapper.model = None
wrapper = None
torch.cuda.empty_cache()
# Evaluation
if do_eval:
logger.info("Starting evaluation...")
if not wrapper:
wrapper = TransformerModelWrapper.from_pretrained(pattern_iter_output_dir)
eval_result = evaluate(wrapper, eval_data, eval_config, priming_data=train_data)
save_predictions(os.path.join(pattern_iter_output_dir, 'predictions.jsonl'), wrapper, eval_result)
save_logits(os.path.join(pattern_iter_output_dir, 'eval_logits.txt'), eval_result['logits'])
scores = eval_result['scores']
logger.info("--- RESULT (pattern_id={}, iteration={}) ---".format(pattern_id, iteration))
logger.info(scores)
results_dict['test_set_after_training'] = scores
with open(os.path.join(pattern_iter_output_dir, 'results.json'), 'w') as fh:
json.dump(results_dict, fh)
for metric, value in scores.items():
results[metric][pattern_id].append(value)
wrapper.model = None
wrapper = None
torch.cuda.empty_cache()
if do_eval:
logger.info("=== OVERALL RESULTS ===")
_write_results(os.path.join(output_dir, 'result_test.txt'), results)
else:
logger.info("=== ENSEMBLE TRAINING COMPLETE ===")
def train_single_model(model: TransformerModelWrapper, train_data: List[InputExample], config: TrainConfig,
eval_config: EvalConfig = None, ipet_train_data: List[InputExample] = None,
unlabeled_data: List[InputExample] = None, return_train_set_results: bool = True):
"""
Train a single model.
:param model: the model to train
:param train_data: the training examples to use
:param config: the training config
:param eval_config: the evaluation config
:param ipet_train_data: an optional list of iPET training examples to use
:param unlabeled_data: an optional list of unlabeled examples to use
:param return_train_set_results: whether results on the train set before and after training should be computed and
returned
:return: a dictionary containing the global step, average loss and (optionally) results on the train set
"""
device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
if not ipet_train_data:
ipet_train_data = []
results_dict = {}
model.model.to(device)
if train_data and return_train_set_results:
results_dict['train_set_before_training'] = evaluate(model, train_data, eval_config)['scores']['acc']
all_train_data = train_data + ipet_train_data
if not all_train_data and not config.use_logits:
logger.warning('Training method was called without training examples')
else:
global_step, tr_loss = model.train(
all_train_data, device,
per_gpu_train_batch_size=config.per_gpu_train_batch_size,
per_gpu_unlabeled_batch_size=config.per_gpu_unlabeled_batch_size,
n_gpu=config.n_gpu,
num_train_epochs=config.num_train_epochs,
max_steps=config.max_steps,
gradient_accumulation_steps=config.gradient_accumulation_steps,
weight_decay=config.weight_decay,
learning_rate=config.learning_rate,
adam_epsilon=config.adam_epsilon,
warmup_steps=config.warmup_steps,
max_grad_norm=config.max_grad_norm,
unlabeled_data=unlabeled_data if config.lm_training or config.use_logits else None,
lm_training=config.lm_training,
use_logits=config.use_logits,
alpha=config.alpha,
temperature=config.temperature
)
results_dict['global_step'] = global_step
results_dict['average_loss'] = tr_loss
if train_data and return_train_set_results:
results_dict['train_set_after_training'] = evaluate(model, train_data, eval_config)['scores']['acc']
return results_dict
def evaluate(model: TransformerModelWrapper, eval_data: List[InputExample], config: EvalConfig,
priming_data: List[InputExample] = None) -> Dict:
"""
Evaluate a model.
:param model: the model to evaluate
:param eval_data: the examples for evaluation
:param config: the evaluation config
:param priming_data: an optional list of priming data to use
:return: a dictionary containing the model's logits, predictions and (if any metrics are given) scores
"""
if config.priming:
for example in eval_data:
example.meta['priming_data'] = priming_data
metrics = config.metrics if config.metrics else ['acc']
device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
model.model.to(device)
results = model.eval(eval_data, device, per_gpu_eval_batch_size=config.per_gpu_eval_batch_size,
n_gpu=config.n_gpu, decoding_strategy=config.decoding_strategy, priming=config.priming)
predictions = np.argmax(results['logits'], axis=1)
scores = {}
for metric in metrics:
if metric == 'acc':
scores[metric] = simple_accuracy(predictions, results['labels'])
elif metric == 'f1':
scores[metric] = f1_score(results['labels'], predictions)
elif metric == 'f1-macro':
scores[metric] = f1_score(results['labels'], predictions, average='macro')
elif metric == 'em':
scores[metric] = exact_match(predictions, results['labels'], results['question_ids'])
else:
raise ValueError(f"Metric '{metric}' not implemented")
results['scores'] = scores
results['predictions'] = predictions
return results
def _write_results(path: str, results: Dict):
with open(path, 'w') as fh:
for metric in results.keys():
for pattern_id, values in results[metric].items():
mean = statistics.mean(values)
stdev = statistics.stdev(values) if len(values) > 1 else 0
result_str = "{}-p{}: {} +- {}".format(metric, pattern_id, mean, stdev)
logger.info(result_str)
fh.write(result_str + '\n')
for metric in results.keys():
all_results = [result for pattern_results in results[metric].values() for result in pattern_results]
all_mean = statistics.mean(all_results)
all_stdev = statistics.stdev(all_results) if len(all_results) > 1 else 0
result_str = "{}-all-p: {} +- {}".format(metric, all_mean, all_stdev)
logger.info(result_str)
fh.write(result_str + '\n')
def merge_logits(logits_dir: str, output_file: str, reduction: str):
"""
Merge the logits predicted for unlabeled examples by multiple models.
:param logits_dir: a directory for which each sub-directory corresponds to a pretrained model and contains
both a file ``results.txt`` containing that model's results on the training set and a file ``logits.txt``
containing that model's predictions for the unlabeled data.
:param output_file: the file to which the merged logits for all unlabeled examples are written.
:param reduction: the strategy for merging logits, either 'mean' or 'wmean'. For 'mean', all models contribute
equally, for 'wmean', each model's contribution is proportional to its accuracy on the training set before
training.
"""
subdirs = next(os.walk(logits_dir))[1]
logger.info("Found the following {} subdirectories: {}".format(len(subdirs), subdirs))
all_logits_lists = []
for subdir in subdirs:
results_file = os.path.join(logits_dir, subdir, 'results.txt')
logits_file = os.path.join(logits_dir, subdir, 'logits.txt')
logits = []
if not os.path.exists(results_file) or not os.path.exists(logits_file):
logger.warning(f"Skipping subdir '{subdir}' because 'results.txt' or 'logits.txt' not found")
continue
if reduction == 'mean':
result_train = 1
else:
with open(results_file, 'r') as fh:
results = ast.literal_eval(fh.read())
result_train = results['train_set_before_training']
with open(logits_file, 'r') as fh:
for line in fh.read().splitlines():
example_logits = [float(x) for x in line.split()]
logits.append(example_logits)
logger.info("File {}: Score = {}, #Logits = {}, #Labels = {}".format(
results_file, result_train, len(logits), len(logits[0])))
loglist = LogitsList(score=result_train, logits=logits)
all_logits_lists.append(loglist)
merged_loglist = merge_logits_lists(all_logits_lists, reduction=reduction)
merged_loglist.save(output_file)
def merge_logits_lists(logits_lists: List[LogitsList], reduction: str = 'mean') -> LogitsList:
"""
Merge a list of :class:`LogitsList` objects.
:param logits_lists: the lists to merge
:param reduction: the strategy for merging logits, either 'mean' or 'wmean'. For 'mean', all models contribute
equally, for 'wmean', each model's contribution is proportional to its accuracy on the training set before
training.
:return: the merged list
"""
assert len(set(len(ll.logits) for ll in logits_lists)) == 1
logits = np.array([ll.logits for ll in logits_lists])
weights = np.array([ll.score for ll in logits_lists])
if reduction == 'mean':
logits = np.mean(logits, axis=0).tolist()
elif reduction == 'wmean':
logits = np.average(logits, axis=0, weights=weights).tolist()
else:
raise ValueError("Reduction strategy '{}' not implemented".format(reduction))
return LogitsList(score=-1, logits=logits)
def generate_ipet_train_sets(train_data: List[InputExample], unlabeled_data: List[InputExample], labels: List[str],
logits_dir: str, output_dir: str, reduction: str, num_new_examples: int,
logits_percentage: float, n_most_likely: int = -1, seed: int = 42):
"""
Generate training sets for the next generation of iPET models.
:param train_data: the training examples
:param unlabeled_data: the unlabeled examples
:param labels: the list of all possible labels
:param logits_dir: the directory that contains the predictions of all models in the current generation for the
unlabeled data.
:param output_dir: the output directory
:param reduction: the strategy for merging logits, either 'mean' or 'wmean'. For 'mean', all models contribute
equally, for 'wmean', each model's contribution is proportional to its accuracy on the training set before
training.
:param num_new_examples: the number of new examples to create
:param logits_percentage: the percentage of models to use for annotating training sets for the next generation
:param n_most_likely: If >0, in the first generation the n_most_likely examples per label are chosen even
if their predicted label is different
:param seed: the random seed to use
"""
subdirs = next(os.walk(logits_dir))[1]
if not os.path.exists(output_dir):
os.makedirs(output_dir)
logger.info("Found the following {} subdirectories: {}".format(len(subdirs), subdirs))
if train_data:
train_examples_per_label = [sum(1 for ex in train_data if ex.label == label) for label in labels]
multiplier = num_new_examples / len(train_data)
examples_per_label = [int(epl * multiplier) for epl in train_examples_per_label]
logger.info(f"Example distribution in the original dataset: {train_examples_per_label}")
else:
examples_per_label = eq_div(num_new_examples, len(labels))
logger.info(f"Target distribution for the new dataset: {examples_per_label}")
for example in unlabeled_data:
example.label, example.logits = None, None
logits_lists = {}
rng = random.Random(seed)
rng_np = np.random.RandomState(seed)
for subdir in subdirs:
results_file = os.path.join(logits_dir, subdir, 'results.txt')
logits_file = os.path.join(logits_dir, subdir, 'logits.txt')
logits = []
if not os.path.exists(results_file) or not os.path.exists(logits_file):
logger.warning(f"Skipping subdir '{subdir}' because 'results.txt' or 'logits.txt' not found")
continue
if reduction == 'mean':
result_train = 1
else:
with open(results_file, 'r') as fh:
results = ast.literal_eval(fh.read())
result_train = results['train_set_before_training']
with open(logits_file, 'r') as fh:
for line in fh.read().splitlines():
example_logits = [float(x) for x in line.split()]
logits.append(example_logits)
logger.info("File {}: Score = {}, #Logits = {}, #Labels = {}".format(
results_file, result_train, len(logits), len(logits[0])))
loglist = LogitsList(score=result_train, logits=logits)
logits_lists[subdir] = loglist
for subdir in subdirs:
other_logits_lists = [ll for sd, ll in logits_lists.items() if sd != subdir]
subdir_train_set = generate_ipet_train_set(
other_logits_lists, labels=labels, original_data=unlabeled_data, examples_per_label=examples_per_label,
logits_percentage=logits_percentage, reduction=reduction, n_most_likely=n_most_likely, rng=rng,
rng_np=rng_np
)
InputExample.save_examples(subdir_train_set,
os.path.join(output_dir, subdir + '-train.bin'))
def generate_ipet_train_set(logits_lists: List[LogitsList], labels: List[str], original_data: List[InputExample],
examples_per_label: List[int], logits_percentage: float, reduction: str = 'mean',
n_most_likely: int = -1, rng=None, rng_np=None) -> List[InputExample]:
"""
Generate a single training set for the next generation of iPET models.
:param logits_lists: predictions from the previous generation of models
:param labels: all task labels
:param original_data: the original training data corresponding to the logits_lists
:param examples_per_label: the number of examples per label to create
:param logits_percentage: the percentage of models/logits to choose
:param reduction: the reduction strategy ('wmean' or 'mean')
:param n_most_likely: if >0, for each label the n_most_likely examples with the highest logits are chosen
:param rng: the random number generator to use for non-numpy operations
:param rng_np: the random number generator to use for numpy operations
:return: a list of input examples that serves as training set for the next generation
"""
assert len(set(len(ll.logits) for ll in logits_lists)) == 1
if not rng:
rng = random.Random()
if not rng_np:
rng_np = np.random.RandomState()
num_logits_lists = round(len(logits_lists) * logits_percentage)
logits_lists = rng.sample(logits_lists, k=num_logits_lists)
logits = np.array([ll.logits for ll in logits_lists])
weights = np.array([ll.score for ll in logits_lists])
if reduction == 'mean':
logits = np.mean(logits, axis=0)
logits = softmax(logits, axis=1).tolist()
elif reduction == 'wmean':
logits = np.average(logits, axis=0, weights=weights)
logits = softmax(logits, axis=1).tolist()
else:
raise ValueError("Reduction strategy '{}' not implemented".format(reduction))
assert len(logits) == len(original_data)
for lgs, example in zip(logits, original_data):
example.logits = lgs
example.label = labels[np.argmax(example.logits).item()]
test_set = []
for idx, label in enumerate(labels):
if n_most_likely <= 0:
examples = [ex for ex in original_data if ex.label == label]
logger.info("There are {} examples for label {}".format(len(examples), label))
while len(examples) < examples_per_label[idx]:
# upsample examples if there are too few
examples.extend(ex for ex in original_data if ex.label == label)
else:
examples = [(ex.logits[idx], ex_idx, ex) for ex_idx, ex in enumerate(original_data)]
examples.sort(reverse=True)
examples = [ex for score, ex_idx, ex in examples[:n_most_likely]]
examples = [deepcopy(ex) for ex in examples]
for example in examples:
example.logits = [example.logits[idx]]
example.label = label
label_examples = _draw_examples_by_label_probability(
examples=examples, num_examples=examples_per_label[idx], rng=rng_np)
test_set.extend(label_examples)
return test_set
def _draw_examples_by_label_probability(examples: List[InputExample], num_examples: int, rng) -> List[InputExample]:
label_probabilities = [max(example.logits) for example in examples]
sum_label_probabilities = sum(label_probabilities)
label_probabilities = [p / sum_label_probabilities for p in label_probabilities]
return rng.choice(examples, size=num_examples, replace=False, p=label_probabilities).tolist()
import argparse
import os
import json
from collections import Counter
from typing import Dict, List
import numpy as np
import random
import torch
from transformers import PreTrainedTokenizer, RobertaTokenizer
from pet.tasks import PROCESSORS, load_examples, TRAIN_SET
from pet.utils import InputExample, eq_div
from pet.wrapper import TransformerModelWrapper, MODEL_CLASSES, WrapperConfig
import log
logger = log.get_logger('root')
def filter_words(tokens: List[str], word_counts=None, max_words: int = -1):
"""
Given a list of tokens, return a reduced list that contains only tokens from the list that correspond
to actual words and occur a given number of times.
:param tokens: the list of tokens to filter
:param word_counts: a dictionary mapping words to their number of occurrences
:param max_words: if set to a value >0, only the `max_words` most frequent words according to `word_counts` are kept
:return: the filtered list of tokens
"""
tokens = (word for word in tokens if word[0] == 'Ġ' and len([char for char in word[1:] if char.isalpha()]) >= 2)
if word_counts and max_words > 0:
tokens = sorted(tokens, key=lambda word: word_counts[word[1:]], reverse=True)[:max_words]
return tokens
def get_word_to_id_map(tokenizer: PreTrainedTokenizer, word_counts=None, max_words: int = -1):
"""
Return a mapping from all tokens to their internal ids for a given tokenizer
:param tokenizer: the tokenizer
:param word_counts: a dictionary mapping words to their number of occurrences
:param max_words: if set to a value >0, only the `max_words` most frequent words according to `word_counts` are kept
:return:
"""
if not isinstance(tokenizer, RobertaTokenizer):
raise ValueError("this function currently only supports instances of 'RobertaTokenizer'")
words = filter_words(tokenizer.encoder.keys(), word_counts, max_words)
word2id = {word[1:]: tokenizer.convert_tokens_to_ids(word) for word in words}
logger.info(f"There are {len(word2id)} words left after filtering non-word tokens")
return word2id
class AutomaticVerbalizerSearch:
def __init__(self, word2idx: Dict[str, int], labels: List[str], logits_list: List[np.ndarray],
expected: Dict[str, np.ndarray]):
self.word2idx = word2idx
self.labels = labels
self.expected = expected
logits_list = [np.exp(logits) for logits in logits_list]
self.probs_list = [logits / np.expand_dims(np.sum(logits, axis=1), axis=1) for logits in logits_list]
def _get_candidates(self, num_candidates: int) -> Dict[str, List[str]]:
if num_candidates <= 0:
return {label: self.word2idx.keys() for label in self.labels}
scores = {label: Counter() for label in self.labels}
for label in self.labels:
for probs in self.probs_list:
for word, idx in self.word2idx.items():
score = np.sum(np.log(probs[:, idx]) * self.expected[label])
scores[label][word] += score
return {label: [w for w, _ in scores[label].most_common(num_candidates)] for label in self.labels}
def _get_top_words(self, candidates: Dict[str, List[str]], normalize: bool = True, words_per_label: int = 10,
score_fct: str = 'llr') -> Dict[str, List[str]]:
scores = {label: Counter() for label in self.labels}
for label in self.labels:
for probs in self.probs_list:
for word in candidates[label]:
idx = self.word2idx[word]
if score_fct == 'llr':
scores[label][word] += self.log_likelihood_ratio(probs[:, idx], self.expected[label], normalize)
elif score_fct == 'ce':
scores[label][word] += self.cross_entropy(probs[:, idx], self.expected[label], normalize)
else:
raise ValueError(f"Score function '{score_fct}' not implemented")
return {label: [w for w, _ in scores[label].most_common(words_per_label)] for label in self.labels}
@staticmethod
def log_likelihood_ratio(predictions: np.ndarray, expected: np.ndarray, normalize: bool) -> float:
scale_factor = sum(1 - expected) / sum(expected) if normalize else 1
pos_score = scale_factor * (np.sum(np.log(predictions) * expected) - np.sum(np.log(1 - predictions) * expected))
neg_score = np.sum(np.log(1 - predictions) * (1 - expected)) - np.sum(np.log(predictions) * (1 - expected))
return pos_score + neg_score
@staticmethod
def cross_entropy(predictions: np.ndarray, expected: np.ndarray, normalize: bool) -> float:
scale_factor = sum(1 - expected) / sum(expected) if normalize else 1
pos_score = scale_factor * np.sum(np.log(predictions) * expected)
neg_score = np.sum(np.log(1 - predictions) * (1 - expected))
return pos_score + neg_score
def find_verbalizer(self, words_per_label: int = 10, num_candidates: int = 1000, normalize: bool = True,
score_fct: str = 'llr'):
if score_fct == 'random':
return {label: random.sample(self.word2idx.keys(), words_per_label) for label in self.labels}
candidates = self._get_candidates(num_candidates=num_candidates)
return self._get_top_words(candidates=candidates, normalize=normalize, words_per_label=words_per_label,
score_fct=score_fct)
def main():
parser = argparse.ArgumentParser()
# required parameters
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory. The verbalizers are written to a file 'verbalizer.json' in this directory.")
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the data files for the task.")
parser.add_argument("--model_type", default=None, type=str, required=True,
help="The model type")
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name")
parser.add_argument("--task_name", default=None, type=str, required=True,
help="The name of the task to train selected in the list: " + ", ".join(PROCESSORS.keys()))
# verbalizer search hyperparameters
parser.add_argument("--normalize", action='store_true',
help="Whether to normalize the loss as proposed in the paper. It is recommended to set this to 'true'.")
parser.add_argument("--combine_patterns", action='store_true',
help="If set to true, a single joint verbalizer is searched for all patterns")
parser.add_argument("--num_candidates", default=1000, type=int,
help="The number of candidate tokens to consider as verbalizers (see Section 4.1 of the paper)")
parser.add_argument("--words_per_label", default=10, type=int,
help="The number of verbalizer tokens to assign to each label")
parser.add_argument("--score_fct", default='llr', choices=['llr', 'ce', 'random'],
help="The function used to score verbalizers. Choices are: the log-likelihood ratio loss proposed in the paper "
"('llr'), cross-entropy loss ('ce') and 'random', which assigns random tokens to each label.")
# other optional parameters
parser.add_argument("--train_examples", default=50, type=int,
help="The total number of train examples to use, where -1 equals all examples.")
parser.add_argument("--pattern_ids", default=[0], type=int, nargs='+',
help="The ids of the PVPs to be used")
parser.add_argument("--max_seq_length", default=256, type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for evaluation.")
parser.add_argument("--words_file", default=None, type=str,
help="Path to a file containing (unlabeled) texts from the task's domain. This text is used to compute "
"verbalization candidates by selecting the most frequent words.")
parser.add_argument("--max_words", default=10000, type=int,
help="Only the 10,000 tokens that occur most frequently in the task’s unlabeled data (see --words_file) are "
"considered as verbalization candidates")
parser.add_argument("--additional_input_examples", type=str,
help="An optional path to an additional set of input examples (e.g., obtained using iPET)")
parser.add_argument("--seed", default=42, type=int,
help="random seed for initialization")
args = parser.parse_args()
random.seed(args.seed)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
with open(os.path.join(args.output_dir, 'config.txt'), 'w', encoding='utf8') as fh:
json.dump(args.__dict__, fh, indent=2)
# setup gpu/cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
# prepare task
args.task_name = args.task_name.lower()
if args.task_name not in PROCESSORS:
raise ValueError("Task not found: {}".format(args.task_name))
processor = PROCESSORS[args.task_name]()
args.label_list = processor.get_labels()
args.cache_dir = ""
args.do_lower_case = False
args.verbalizer_file = None
args.wrapper_type = 'mlm'
# get training data
train_examples_per_label = eq_div(args.train_examples, len(args.label_list)) if args.train_examples != -1 else -1
train_data = load_examples(args.task_name, args.data_dir, set_type=TRAIN_SET, num_examples_per_label=train_examples_per_label)
if args.additional_input_examples:
additional_data = InputExample.load_examples(args.additional_input_examples)
train_data += additional_data
logger.info(f"Loaded {len(additional_data)} additional examples from {args.additional_input_examples}, total"
f"training set size is now {len(train_data)}")
expected = {label: np.array([1 if x.label == label else 0 for x in train_data]) for label in args.label_list}
if args.words_file:
with open(args.words_file, 'r', encoding='utf8') as fh:
word_counts = Counter(fh.read().split())
else:
word_counts = None
tokenizer_class = MODEL_CLASSES[args.model_type]['tokenizer']
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
word2idx = get_word_to_id_map(tokenizer, word_counts=word_counts, max_words=args.max_words)
logits = []
for pattern_id in args.pattern_ids:
logger.info(f"Processing examples with pattern id {pattern_id}...")
args.pattern_id = pattern_id
config = WrapperConfig(model_type=args.model_type, model_name_or_path=args.model_name_or_path, wrapper_type='mlm',
task_name=args.task_name, max_seq_length=args.max_seq_length, label_list=args.label_list,
pattern_id=args.pattern_id)
wrapper = TransformerModelWrapper(config)
wrapper.model.to(device)
# modify all patterns so that they return a single text segment instead of two segments
get_parts = wrapper.preprocessor.pvp.get_parts
wrapper.preprocessor.pvp.get_parts = lambda example: (get_parts(example)[0] + get_parts(example)[1], [])
wrapper.preprocessor.pvp.convert_mlm_logits_to_cls_logits = lambda mask, x, _=None: x[mask >= 0]
pattern_logits = wrapper.eval(train_data, device, per_gpu_eval_batch_size=args.per_gpu_eval_batch_size, n_gpu=args.n_gpu)['logits']
pattern_logits = pattern_logits - np.expand_dims(np.max(pattern_logits, axis=1), axis=1)
logits.append(pattern_logits)
logger.info("Starting verbalizer search...")
if args.combine_patterns:
avs = AutomaticVerbalizerSearch(word2idx, args.label_list, logits, expected)
verbalizer = avs.find_verbalizer(
num_candidates=args.num_candidates,
words_per_label=args.words_per_label,
normalize=args.normalize,
score_fct=args.score_fct
)
verbalizers = {pattern_id: verbalizer for pattern_id in args.pattern_ids}
else:
verbalizers = {}
for idx, pattern_id in enumerate(args.pattern_ids):
avs = AutomaticVerbalizerSearch(word2idx, args.label_list, [logits[idx]], expected)
verbalizers[pattern_id] = avs.find_verbalizer(
num_candidates=args.num_candidates,
words_per_label=args.words_per_label,
normalize=args.normalize,
score_fct=args.score_fct
)
print(json.dumps(verbalizers, indent=2))
logger.info("Verbalizer search complete, writing output...")
with open(os.path.join(args.output_dir, 'verbalizers.json'), 'w', encoding='utf8') as fh:
json.dump(verbalizers, fh, indent=2)
logger.info("Done")
if __name__ == "__main__":
main()
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import List
import numpy as np
from pet.utils import InputFeatures, InputExample, PLMInputFeatures
from pet.pvp import PVP, PVPS
class Preprocessor(ABC):
"""
A preprocessor that transforms an :class:`InputExample` into a :class:`InputFeatures` object so that it can be
processed by the model being used.
"""
def __init__(self, wrapper, task_name, pattern_id: int = 0, verbalizer_file: str = None):
"""
Create a new preprocessor.
:param wrapper: the wrapper for the language model to use
:param task_name: the name of the task
:param pattern_id: the id of the PVP to be used
:param verbalizer_file: path to a file containing a verbalizer that overrides the default verbalizer
"""
self.wrapper = wrapper
self.pvp = PVPS[task_name](self.wrapper, pattern_id, verbalizer_file) # type: PVP
self.label_map = {label: i for i, label in enumerate(self.wrapper.config.label_list)}
@abstractmethod
def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False,
**kwargs) -> InputFeatures:
"""Convert the given example into a set of input features"""
pass
class MLMPreprocessor(Preprocessor):
"""Preprocessor for models pretrained using a masked language modeling objective (e.g., BERT)."""
def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False,
**kwargs) -> InputFeatures:
if priming:
input_ids, token_type_ids = self.pvp.encode(example, priming=True)
priming_data = example.meta['priming_data'] # type: List[InputExample]
priming_input_ids = []
for priming_example in priming_data:
pe_input_ids, _ = self.pvp.encode(priming_example, priming=True, labeled=True)
priming_input_ids += pe_input_ids
input_ids = priming_input_ids + input_ids
token_type_ids = self.wrapper.tokenizer.create_token_type_ids_from_sequences(input_ids)
input_ids = self.wrapper.tokenizer.build_inputs_with_special_tokens(input_ids)
else:
input_ids, token_type_ids = self.pvp.encode(example)
attention_mask = [1] * len(input_ids)
padding_length = self.wrapper.config.max_seq_length - len(input_ids)
if padding_length < 0:
raise ValueError(f"Maximum sequence length is too small, got {len(input_ids)} input ids")
input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
token_type_ids = token_type_ids + ([0] * padding_length)
assert len(input_ids) == self.wrapper.config.max_seq_length
assert len(attention_mask) == self.wrapper.config.max_seq_length
assert len(token_type_ids) == self.wrapper.config.max_seq_length
label = self.label_map[example.label] if example.label is not None else -100
logits = example.logits if example.logits else [-1]
if labelled:
mlm_labels = self.pvp.get_mask_positions(input_ids)
if self.wrapper.config.model_type == 'gpt2':
# shift labels to the left by one
mlm_labels.append(mlm_labels.pop(0))
else:
mlm_labels = [-1] * self.wrapper.config.max_seq_length
return InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
label=label, mlm_labels=mlm_labels, logits=logits, idx=example.idx)
class PLMPreprocessor(MLMPreprocessor):
"""Preprocessor for models pretrained using a permuted language modeling objective (e.g., XLNet)."""
def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False,
**kwargs) -> PLMInputFeatures:
input_features = super().get_input_features(example, labelled, priming, **kwargs)
input_ids = input_features.input_ids
num_masks = 1 # currently, PLMPreprocessor supports only replacements that require exactly one mask
perm_mask = np.zeros((len(input_ids), len(input_ids)), dtype=np.float)
label_idx = input_ids.index(self.pvp.mask_id)
perm_mask[:, label_idx] = 1 # the masked token is not seen by any other token
target_mapping = np.zeros((num_masks, len(input_ids)), dtype=np.float)
target_mapping[0, label_idx] = 1.0
return PLMInputFeatures(perm_mask=perm_mask, target_mapping=target_mapping, **input_features.__dict__)
class SequenceClassifierPreprocessor(Preprocessor):
"""Preprocessor for a regular sequence classification model."""
def get_input_features(self, example: InputExample, **kwargs) -> InputFeatures:
inputs = self.wrapper.task_helper.get_sequence_classifier_inputs(example) if self.wrapper.task_helper else None
if inputs is None:
inputs = self.wrapper.tokenizer.encode_plus(
example.text_a if example.text_a else None,
example.text_b if example.text_b else None,
add_special_tokens=True,
max_length=self.wrapper.config.max_seq_length,
)
input_ids, token_type_ids = inputs["input_ids"], inputs.get("token_type_ids")
attention_mask = [1] * len(input_ids)
padding_length = self.wrapper.config.max_seq_length - len(input_ids)
input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
if not token_type_ids:
token_type_ids = [0] * self.wrapper.config.max_seq_length
else:
token_type_ids = token_type_ids + ([0] * padding_length)
mlm_labels = [-1] * len(input_ids)
assert len(input_ids) == self.wrapper.config.max_seq_length
assert len(attention_mask) == self.wrapper.config.max_seq_length
assert len(token_type_ids) == self.wrapper.config.max_seq_length
label = self.label_map[example.label] if example.label is not None else -100
logits = example.logits if example.logits else [-1]
return InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
label=label, mlm_labels=mlm_labels, logits=logits, idx=example.idx)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains the pattern-verbalizer pairs (PVPs) for all tasks.
"""
import random
import string
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Tuple, List, Union, Dict
import torch
from transformers import PreTrainedTokenizer, GPT2Tokenizer
from pet.task_helpers import MultiMaskTaskHelper
from pet.tasks import TASK_HELPERS
from pet.utils import InputExample, get_verbalization_ids
import log
from pet import wrapper as wrp
logger = log.get_logger('root')
FilledPattern = Tuple[List[Union[str, Tuple[str, bool]]], List[Union[str, Tuple[str, bool]]]]
class PVP(ABC):
"""
This class contains functions to apply patterns and verbalizers as required by PET. Each task requires its own
custom implementation of a PVP.
"""
def __init__(self, wrapper, pattern_id: int = 0, verbalizer_file: str = None, seed: int = 42):
"""
Create a new PVP.
:param wrapper: the wrapper for the underlying language model
:param pattern_id: the pattern id to use
:param verbalizer_file: an optional file that contains the verbalizer to be used
:param seed: a seed to be used for generating random numbers if necessary
"""
self.wrapper = wrapper
self.pattern_id = pattern_id
self.rng = random.Random(seed)
if verbalizer_file:
self.verbalize = PVP._load_verbalizer_from_file(verbalizer_file, self.pattern_id)
use_multimask = (self.wrapper.config.task_name in TASK_HELPERS) and (
issubclass(TASK_HELPERS[self.wrapper.config.task_name], MultiMaskTaskHelper)
)
if not use_multimask and self.wrapper.config.wrapper_type in [wrp.MLM_WRAPPER, wrp.PLM_WRAPPER]:
self.mlm_logits_to_cls_logits_tensor = self._build_mlm_logits_to_cls_logits_tensor()
def _build_mlm_logits_to_cls_logits_tensor(self):
label_list = self.wrapper.config.label_list
m2c_tensor = torch.ones([len(label_list), self.max_num_verbalizers], dtype=torch.long) * -1
for label_idx, label in enumerate(label_list):
verbalizers = self.verbalize(label)
for verbalizer_idx, verbalizer in enumerate(verbalizers):
verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True)
assert verbalizer_id != self.wrapper.tokenizer.unk_token_id, "verbalization was tokenized as <UNK>"
m2c_tensor[label_idx, verbalizer_idx] = verbalizer_id
return m2c_tensor
@property
def mask(self) -> str:
"""Return the underlying LM's mask token"""
return self.wrapper.tokenizer.mask_token
@property
def mask_id(self) -> int:
"""Return the underlying LM's mask id"""
return self.wrapper.tokenizer.mask_token_id
@property
def max_num_verbalizers(self) -> int:
"""Return the maximum number of verbalizers across all labels"""
return max(len(self.verbalize(label)) for label in self.wrapper.config.label_list)
@staticmethod
def shortenable(s):
"""Return an instance of this string that is marked as shortenable"""
return s, True
@staticmethod
def remove_final_punc(s: Union[str, Tuple[str, bool]]):
"""Remove the final punctuation mark"""
if isinstance(s, tuple):
return PVP.remove_final_punc(s[0]), s[1]
return s.rstrip(string.punctuation)
@staticmethod
def lowercase_first(s: Union[str, Tuple[str, bool]]):
"""Lowercase the first character"""
if isinstance(s, tuple):
return PVP.lowercase_first(s[0]), s[1]
return s[0].lower() + s[1:]
def encode(self, example: InputExample, priming: bool = False, labeled: bool = False) \
-> Tuple[List[int], List[int]]:
"""
Encode an input example using this pattern-verbalizer pair.
:param example: the input example to encode
:param priming: whether to use this example for priming
:param labeled: if ``priming=True``, whether the label should be appended to this example
:return: A tuple, consisting of a list of input ids and a list of token type ids
"""
if not priming:
assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true"
tokenizer = self.wrapper.tokenizer # type: PreTrainedTokenizer
parts_a, parts_b = self.get_parts(example)
kwargs = {'add_prefix_space': True} if isinstance(tokenizer, GPT2Tokenizer) else {}
parts_a = [x if isinstance(x, tuple) else (x, False) for x in parts_a]
parts_a = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_a if x]
if parts_b:
parts_b = [x if isinstance(x, tuple) else (x, False) for x in parts_b]
parts_b = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_b if x]
self.truncate(parts_a, parts_b, max_length=self.wrapper.config.max_seq_length)
tokens_a = [token_id for part, _ in parts_a for token_id in part]
tokens_b = [token_id for part, _ in parts_b for token_id in part] if parts_b else None
if priming:
input_ids = tokens_a
if tokens_b:
input_ids += tokens_b
if labeled:
mask_idx = input_ids.index(self.mask_id)
assert mask_idx >= 0, 'sequence of input_ids must contain a mask token'
assert len(self.verbalize(example.label)) == 1, 'priming only supports one verbalization per label'
verbalizer = self.verbalize(example.label)[0]
verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True)
input_ids[mask_idx] = verbalizer_id
return input_ids, []
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
return input_ids, token_type_ids
@staticmethod
def _seq_length(parts: List[Tuple[str, bool]], only_shortenable: bool = False):
return sum([len(x) for x, shortenable in parts if not only_shortenable or shortenable]) if parts else 0
@staticmethod
def _remove_last(parts: List[Tuple[str, bool]]):
last_idx = max(idx for idx, (seq, shortenable) in enumerate(parts) if shortenable and seq)
parts[last_idx] = (parts[last_idx][0][:-1], parts[last_idx][1])
def truncate(self, parts_a: List[Tuple[str, bool]], parts_b: List[Tuple[str, bool]], max_length: int):
"""Truncate two sequences of text to a predefined total maximum length"""
total_len = self._seq_length(parts_a) + self._seq_length(parts_b)
total_len += self.wrapper.tokenizer.num_special_tokens_to_add(bool(parts_b))
num_tokens_to_remove = total_len - max_length
if num_tokens_to_remove <= 0:
return parts_a, parts_b
for _ in range(num_tokens_to_remove):
if self._seq_length(parts_a, only_shortenable=True) > self._seq_length(parts_b, only_shortenable=True):
self._remove_last(parts_a)
else:
self._remove_last(parts_b)
@abstractmethod
def get_parts(self, example: InputExample) -> FilledPattern:
"""
Given an input example, apply a pattern to obtain two text sequences (text_a and text_b) containing exactly one
mask token (or one consecutive sequence of mask tokens for PET with multiple masks). If a task requires only a
single sequence of text, the second sequence should be an empty list.
:param example: the input example to process
:return: Two sequences of text. All text segments can optionally be marked as being shortenable.
"""
pass
@abstractmethod
def verbalize(self, label) -> List[str]:
"""
Return all verbalizations for a given label.
:param label: the label
:return: the list of verbalizations
"""
pass
def get_mask_positions(self, input_ids: List[int]) -> List[int]:
label_idx = input_ids.index(self.mask_id)
labels = [-1] * len(input_ids)
labels[label_idx] = 1
return labels
def convert_mlm_logits_to_cls_logits(self, mlm_labels: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
masked_logits = logits[mlm_labels >= 0]
cls_logits = torch.stack([self._convert_single_mlm_logits_to_cls_logits(ml) for ml in masked_logits])
return cls_logits
def _convert_single_mlm_logits_to_cls_logits(self, logits: torch.Tensor) -> torch.Tensor:
m2c = self.mlm_logits_to_cls_logits_tensor.to(logits.device)
# filler_len.shape() == max_fillers
filler_len = torch.tensor([len(self.verbalize(label)) for label in self.wrapper.config.label_list],
dtype=torch.float)
filler_len = filler_len.to(logits.device)
# cls_logits.shape() == num_labels x max_fillers (and 0 when there are not as many fillers).
cls_logits = logits[torch.max(torch.zeros_like(m2c), m2c)]
cls_logits = cls_logits * (m2c > 0).float()
# cls_logits.shape() == num_labels
cls_logits = cls_logits.sum(axis=1) / filler_len
return cls_logits
def convert_plm_logits_to_cls_logits(self, logits: torch.Tensor) -> torch.Tensor:
assert logits.shape[1] == 1
logits = torch.squeeze(logits, 1) # remove second dimension as we always have exactly one <mask> per example
cls_logits = torch.stack([self._convert_single_mlm_logits_to_cls_logits(lgt) for lgt in logits])
return cls_logits
@staticmethod
def _load_verbalizer_from_file(path: str, pattern_id: int):
verbalizers = defaultdict(dict) # type: Dict[int, Dict[str, List[str]]]
current_pattern_id = None
with open(path, 'r') as fh:
for line in fh.read().splitlines():
if line.isdigit():
current_pattern_id = int(line)
elif line:
label, *realizations = line.split()
verbalizers[current_pattern_id][label] = realizations
logger.info("Automatically loaded the following verbalizer: \n {}".format(verbalizers[pattern_id]))
def verbalize(label) -> List[str]:
return verbalizers[pattern_id][label]
return verbalize
class AgnewsPVP(PVP):
VERBALIZER = {
"1": ["World"],
"2": ["Sports"],
"3": ["Business"],
"4": ["Tech"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
if self.pattern_id == 0:
return [self.mask, ':', text_a, text_b], []
elif self.pattern_id == 1:
return [self.mask, 'News:', text_a, text_b], []
elif self.pattern_id == 2:
return [text_a, '(', self.mask, ')', text_b], []
elif self.pattern_id == 3:
return [text_a, text_b, '(', self.mask, ')'], []
elif self.pattern_id == 4:
return ['[ Category:', self.mask, ']', text_a, text_b], []
elif self.pattern_id == 5:
return [self.mask, '-', text_a, text_b], []
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
return AgnewsPVP.VERBALIZER[label]
class YahooPVP(PVP):
VERBALIZER = {
"1": ["Society"],
"2": ["Science"],
"3": ["Health"],
"4": ["Education"],
"5": ["Computer"],
"6": ["Sports"],
"7": ["Business"],
"8": ["Entertainment"],
"9": ["Relationship"],
"10": ["Politics"],
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
if self.pattern_id == 0:
return [self.mask, ':', text_a, text_b], []
elif self.pattern_id == 1:
return [self.mask, 'Question:', text_a, text_b], []
elif self.pattern_id == 2:
return [text_a, '(', self.mask, ')', text_b], []
elif self.pattern_id == 3:
return [text_a, text_b, '(', self.mask, ')'], []
elif self.pattern_id == 4:
return ['[ Category:', self.mask, ']', text_a, text_b], []
elif self.pattern_id == 5:
return [self.mask, '-', text_a, text_b], []
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
return YahooPVP.VERBALIZER[label]
class MnliPVP(PVP):
VERBALIZER_A = {
"contradiction": ["Wrong"],
"entailment": ["Right"],
"neutral": ["Maybe"]
}
VERBALIZER_B = {
"contradiction": ["No"],
"entailment": ["Yes"],
"neutral": ["Maybe"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(self.remove_final_punc(example.text_a))
text_b = self.shortenable(example.text_b)
if self.pattern_id == 0 or self.pattern_id == 2:
return ['"', text_a, '" ?'], [self.mask, ', "', text_b, '"']
elif self.pattern_id == 1 or self.pattern_id == 3:
return [text_a, '?'], [self.mask, ',', text_b]
def verbalize(self, label) -> List[str]:
if self.pattern_id == 0 or self.pattern_id == 1:
return MnliPVP.VERBALIZER_A[label]
return MnliPVP.VERBALIZER_B[label]
class YelpPolarityPVP(PVP):
VERBALIZER = {
"1": ["bad"],
"2": ["good"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
text = self.shortenable(example.text_a)
if self.pattern_id == 0:
return ['It was', self.mask, '.', text], []
elif self.pattern_id == 1:
return [text, '. All in all, it was', self.mask, '.'], []
elif self.pattern_id == 2:
return ['Just', self.mask, "!"], [text]
elif self.pattern_id == 3:
return [text], ['In summary, the restaurant is', self.mask, '.']
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
return YelpPolarityPVP.VERBALIZER[label]
class YelpFullPVP(YelpPolarityPVP):
VERBALIZER = {
"1": ["terrible"],
"2": ["bad"],
"3": ["okay"],
"4": ["good"],
"5": ["great"]
}
def verbalize(self, label) -> List[str]:
return YelpFullPVP.VERBALIZER[label]
class XStancePVP(PVP):
VERBALIZERS = {
'en': {"FAVOR": ["Yes"], "AGAINST": ["No"]},
'de': {"FAVOR": ["Ja"], "AGAINST": ["Nein"]},
'fr': {"FAVOR": ["Oui"], "AGAINST": ["Non"]}
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4:
return ['"', text_a, '"'], [self.mask, '. "', text_b, '"']
elif self.pattern_id == 1 or self.pattern_id == 3 or self.pattern_id == 5:
return [text_a], [self.mask, '.', text_b]
def verbalize(self, label) -> List[str]:
lang = 'de' if self.pattern_id < 2 else 'en' if self.pattern_id < 4 else 'fr'
return XStancePVP.VERBALIZERS[lang][label]
class RtePVP(PVP):
VERBALIZER = {
"not_entailment": ["No"],
"entailment": ["Yes"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
# switch text_a and text_b to get the correct order
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b.rstrip(string.punctuation))
if self.pattern_id == 0:
return ['"', text_b, '" ?'], [self.mask, ', "', text_a, '"']
elif self.pattern_id == 1:
return [text_b, '?'], [self.mask, ',', text_a]
if self.pattern_id == 2:
return ['"', text_b, '" ?'], [self.mask, '. "', text_a, '"']
elif self.pattern_id == 3:
return [text_b, '?'], [self.mask, '.', text_a]
elif self.pattern_id == 4:
return [text_a, ' question: ', self.shortenable(example.text_b), ' True or False? answer:', self.mask], []
def verbalize(self, label) -> List[str]:
if self.pattern_id == 4:
return ['true'] if label == 'entailment' else ['false']
return RtePVP.VERBALIZER[label]
class CbPVP(RtePVP):
VERBALIZER = {
"contradiction": ["No"],
"entailment": ["Yes"],
"neutral": ["Maybe"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
if self.pattern_id == 4:
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
return [text_a, ' question: ', text_b, ' true, false or neither? answer:', self.mask], []
return super().get_parts(example)
def verbalize(self, label) -> List[str]:
if self.pattern_id == 4:
return ['true'] if label == 'entailment' else ['false'] if label == 'contradiction' else ['neither']
return CbPVP.VERBALIZER[label]
class CopaPVP(PVP):
def get_parts(self, example: InputExample) -> FilledPattern:
premise = self.remove_final_punc(self.shortenable(example.text_a))
choice1 = self.remove_final_punc(self.lowercase_first(example.meta['choice1']))
choice2 = self.remove_final_punc(self.lowercase_first(example.meta['choice2']))
question = example.meta['question']
assert question in ['cause', 'effect']
example.meta['choice1'], example.meta['choice2'] = choice1, choice2
num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in [choice1, choice2])
if question == 'cause':
if self.pattern_id == 0:
return ['"', choice1, '" or "', choice2, '"?', premise, 'because', self.mask * num_masks, '.'], []
elif self.pattern_id == 1:
return [choice1, 'or', choice2, '?', premise, 'because', self.mask * num_masks, '.'], []
else:
if self.pattern_id == 0:
return ['"', choice1, '" or "', choice2, '"?', premise, ', so', self.mask * num_masks, '.'], []
elif self.pattern_id == 1:
return [choice1, 'or', choice2, '?', premise, ', so', self.mask * num_masks, '.'], []
def verbalize(self, label) -> List[str]:
return []
class WscPVP(PVP):
def get_parts(self, example: InputExample) -> FilledPattern:
pronoun = example.meta['span2_text']
target = example.meta['span1_text']
pronoun_idx = example.meta['span2_index']
words_a = example.text_a.split()
words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*'
text_a = ' '.join(words_a)
text_a = self.shortenable(text_a)
num_pad = self.rng.randint(0, 3) if 'train' in example.guid else 1
num_masks = len(get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)) + num_pad
masks = self.mask * num_masks
if self.pattern_id == 0:
return [text_a, "The pronoun '*" + pronoun + "*' refers to", masks + '.'], []
elif self.pattern_id == 1:
return [text_a, "In the previous sentence, the pronoun '*" + pronoun + "*' refers to", masks + '.'], []
elif self.pattern_id == 2:
return [text_a,
"Question: In the passage above, what does the pronoun '*" + pronoun + "*' refer to? Answer: ",
masks + '.'], []
def verbalize(self, label) -> List[str]:
return []
class BoolQPVP(PVP):
VERBALIZER_A = {
"False": ["No"],
"True": ["Yes"]
}
VERBALIZER_B = {
"False": ["false"],
"True": ["true"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
passage = self.shortenable(example.text_a)
question = self.shortenable(example.text_b)
if self.pattern_id < 2:
return [passage, '. Question: ', question, '? Answer: ', self.mask, '.'], []
elif self.pattern_id < 4:
return [passage, '. Based on the previous passage, ', question, '?', self.mask, '.'], []
else:
return ['Based on the following passage, ', question, '?', self.mask, '.', passage], []
def verbalize(self, label) -> List[str]:
if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4:
return BoolQPVP.VERBALIZER_A[label]
else:
return BoolQPVP.VERBALIZER_B[label]
class MultiRcPVP(PVP):
VERBALIZER = {
"0": ["No"],
"1": ["Yes"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
passage = self.shortenable(example.text_a)
question = example.text_b
answer = example.meta['answer']
if self.pattern_id == 0:
return [passage, '. Question: ', question, '? Is it ', answer, '?', self.mask, '.'], []
if self.pattern_id == 1:
return [passage, '. Question: ', question, '? Is the correct answer "', answer, '"?', self.mask, '.'], []
if self.pattern_id == 2:
return [passage, '. Based on the previous passage, ', question, '? Is "', answer, '" a correct answer?',
self.mask, '.'], []
if self.pattern_id == 3:
return [passage, question, '- [', self.mask, ']', answer], []
def verbalize(self, label) -> List[str]:
if self.pattern_id == 3:
return ['False'] if label == "0" else ['True']
return MultiRcPVP.VERBALIZER[label]
class WicPVP(PVP):
VERBALIZER_A = {
"F": ["No"],
"T": ["Yes"]
}
VERBALIZER_B = {
"F": ["2"],
"T": ["b"]
}
def get_parts(self, example: InputExample) -> FilledPattern:
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
word = example.meta['word']
if self.pattern_id == 0:
return ['"', text_a, '" / "', text_b, '" Similar sense of "' + word + '"?', self.mask, '.'], []
if self.pattern_id == 1:
return [text_a, text_b, 'Does ' + word + ' have the same meaning in both sentences?', self.mask], []
if self.pattern_id == 2:
return [word, ' . Sense (1) (a) "', text_a, '" (', self.mask, ') "', text_b, '"'], []
def verbalize(self, label) -> List[str]:
if self.pattern_id == 2:
return WicPVP.VERBALIZER_B[label]
return WicPVP.VERBALIZER_A[label]
class RecordPVP(PVP):
def get_parts(self, example: InputExample) -> FilledPattern:
premise = self.shortenable(example.text_a)
choices = example.meta['candidates']
assert '@placeholder' in example.text_b, f'question "{example.text_b}" does not contain a @placeholder token'
num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in choices)
question = example.text_b.replace('@placeholder', self.mask * num_masks)
return [premise, question], []
def verbalize(self, label) -> List[str]:
return []
PVPS = {
'agnews': AgnewsPVP,
'mnli': MnliPVP,
'yelp-polarity': YelpPolarityPVP,
'yelp-full': YelpFullPVP,
'yahoo': YahooPVP,
'xstance': XStancePVP,
'xstance-de': XStancePVP,
'xstance-fr': XStancePVP,
'rte': RtePVP,
'wic': WicPVP,
'cb': CbPVP,
'wsc': WscPVP,
'boolq': BoolQPVP,
'copa': CopaPVP,
'multirc': MultiRcPVP,
'record': RecordPVP,
'ax-b': RtePVP,
'ax-g': RtePVP,
}
class MyTaskPVP(PVP):
"""
Example for a pattern-verbalizer pair (PVP).
"""
# Set this to the name of the task
TASK_NAME = "my-task"
# Set this to the verbalizer for the given task: a mapping from the task's labels (which can be obtained using
# the corresponding DataProcessor's get_labels method) to tokens from the language model's vocabulary
#VERBALIZER = {"+1": ["Good"],"-1": ["Bad"],}
VERBALIZER = {1: ["fascinating",'effective',"irresistible","thrilling"], # "breathtaking",
0: ["boring","embarrassing","weird","nothing"],} #"depressing",
def get_parts(self, example: InputExample):
text_a = self.shortenable(example.text_a)
if self.pattern_id == 0:
#*cls*_A*mask*_idea.*+sent_0**sep+* 93 {"0": "boring", "1": "fascinating"}
return ['A',self.mask,'idea.',text_a],[]
if self.pattern_id == 1:
# *cls*_A*mask*_show.*+sent_0**sep+* 95 {"0": "embarrassing", "1": "effective"}
return ['A',self.mask,'show.',text_a],[]
if self.pattern_id == 2:
# *cls*_The_story_is*mask*.*+sent_0**sep+* 81 {"0": "depressing", "1": "irresistible"}
return ['The story is',self.mask,'.',text_a],[]
if self.pattern_id == 3:
# *cls**sent_0*_The_result_is*mask*.*sep+* 49 {"0": "weird", "1": "breathtaking"}
return [text_a,'The result is', self.mask,'.'],[]
if self.pattern_id == 4:
# *cls**sent_0*_Very*mask*!*sep+* 30 {"0": "embarrassing", "1": "thrilling"}
return [text_a,'Very',self.mask,'!'],[]
if self.pattern_id == 5:
# *cls**sent_0*_This_is*mask*.*sep+* 7 {"0": "nothing", "1": "thrilling"}
return [text_a,'This is',self.mask,'.'],[]
def verbalize(self, label) -> List[str]:
return MyTaskPVP.VERBALIZER[label]
class MyTaskPVP2(PVP):
"""
Example for a pattern-verbalizer pair (PVP).
"""
# Set this to the name of the task
TASK_NAME = "my-task2"
# Set this to the verbalizer for the given task: a mapping from the task's labels (which can be obtained using
# the corresponding DataProcessor's get_labels method) to tokens from the language model's vocabulary
#VERBALIZER = {"+1": ["Good"],"-1": ["Bad"],}
VERBALIZER = {1: ["Good"], # "breathtaking",
0: ["Bad"],} #"depressing",
def get_parts(self, example: InputExample):
text_a = self.shortenable(example.text_a)
return [text_a,'It is',self.mask,'.'],[]
def verbalize(self, label) -> List[str]:
return MyTaskPVP.VERBALIZER[label]
class AutoBest5(PVP):
"""
Example for a pattern-verbalizer pair (PVP).
"""
# Set this to the name of the task
TASK_NAME = "autobest5"
# Set this to the verbalizer for the given task: a mapping from the task's labels (which can be obtained using
# the corresponding DataProcessor's get_labels method) to tokens from the language model's vocabulary
#VERBALIZER = {"+1": ["Good"],"-1": ["Bad"],}
VERBALIZER = {1: ["good"], # "breathtaking",
0: ["bad"],} #"depressing",
def get_parts(self, example: InputExample):
text_a = self.shortenable(example.text_a)
'''
0.89967 *cls*_A*mask*_film!*+sent_0**sep+*
0.89298 *cls**sent_0*_This_is_really*mask*.*sep+*
0.89298 *cls*_It_is*mask*!*+sent_0**sep+*
0.89298 *cls*_This_movie_is*mask*.*+sent_0**sep+*
0.88963 *cls*_Just*mask*.*+sent_0**sep+*
0.88963 *cls*_The_movie_is*mask*.*+sent_0**sep+*
0.88963 *cls*_This_film_is*mask*.*+sent_0**sep+*
0.88629 *cls**sent_0*_It's*mask*.*sep+*
0.88629 *cls**sent_0*_The_movie_is*mask*.*sep+*
0.88629 *cls**sent_0*_This_is_just*mask*.*sep+*
'''
if self.pattern_id == 0:
return ['A',self.mask,'film.',text_a],[]
if self.pattern_id == 1:
return [text_a,'This is really',self.mask,'.'],[]
if self.pattern_id == 2:
return ['It is',self.mask,'!',text_a],[]
if self.pattern_id == 3:
return ['The movie is',self.mask,'.',text_a],[]
# *cls**sent_0*_The_result_is*mask*.*sep+* 49 {"0": "weird", "1": "breathtaking"}
#return [text_a,'The result is', self.mask,'.'],[]
if self.pattern_id == 4:
return ['Just',self.mask,'.',text_a],[]
# *cls**sent_0*_Very*mask*!*sep+* 30 {"0": "embarrassing", "1": "thrilling"}
#return [text_a,'Very',self.mask,'!'],[]
#if self.pattern_id == 5:
def verbalize(self, label) -> List[str]:
return MyTaskPVP.VERBALIZER[label]
# register the PVP for this task with its name
PVPS[MyTaskPVP.TASK_NAME] = MyTaskPVP
PVPS[MyTaskPVP2.TASK_NAME] = MyTaskPVP2
PVPS[AutoBest5.TASK_NAME] = AutoBest5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from abc import ABC
from collections import defaultdict
from typing import Dict, List, Optional, Any
import torch
import re
import numpy as np
from torch.nn import CrossEntropyLoss
from pet.utils import InputFeatures, InputExample, get_verbalization_ids, chunks, trim_input_ids, remove_final_punc, \
lowercase_first
class TaskHelper(ABC):
"""
A helper class that provides custom training and evaluation methods for tasks that do not fit in PETs default
schema, for example because they require more than two sequences of text, different evaluation metrics or
verbalizers consisting of multiple tokens.
"""
def __init__(self, wrapper):
"""
Create a new task helper.
:param wrapper: The wrapper for the language model being used.
"""
self.wrapper = wrapper
self.output = None
def train_step(self, batch: Dict[str, torch.Tensor], **kwargs) -> Optional[torch.Tensor]:
"""
Custom implementation of the train step for this task.
:param batch: a batch of examples
:return: a scalar loss tensor
"""
pass
def eval_step(self, batch: Dict[str, torch.Tensor], **kwargs) -> Optional[torch.Tensor]:
"""
Custom implementation of the eval step for this task.
:param batch: a batch of examples
:return: a tensor of logits
"""
pass
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
"""
Add special features to the ``meta`` dictionary of a feature set
:param input_example: the input example considered
:param input_features: the set of features corresponding to this example
"""
pass
def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
"""
Add special features from the ``meta`` dictionary of a sequence of features to the corresponding dictionary
:param features: the sequence of features
:param feature_dict: the dictionary that stores aggregated feature views as tensors
"""
pass
def get_sequence_classifier_inputs(self, example: InputExample) -> Dict[str, Any]:
"""
Get the inputs for sequence classification. Override this method if the input for the task considered is of a
more complicated form than `text_a` or `text_a [SEP] text_b`.
:param example: the input example
:return: the dictionary of inputs
"""
pass
class MultiMaskTaskHelper(TaskHelper):
"""A custom task helper for classification datasets where multiple masks are required for one or more verbalizers."""
def train_step(self, batch, **kwargs) -> Optional[torch.Tensor]:
if self.wrapper.config.wrapper_type == 'sequence_classifier':
return
assert self.wrapper.config.wrapper_type == 'mlm', 'train_step() for MultiMaskTaskHelper is only implemented for MLM models'
inputs = self.wrapper.generate_default_inputs(batch)
loss_fct = CrossEntropyLoss(reduction='none')
# prediction_scores.shape == max_seq_len x vocab_size x batch_size
prediction_scores = self.wrapper.model(**inputs)[0].permute(1, 2, 0)
# all_choice_token_ids.shape == batch_size x num_choices x max_seq_len
all_choice_token_ids = batch['choice_token_ids']
batch_size, num_choices, max_seq_len = all_choice_token_ids.shape
# all_candidate_labels.shape() == batch_size
all_labels = batch['labels']
# correct_choice_token_ids.shape == max_seq_len x batch_size
correct_choice_token_ids = all_choice_token_ids[torch.arange(batch_size), all_labels].permute(1, 0)
wrong_choices_mask = torch.ones_like(all_choice_token_ids)
wrong_choices_mask.scatter_(1, all_labels.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, max_seq_len), 0)
wrong_choices_token_ids = all_choice_token_ids[wrong_choices_mask.bool()].view(batch_size, num_choices - 1, max_seq_len)
# wrong_choices_token_ids.shape == (num_choices - 1) x max_seq_len x batch_size
wrong_choices_token_ids = wrong_choices_token_ids.permute(1, 2, 0)
total_loss = 0
# loss_correct_label.shape == batch_size
loss_correct_choice = loss_fct(prediction_scores, correct_choice_token_ids).sum(dim=0)
# compute hinge loss
for wrong_choice_token_ids in wrong_choices_token_ids:
loss_wrong_choice = loss_fct(prediction_scores, wrong_choice_token_ids).sum(dim=0)
hinge_loss = 1 + loss_correct_choice - loss_wrong_choice
hinge_loss[hinge_loss < 0] = 0
total_loss += hinge_loss
return total_loss.mean()
def eval_step(self, batch: Dict[str, torch.Tensor], batch_size: int = 8, decoding_strategy: str = 'default'):
if self.wrapper.config.wrapper_type == 'sequence_classifier':
return
assert self.wrapper.config.wrapper_type == 'mlm', 'eval_step() for MultiMaskTaskHelper is only implemented for MLM models'
assert batch['input_ids'].shape[0] == 1, "eval_step() for MultiMaskTaskHelper is only implemented for batch_size=1"
all_choice_token_ids = batch['choice_token_ids'][0]
log_probabilities = torch.tensor([[-math.inf] * len(all_choice_token_ids)], dtype=torch.float, device=all_choice_token_ids.device)
# group choices by length to speed up decoding
choices_grouped_by_length = defaultdict(list)
for idx, choice_token_ids in enumerate(all_choice_token_ids):
num_masks = sum(1 for x in choice_token_ids if x != -100)
choices_grouped_by_length[num_masks].append((idx, choice_token_ids))
input_ids = {}
initial_outputs = {}
for num_masks in choices_grouped_by_length.keys():
# modify the input ids to contain the correct number of masks
input_ids[num_masks] = trim_input_ids(batch['input_ids'], num_masks=num_masks,
pad_token_id=self.wrapper.tokenizer.pad_token_id,
mask_token_id=self.wrapper.tokenizer.mask_token_id)
initial_outputs[num_masks] = self.wrapper.model(input_ids[num_masks])
for num_masks, choices_with_labels in choices_grouped_by_length.items():
for batch in chunks(choices_with_labels, batch_size):
batch_input_ids = input_ids[num_masks].repeat(len(batch), 1)
choice_token_ids = torch.stack([choice_token_ids for idx, choice_token_ids in batch])
batch_probabilities = self._get_choice_probabilities_batched(choice_token_ids, batch_input_ids, initial_outputs[num_masks],
decoding_strategy=decoding_strategy)
for batch_idx, (idx, choice_token_ids) in enumerate(batch):
log_probabilities[0][idx] = batch_probabilities[batch_idx]
return log_probabilities
def _get_choice_probabilities_batched(self, target_sequences, input_ids, initial_output, decoding_strategy):
log_probabilities = defaultdict(list)
first_call = True
while True:
masks = {batch_idx: [(idx, tok) for idx, tok in enumerate(target_sequences[batch_idx]) if tok >= 0] for
batch_idx in range(len(target_sequences))}
if not masks[0]: # there are no masks left to process, we are done
break
if first_call:
outputs = initial_output
else:
outputs = self.wrapper.model(input_ids)
next_token_logits = outputs[0]
next_token_logits = torch.nn.Softmax(dim=2)(next_token_logits)
if decoding_strategy == 'ltr':
masks = {batch_idx: [batch_masks[0]] for batch_idx, batch_masks in masks.items()}
for batch_idx in range(len(target_sequences)):
ntl = next_token_logits[batch_idx] if not first_call else next_token_logits[0]
if decoding_strategy == 'parallel':
for m_pos, m_id in masks[batch_idx]:
log_probabilities[batch_idx].append(math.log(ntl[m_pos][m_id].item()))
target_sequences[batch_idx][m_pos] = -100
else:
mask_pos, masked_id = None, None
highest_prob = None
for m_pos, m_id in masks[batch_idx]:
m_prob = ntl[m_pos][m_id]
if highest_prob is None or m_prob > highest_prob:
highest_prob = m_prob
mask_pos, masked_id = m_pos, m_id
log_probabilities[batch_idx].append(math.log(ntl[mask_pos][masked_id].item()))
input_ids[batch_idx][mask_pos] = masked_id
target_sequences[batch_idx][mask_pos] = -100
first_call = False
return {batch_idx: sum(log_prob for log_prob in log_probabilities[batch_idx]) for batch_idx in
range(len(target_sequences))}
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
if self.wrapper.config.wrapper_type == 'sequence_classifier':
return
mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)
if 'choices' in input_example.meta:
choices = [choice for choice in input_example.meta['choices']]
else:
label_list = self.wrapper.config.label_list
choices = [self.wrapper.preprocessor.pvp.verbalize(label)[0] for label in label_list]
input_features.meta['choice_token_ids'] = []
for idx, choice_text in enumerate(choices):
choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False)
mask_end = mask_start + len(choice_token_ids)
candidate_token_ids = [-100] * len(input_features.input_ids)
candidate_token_ids[mask_start:mask_end] = choice_token_ids
input_features.meta['choice_token_ids'].append(candidate_token_ids)
def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
if self.wrapper.config.wrapper_type == 'sequence_classifier':
return
max_num_choices = max(len(f.meta['choice_token_ids']) for f in features)
for feature in features:
if len(feature.meta['choice_token_ids']) != max_num_choices:
raise ValueError(f"The number of output choices must be identical for all examples, got "
f"{len(feature.meta['choice_token_ids'])} and {max_num_choices}")
feature_dict['choice_token_ids'] = torch.tensor([f.meta['choice_token_ids'] for f in features], dtype=torch.long)
class WicTaskHelper(TaskHelper):
"""A custom task helper for the WiC dataset."""
def get_sequence_classifier_inputs(self, example: InputExample) -> Dict[str, Any]:
text_a = example.meta['word'] + ': ' + example.text_a
return self.wrapper.tokenizer.encode_plus(text_a, example.text_b, add_special_tokens=True,
max_length=self.wrapper.config.max_seq_length)
class MultiRcTaskHelper(TaskHelper):
"""A custom task helper for the MultiRC dataset."""
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
input_features.meta['question_idx'] = input_example.meta['question_idx']
def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
feature_dict['question_idx'] = torch.tensor([f.meta['question_idx'] for f in features], dtype=torch.long)
def get_sequence_classifier_inputs(self, example: InputExample) -> Dict[str, Any]:
text_a = example.text_a
text_b = ' '.join([example.text_b, self.wrapper.tokenizer.sep_token, example.meta['answer']])
return self.wrapper.tokenizer.encode_plus(text_a, text_b, add_special_tokens=True,
max_length=self.wrapper.config.max_seq_length)
class CopaTaskHelper(TaskHelper):
"""A custom task helper for the COPA dataset."""
def get_sequence_classifier_inputs(self, example: InputExample) -> Dict[str, Any]:
premise = remove_final_punc(example.text_a)
choice1, choice2 = lowercase_first(example.meta['choice1']), lowercase_first(example.meta['choice2'])
question = example.meta['question']
joiner = 'because' if question == 'cause' else 'so'
text_a, text_b = ' '.join([premise, joiner, choice1]), ' '.join([premise, joiner, choice2])
return self.wrapper.tokenizer.encode_plus(text_a, text_b, add_special_tokens=True,
max_length=self.wrapper.config.max_seq_length)
def train_step(self, batch, **kwargs) -> Optional[torch.Tensor]:
if self.wrapper.config.wrapper_type == 'sequence_classifier':
return
assert self.wrapper.config.wrapper_type == 'mlm', 'train_step() for COPA is only implemented for MLM models'
inputs = self.wrapper.generate_default_inputs(batch)
mask = batch['labels'].unsqueeze(1)
correct_targets = batch['choice1_token_ids'] * (1 - mask) + batch['choice2_token_ids'] * mask
wrong_targets = batch['choice1_token_ids'] * mask + batch['choice2_token_ids'] * (1 - mask)
prediction_scores = self.wrapper.model(**inputs)[0].view(-1, self.wrapper.model.config.vocab_size)
loss_fct = CrossEntropyLoss()
loss_correct_label = loss_fct(prediction_scores, correct_targets.view(-1))
loss_wrong_label = loss_fct(prediction_scores, wrong_targets.view(-1))
loss = 1 + loss_correct_label - loss_wrong_label
loss[loss < 0] = 0
return loss
def eval_step(self, batch: Dict[str, torch.Tensor], decoding_strategy: str = 'default', **kwargs):
if self.wrapper.config.wrapper_type == 'sequence_classifier':
return
assert self.wrapper.config.wrapper_type == 'mlm', 'eval_step() for COPA is only implemented for MLM models'
assert batch['input_ids'].shape[0] == 1, 'eval_step() for COPA is only implemented for batch_size=1'
log_probs = []
for choice in ['choice1', 'choice2']:
labels = batch[f'{choice}_token_ids']
log_prob = self._get_choice_log_probability(batch, labels, decoding_strategy=decoding_strategy)
log_probs.append(log_prob)
return torch.tensor([log_probs])
def _get_choice_log_probability(self, batch, target_sequence, decoding_strategy: str = 'default'):
# adjust the number of masks
num_masks = sum(1 for tok_id in target_sequence[0] if tok_id != -100)
input_ids = trim_input_ids(batch['input_ids'], num_masks=num_masks,
pad_token_id=self.wrapper.tokenizer.pad_token_id,
mask_token_id=self.wrapper.tokenizer.mask_token_id)
log_probabilities = []
while True:
masks = [(idx, tok_id) for idx, tok_id in enumerate(target_sequence[0]) if tok_id != -100]
if not masks: # there are no masks left to process, we are done
break
outputs = self.wrapper.model(input_ids)
next_token_logits = torch.nn.Softmax(dim=2)(outputs[0])[0]
if decoding_strategy == 'ltr':
mask_pos, masked_id = masks[0]
max_prob = next_token_logits[mask_pos][masked_id].item()
elif decoding_strategy == 'parallel':
for m_pos, m_id in masks:
log_probabilities.append(math.log(next_token_logits[m_pos][m_id].item()))
break
else:
mask_pos, masked_id = None, None
max_prob = None
for m_pos, m_id in masks:
m_prob = next_token_logits[m_pos][m_id].item()
if max_prob is None or m_prob > max_prob:
max_prob = m_prob
mask_pos, masked_id = m_pos, m_id
log_probabilities.append(math.log(max_prob))
input_ids[0][mask_pos] = masked_id
target_sequence[0][mask_pos] = -100
return sum(log_probabilities)
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
if self.wrapper.config.wrapper_type == 'sequence_classifier':
return
mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)
for choice in ['choice1', 'choice2']:
choice_text = input_example.meta[choice]
choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False)
mask_end = mask_start + len(choice_token_ids)
input_features.meta[f'{choice}_token_ids'] = [-100] * len(input_features.input_ids)
input_features.meta[f'{choice}_token_ids'][mask_start:mask_end] = choice_token_ids
def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
if self.wrapper.config.wrapper_type == 'sequence_classifier':
return
for choice in ['choice1', 'choice2']:
feature_dict[f'{choice}_token_ids'] = torch.tensor(
[f.meta[f'{choice}_token_ids'] for f in features], dtype=torch.long)
class WscTaskHelper(TaskHelper):
"""A custom task helper for the Wsc dataset."""
def __init__(self, wrapper):
super().__init__(wrapper)
self.id_to_target = []
def get_sequence_classifier_inputs(self, example: InputExample) -> Dict[str, Any]:
target = example.meta['span1_text']
pronoun_idx = example.meta['span2_index']
# mark the pronoun with asterisks
words_a = example.text_a.split()
words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*'
text_a = ' '.join(words_a)
text_b = target
return self.wrapper.tokenizer.encode_plus(text_a, text_b, add_special_tokens=True,
max_length=self.wrapper.config.max_seq_length)
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
if self.wrapper.config.wrapper_type == 'sequence_classifier':
return
mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)
num_masks = input_features.input_ids.count(self.wrapper.tokenizer.mask_token_id)
mask_end = mask_start + num_masks
target = input_example.meta['span1_text']
input_features.meta['target'] = target
target_token_ids = get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)
input_features.meta['target_token_ids'] = [-100] * len(input_features.input_ids)
# we also predict <pad> tokens at the missing positions
target_token_ids += [self.wrapper.tokenizer.pad_token_id] * (num_masks - len(target_token_ids))
input_features.meta['target_token_ids'][mask_start:mask_end] = target_token_ids
def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
if self.wrapper.config.wrapper_type == 'sequence_classifier':
return
feature_dict['target_id'] = torch.tensor([len(self.id_to_target) + idx for idx, f in enumerate(features)],
dtype=torch.long)
self.id_to_target += [f.meta['target'] for f in features]
feature_dict['target_token_ids'] = torch.tensor([f.meta['target_token_ids'] for f in features],
dtype=torch.long)
def train_step(self, batch, **kwargs) -> Optional[torch.Tensor]:
if self.wrapper.config.wrapper_type == 'sequence_classifier':
return
assert self.wrapper.config.wrapper_type == 'mlm', 'train_step() for WSC is only implemented for MLM models'
inputs = self.wrapper.generate_default_inputs(batch)
inputs['labels'] = batch['target_token_ids']
loss = self.wrapper.model(**inputs)[0]
return loss
def eval_step(self, batch: Dict[str, torch.Tensor], decoding_strategy: str = 'default', **kwargs):
if self.wrapper.config.wrapper_type in ['sequence_classifier', 'span_pair_classifier']:
return
assert self.wrapper.config.wrapper_type == 'mlm', 'eval_step() for WSC is only implemented for MLM models'
assert batch['input_ids'].shape[0] == 1, 'eval_step() for COPA is only implemented for batch_size=1'
inputs = self.wrapper.generate_default_inputs(batch)
input_ids = inputs['input_ids']
orig_mask_positions = [
idx for idx, input_id in enumerate(input_ids[0]) if input_id == self.wrapper.tokenizer.mask_token_id
]
while True:
mask_positions = [
idx for idx, input_id in enumerate(input_ids[0]) if input_id == self.wrapper.tokenizer.mask_token_id
]
if not mask_positions: # there are no masks left to process, we are done
input_ids = input_ids[0].detach().cpu().tolist()
output_actual = self.wrapper.tokenizer.decode([
input_id for idx, input_id in enumerate(input_ids)
if idx in orig_mask_positions and input_id not in self.wrapper.tokenizer.all_special_ids
])
output_expected = self.id_to_target[batch["target_id"][0].item()]
# transform both outputs as described in the T5 paper
output_actual = output_actual.lower().strip()
output_actual = [w for w in re.split('[^a-zA-Z]', output_actual) if w]
output_expected = output_expected.lower().strip()
output_expected = [w for w in re.split('[^a-zA-Z]', output_expected) if w]
# compare outputs
if all(x in output_expected for x in output_actual) or all(
x in output_actual for x in output_expected):
return torch.tensor([[0, 1]])
return torch.tensor([[1, 0]])
outputs = self.wrapper.model(**inputs)
next_token_logits = outputs[0]
next_token_logits = torch.nn.Softmax(dim=2)(next_token_logits)
next_token_logits = next_token_logits[0].detach().cpu().numpy()
most_confident = ()
most_confident_score = -1
if decoding_strategy == 'ltr':
mask_positions = [mask_positions[0]]
k = 1
for mask_position in mask_positions:
ntl = next_token_logits[mask_position]
top_token_id = np.argmax(ntl)
top_score = ntl[top_token_id]
if decoding_strategy == 'parallel':
input_ids[0][mask_position] = top_token_id
elif top_score > most_confident_score:
most_confident_score = top_score
most_confident = (mask_position, top_token_id)
if decoding_strategy == 'parallel':
continue
input_ids[0][most_confident[0]] = most_confident[1]
class RecordTaskHelper(TaskHelper):
"""A custom task helper for the ReCoRD dataset."""
def __init__(self, wrapper):
super().__init__(wrapper)
self.output = []
self.original_choices = {}
def train_step(self, batch, **kwargs) -> Optional[torch.Tensor]:
assert self.wrapper.config.wrapper_type == 'mlm', 'train_step() for ReCoRD is only implemented for MLM models'
inputs = self.wrapper.generate_default_inputs(batch)
prediction_scores = self.wrapper.model(**inputs)[0].view(-1, self.wrapper.model.config.vocab_size)
loss_fct = CrossEntropyLoss()
# all_candidate_token_ids.shape() == batch_size x max_num_candidates x max_seq_len
all_candidate_token_ids = batch['candidate_token_ids']
# all_candidate_labels.shape() == batch_size x max_num_candidates
all_candidate_labels = batch['candidate_labels']
all_candidate_token_ids = all_candidate_token_ids.permute(1, 0, 2)
all_candidate_labels = all_candidate_labels.permute(1, 0)
total_loss = 0
loss_correct_label = loss_fct(prediction_scores, all_candidate_token_ids[0].view(-1))
# compute hinge loss
for candidate_token_ids, candidate_labels in zip(all_candidate_token_ids[1:], all_candidate_labels[1:]):
loss_wrong_label = loss_fct(prediction_scores, candidate_token_ids.view(-1))
hinge_loss = 1 + loss_correct_label - loss_wrong_label
hinge_loss[hinge_loss < 0] = 0
total_loss += hinge_loss
return total_loss
def eval_step(self, batch: Dict[str, torch.Tensor], batch_size: int = 8, decoding_strategy: str = 'default'):
assert self.wrapper.config.wrapper_type == 'mlm', 'eval_step() for ReCoRD is only implemented for MLM models'
assert batch['input_ids'].shape[0] == 1, "eval_step() for ReCoRD is only implemented for batch_size=1"
best_choice_correct, best_choice, max_prob = False, None, None
question_idx = batch['question_idx'][0].item()
output_line = {'idx': question_idx, 'choices': {}}
# group choices by length to speed up decoding
choices_grouped_by_length = defaultdict(list)
for idx, (choice_ids, label) in enumerate(zip(batch['candidate_token_ids'][0], batch['candidate_labels'][0])):
if label < 0:
continue
num_masks = sum(1 for x in choice_ids if x != -100)
choice = self.original_choices[question_idx][idx]
choices_grouped_by_length[num_masks].append((choice, choice_ids, label))
input_ids = {}
initial_outputs = {}
for num_masks in choices_grouped_by_length.keys():
# modify the input ids to contain the correct number of masks
input_ids[num_masks] = trim_input_ids(batch['input_ids'], num_masks=num_masks,
pad_token_id=self.wrapper.tokenizer.pad_token_id,
mask_token_id=self.wrapper.tokenizer.mask_token_id)
initial_outputs[num_masks] = self.wrapper.model(input_ids[num_masks])
for num_masks, choices_with_labels in choices_grouped_by_length.items():
for batch in chunks(choices_with_labels, batch_size):
batch_input_ids = input_ids[num_masks].repeat(len(batch), 1)
choice_ids = torch.stack([choice_id for choice, choice_id, label in batch])
probs = self._get_choice_probabilities_batched(choice_ids, batch_input_ids, initial_outputs[num_masks],
decoding_strategy=decoding_strategy)
for idx, (choice, choice_ids, label) in enumerate(batch):
prob = probs[idx]
output_line['choices'][choice] = prob
if max_prob is None or prob > max_prob:
best_choice_correct, max_prob = (label == 1), prob
self.output.append(output_line)
if best_choice_correct:
return torch.tensor([[0, 1]])
return torch.tensor([[1, 0]])
def _get_choice_probabilities_batched(self, target_sequences, input_ids, initial_output, decoding_strategy):
log_probabilities = defaultdict(list)
first_call = True
while True:
masks = {batch_idx: [(idx, tok) for idx, tok in enumerate(target_sequences[batch_idx]) if tok >= 0] for
batch_idx in range(len(target_sequences))}
if not masks[0]: # there are no masks left to process, we are done
break
if first_call:
outputs = initial_output
else:
outputs = self.wrapper.model(input_ids)
next_token_logits = outputs[0]
next_token_logits = torch.nn.Softmax(dim=2)(next_token_logits)
if decoding_strategy == 'ltr':
masks = {batch_idx: [batch_masks[0]] for batch_idx, batch_masks in masks.items()}
for batch_idx in range(len(target_sequences)):
ntl = next_token_logits[batch_idx] if not first_call else next_token_logits[0]
if decoding_strategy == 'parallel':
for m_pos, m_id in masks[batch_idx]:
log_probabilities[batch_idx].append(math.log(ntl[m_pos][m_id].item()))
target_sequences[batch_idx][m_pos] = -100
else:
mask_pos, masked_id = None, None
highest_prob = None
for m_pos, m_id in masks[batch_idx]:
m_prob = ntl[m_pos][m_id]
if highest_prob is None or m_prob > highest_prob:
highest_prob = m_prob
mask_pos, masked_id = m_pos, m_id
log_probabilities[batch_idx].append(math.log(ntl[mask_pos][masked_id].item()))
input_ids[batch_idx][mask_pos] = masked_id
target_sequences[batch_idx][mask_pos] = -100
first_call = False
return {batch_idx: sum(log_prob for log_prob in log_probabilities[batch_idx]) for batch_idx in
range(len(target_sequences))}
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)
choices = input_example.meta['candidates']
question_idx = input_example.meta['question_idx']
input_features.meta['candidate_token_ids'] = []
input_features.meta['candidate_labels'] = []
input_features.meta['question_idx'] = question_idx
self.original_choices[question_idx] = []
for idx, choice_text in enumerate(choices):
choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False)
choice_label = 1 if choice_text in input_example.meta['answers'] else 0
mask_end = mask_start + len(choice_token_ids)
candidate_token_ids = [-100] * len(input_features.input_ids)
candidate_token_ids[mask_start:mask_end] = choice_token_ids
input_features.meta['candidate_token_ids'].append(candidate_token_ids)
input_features.meta['candidate_labels'].append(choice_label)
self.original_choices[question_idx].append(choice_text)
def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
# apply padding if necessary
max_num_candidates = max(len(f.meta['candidate_token_ids']) for f in features)
for feature in features:
while len(feature.meta['candidate_token_ids']) < max_num_candidates:
feature.meta['candidate_token_ids'].append([-100] * len(feature.input_ids))
feature.meta['candidate_labels'].append(-100)
feature_dict['candidate_token_ids'] = \
torch.tensor([f.meta['candidate_token_ids'] for f in features], dtype=torch.long)
feature_dict['candidate_labels'] = \
torch.tensor([f.meta['candidate_labels'] for f in features], dtype=torch.long)
feature_dict['question_idx'] = torch.tensor([f.meta['question_idx'] for f in features], dtype=torch.long)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains the logic for loading training and test data for all tasks.
"""
import csv
import json
import os
import random
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
from typing import List, Dict, Callable
import log
from pet import task_helpers
from pet.utils import InputExample
logger = log.get_logger('root')
def _shuffle_and_restrict(examples: List[InputExample], num_examples: int, seed: int = 42) -> List[InputExample]:
"""
Shuffle a list of examples and restrict it to a given maximum size.
:param examples: the examples to shuffle and restrict
:param num_examples: the maximum number of examples
:param seed: the random seed for shuffling
:return: the first ``num_examples`` elements of the shuffled list
"""
if 0 < num_examples < len(examples):
random.Random(seed).shuffle(examples)
examples = examples[:num_examples]
return examples
class LimitedExampleList:
def __init__(self, labels: List[str], max_examples=-1):
"""
Implementation of a list that stores only a limited amount of examples per label.
:param labels: the set of all possible labels
:param max_examples: the maximum number of examples per label. This can either be a fixed number,
in which case `max_examples` examples are loaded for every label, or a list with the same size as
`labels`, in which case at most `max_examples[i]` examples are loaded for label `labels[i]`.
"""
self._labels = labels
self._examples = []
self._examples_per_label = defaultdict(int)
if isinstance(max_examples, list):
self._max_examples = dict(zip(self._labels, max_examples))
else:
self._max_examples = {label: max_examples for label in self._labels}
def is_full(self):
"""Return `true` iff no more examples can be added to this list"""
for label in self._labels:
if self._examples_per_label[label] < self._max_examples[label] or self._max_examples[label] < 0:
return False
return True
def add(self, example: InputExample) -> bool:
"""
Add a new input example to this list.
:param example: the example to add
:returns: `true` iff the example was actually added to the list
"""
label = example.label
if self._examples_per_label[label] < self._max_examples[label] or self._max_examples[label] < 0:
self._examples_per_label[label] += 1
self._examples.append(example)
return True
return False
def to_list(self):
return self._examples
class DataProcessor(ABC):
"""
Abstract class that provides methods for loading training, testing, development and unlabeled examples for a given
task
"""
@abstractmethod
def get_train_examples(self, data_dir) -> List[InputExample]:
"""Get a collection of `InputExample`s for the train set."""
pass
@abstractmethod
def get_dev_examples(self, data_dir) -> List[InputExample]:
"""Get a collection of `InputExample`s for the dev set."""
pass
@abstractmethod
def get_test_examples(self, data_dir) -> List[InputExample]:
"""Get a collection of `InputExample`s for the test set."""
pass
@abstractmethod
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
"""Get a collection of `InputExample`s for the unlabeled set."""
pass
@abstractmethod
def get_labels(self) -> List[str]:
"""Get the list of labels for this data set."""
pass
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
return self._create_examples(MnliProcessor._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(MnliProcessor._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
return self.get_train_examples(data_dir)
def get_labels(self):
return ["contradiction", "entailment", "neutral"]
@staticmethod
def _create_examples(lines: List[List[str]], set_type: str) -> List[InputExample]:
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[8]
text_b = line[9]
label = line[-1]
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
return examples
@staticmethod
def _read_tsv(input_file, quotechar=None):
with open(input_file, "r", encoding="utf-8-sig") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
class MnliMismatchedProcessor(MnliProcessor):
"""Processor for the MultiNLI mismatched data set (GLUE version)."""
def get_dev_examples(self, data_dir):
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
class AgnewsProcessor(DataProcessor):
"""Processor for the AG news data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.csv"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.csv"), "dev")
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
return self.get_train_examples(data_dir)
def get_labels(self):
return ["1", "2", "3", "4"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path) as f:
reader = csv.reader(f, delimiter=',')
for idx, row in enumerate(reader):
label, headline, body = row
guid = "%s-%s" % (set_type, idx)
text_a = headline.replace('\\', ' ')
text_b = body.replace('\\', ' ')
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
return examples
class YahooAnswersProcessor(DataProcessor):
"""Processor for the Yahoo Answers data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.csv"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.csv"), "dev")
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
return self.get_train_examples(data_dir)
def get_labels(self):
return ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
reader = csv.reader(f, delimiter=',')
for idx, row in enumerate(reader):
label, question_title, question_body, answer = row
guid = "%s-%s" % (set_type, idx)
text_a = ' '.join([question_title.replace('\\n', ' ').replace('\\', ' '),
question_body.replace('\\n', ' ').replace('\\', ' ')])
text_b = answer.replace('\\n', ' ').replace('\\', ' ')
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
return examples
class YelpPolarityProcessor(DataProcessor):
"""Processor for the YELP binary classification set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.csv"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.csv"), "dev")
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
return self.get_train_examples(data_dir)
def get_labels(self):
return ["1", "2"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path) as f:
reader = csv.reader(f, delimiter=',')
for idx, row in enumerate(reader):
label, body = row
guid = "%s-%s" % (set_type, idx)
text_a = body.replace('\\n', ' ').replace('\\', ' ')
example = InputExample(guid=guid, text_a=text_a, label=label)
examples.append(example)
return examples
class YelpFullProcessor(YelpPolarityProcessor):
"""Processor for the YELP full classification set."""
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_labels(self):
return ["1", "2", "3", "4", "5"]
class XStanceProcessor(DataProcessor):
"""Processor for the X-Stance data set."""
def __init__(self, language: str = None):
if language is not None:
assert language in ['de', 'fr']
self.language = language
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"))
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"))
def get_test_examples(self, data_dir) -> List[InputExample]:
raise NotImplementedError()
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
return self.get_train_examples(data_dir)
def get_labels(self):
return ["FAVOR", "AGAINST"]
def _create_examples(self, path: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
label = example_json['label']
id_ = example_json['id']
text_a = example_json['question']
text_b = example_json['comment']
language = example_json['language']
if self.language is not None and language != self.language:
continue
example = InputExample(guid=id_, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
return examples
class RteProcessor(DataProcessor):
"""Processor for the RTE data set."""
def __init__(self):
self.mnli_processor = MnliProcessor()
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["entailment", "not_entailment"]
def _create_examples(self, path: str, set_type: str, hypothesis_name: str = "hypothesis",
premise_name: str = "premise") -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line_idx, line in enumerate(f):
example_json = json.loads(line)
idx = example_json['idx']
if isinstance(idx, str):
try:
idx = int(idx)
except ValueError:
idx = line_idx
label = example_json.get('label')
guid = "%s-%s" % (set_type, idx)
text_a = example_json[premise_name]
text_b = example_json[hypothesis_name]
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx)
examples.append(example)
return examples
class AxGProcessor(RteProcessor):
"""Processor for the AX-G diagnostic data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "AX-g.jsonl"), "train")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "AX-g.jsonl"), "test")
class AxBProcessor(RteProcessor):
"""Processor for the AX-B diagnostic data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "AX-b.jsonl"), "train")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "AX-b.jsonl"), "test")
def _create_examples(self, path, set_type, hypothesis_name="sentence2", premise_name="sentence1"):
return super()._create_examples(path, set_type, hypothesis_name, premise_name)
class CbProcessor(RteProcessor):
"""Processor for the CB data set."""
def get_labels(self):
return ["entailment", "contradiction", "neutral"]
class WicProcessor(DataProcessor):
"""Processor for the WiC data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["F", "T"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
idx = example_json['idx']
if isinstance(idx, str):
idx = int(idx)
label = "T" if example_json.get('label') else "F"
guid = "%s-%s" % (set_type, idx)
text_a = example_json['sentence1']
text_b = example_json['sentence2']
meta = {'word': example_json['word']}
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx, meta=meta)
examples.append(example)
return examples
class WscProcessor(DataProcessor):
"""Processor for the WSC data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["False", "True"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
idx = example_json['idx']
label = str(example_json['label']) if 'label' in example_json else None
guid = "%s-%s" % (set_type, idx)
text_a = example_json['text']
meta = {
'span1_text': example_json['target']['span1_text'],
'span2_text': example_json['target']['span2_text'],
'span1_index': example_json['target']['span1_index'],
'span2_index': example_json['target']['span2_index']
}
# the indices in the dataset are wrong for some examples, so we manually fix them
span1_index, span1_text = meta['span1_index'], meta['span1_text']
span2_index, span2_text = meta['span2_index'], meta['span2_text']
words_a = text_a.split()
words_a_lower = text_a.lower().split()
words_span1_text = span1_text.lower().split()
span1_len = len(words_span1_text)
if words_a_lower[span1_index:span1_index + span1_len] != words_span1_text:
for offset in [-1, +1]:
if words_a_lower[span1_index + offset:span1_index + span1_len + offset] == words_span1_text:
span1_index += offset
if words_a_lower[span1_index:span1_index + span1_len] != words_span1_text:
logger.warning(f"Got '{words_a_lower[span1_index:span1_index + span1_len]}' but expected "
f"'{words_span1_text}' at index {span1_index} for '{words_a}'")
if words_a[span2_index] != span2_text:
for offset in [-1, +1]:
if words_a[span2_index + offset] == span2_text:
span2_index += offset
if words_a[span2_index] != span2_text and words_a[span2_index].startswith(span2_text):
words_a = words_a[:span2_index] \
+ [words_a[span2_index][:len(span2_text)], words_a[span2_index][len(span2_text):]] \
+ words_a[span2_index + 1:]
assert words_a[span2_index] == span2_text, \
f"Got '{words_a[span2_index]}' but expected '{span2_text}' at index {span2_index} for '{words_a}'"
text_a = ' '.join(words_a)
meta['span1_index'], meta['span2_index'] = span1_index, span2_index
example = InputExample(guid=guid, text_a=text_a, label=label, meta=meta, idx=idx)
if set_type == 'train' and label != 'True':
continue
examples.append(example)
return examples
class BoolQProcessor(DataProcessor):
"""Processor for the BoolQ data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["False", "True"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
idx = example_json['idx']
label = str(example_json['label']) if 'label' in example_json else None
guid = "%s-%s" % (set_type, idx)
text_a = example_json['passage']
text_b = example_json['question']
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx)
examples.append(example)
return examples
class CopaProcessor(DataProcessor):
"""Processor for the COPA data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["0", "1"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
label = str(example_json['label']) if 'label' in example_json else None
idx = example_json['idx']
guid = "%s-%s" % (set_type, idx)
text_a = example_json['premise']
meta = {
'choice1': example_json['choice1'],
'choice2': example_json['choice2'],
'question': example_json['question']
}
example = InputExample(guid=guid, text_a=text_a, label=label, meta=meta, idx=idx)
examples.append(example)
if set_type == 'train' or set_type == 'unlabeled':
mirror_examples = []
for ex in examples:
label = "1" if ex.label == "0" else "0"
meta = {
'choice1': ex.meta['choice2'],
'choice2': ex.meta['choice1'],
'question': ex.meta['question']
}
mirror_example = InputExample(guid=ex.guid + 'm', text_a=ex.text_a, label=label, meta=meta)
mirror_examples.append(mirror_example)
examples += mirror_examples
logger.info(f"Added {len(mirror_examples)} mirror examples, total size is {len(examples)}...")
return examples
class MultiRcProcessor(DataProcessor):
"""Processor for the MultiRC data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["0", "1"]
@staticmethod
def _create_examples(path: str, set_type: str) -> List[InputExample]:
examples = []
with open(path, encoding='utf8') as f:
for line in f:
example_json = json.loads(line)
passage_idx = example_json['idx']
text = example_json['passage']['text']
questions = example_json['passage']['questions']
for question_json in questions:
question = question_json["question"]
question_idx = question_json['idx']
answers = question_json["answers"]
for answer_json in answers:
label = str(answer_json["label"]) if 'label' in answer_json else None
answer_idx = answer_json["idx"]
guid = f'{set_type}-p{passage_idx}-q{question_idx}-a{answer_idx}'
meta = {
'passage_idx': passage_idx,
'question_idx': question_idx,
'answer_idx': answer_idx,
'answer': answer_json["text"]
}
idx = [passage_idx, question_idx, answer_idx]
example = InputExample(guid=guid, text_a=text, text_b=question, label=label, meta=meta, idx=idx)
examples.append(example)
question_indices = list(set(example.meta['question_idx'] for example in examples))
label_distribution = Counter(example.label for example in examples)
logger.info(f"Returning {len(examples)} examples corresponding to {len(question_indices)} questions with label "
f"distribution {list(label_distribution.items())}")
return examples
class RecordProcessor(DataProcessor):
"""Processor for the ReCoRD data set."""
def get_train_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test")
def get_unlabeled_examples(self, data_dir):
return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled")
def get_labels(self):
return ["0", "1"]
@staticmethod
def _create_examples(path, set_type, seed=42, max_train_candidates_per_question: int = 10) -> List[InputExample]:
examples = []
entity_shuffler = random.Random(seed)
with open(path, encoding='utf8') as f:
for idx, line in enumerate(f):
example_json = json.loads(line)
idx = example_json['idx']
text = example_json['passage']['text']
entities = set()
for entity_json in example_json['passage']['entities']:
start = entity_json['start']
end = entity_json['end']
entity = text[start:end + 1]
entities.add(entity)
entities = list(entities)
text = text.replace("@highlight\n", "- ") # we follow the GPT-3 paper wrt @highlight annotations
questions = example_json['qas']
for question_json in questions:
question = question_json['query']
question_idx = question_json['idx']
answers = set()
for answer_json in question_json.get('answers', []):
answer = answer_json['text']
answers.add(answer)
answers = list(answers)
if set_type == 'train':
# create a single example per *correct* answer
for answer_idx, answer in enumerate(answers):
candidates = [ent for ent in entities if ent not in answers]
if len(candidates) > max_train_candidates_per_question - 1:
entity_shuffler.shuffle(candidates)
candidates = candidates[:max_train_candidates_per_question - 1]
guid = f'{set_type}-p{idx}-q{question_idx}-a{answer_idx}'
meta = {
'passage_idx': idx,
'question_idx': question_idx,
'candidates': [answer] + candidates,
'answers': [answer]
}
ex_idx = [idx, question_idx, answer_idx]
example = InputExample(guid=guid, text_a=text, text_b=question, label="1", meta=meta,
idx=ex_idx)
examples.append(example)
else:
# create just one example with *all* correct answers and *all* answer candidates
guid = f'{set_type}-p{idx}-q{question_idx}'
meta = {
'passage_idx': idx,
'question_idx': question_idx,
'candidates': entities,
'answers': answers
}
example = InputExample(guid=guid, text_a=text, text_b=question, label="1", meta=meta)
examples.append(example)
question_indices = list(set(example.meta['question_idx'] for example in examples))
label_distribution = Counter(example.label for example in examples)
logger.info(f"Returning {len(examples)} examples corresponding to {len(question_indices)} questions with label "
f"distribution {list(label_distribution.items())}")
return examples
PROCESSORS = {
"mnli": MnliProcessor,
"mnli-mm": MnliMismatchedProcessor,
"agnews": AgnewsProcessor,
"yahoo": YahooAnswersProcessor,
"yelp-polarity": YelpPolarityProcessor,
"yelp-full": YelpFullProcessor,
"xstance-de": lambda: XStanceProcessor("de"),
"xstance-fr": lambda: XStanceProcessor("fr"),
"xstance": XStanceProcessor,
"wic": WicProcessor,
"rte": RteProcessor,
"cb": CbProcessor,
"wsc": WscProcessor,
"boolq": BoolQProcessor,
"copa": CopaProcessor,
"multirc": MultiRcProcessor,
"record": RecordProcessor,
"ax-g": AxGProcessor,
"ax-b": AxBProcessor,
} # type: Dict[str,Callable[[],DataProcessor]]
class MyTaskDataProcessor(DataProcessor):
"""
Example for a data processor.
"""
# Set this to the name of the task
TASK_NAME = "my-task"
# Set this to the name of the file containing the train examples
TRAIN_FILE_NAME = "train.tsv"
# Set this to the name of the file containing the dev examples
DEV_FILE_NAME = "dev.tsv"
#DEV_FILE_NAME = "test.tsv"
# Set this to the name of the file containing the test examples
TEST_FILE_NAME = "test.tsv"
#TEST_FILE_NAME = "dev.tsv"
# Set this to the name of the file containing the unlabeled examples
UNLABELED_FILE_NAME = "unlabeled.tsv"
# Set this to a list of all labels in the train + test data
#LABELS = ["+1", "-1"]
LABELS = [1, 0]
# Set this to the column of the train/test csv files containing the input's text a
TEXT_COLUMN = 0
# Set this to the column of the train/test csv files containing the input's gold label
LABEL_COLUMN = 1
def get_train_examples(self, data_dir: str) -> List[InputExample]:
"""
This method loads train examples from a file with name `TRAIN_FILE_NAME` in the given directory.
:param data_dir: the directory in which the training data can be found
:return: a list of train examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.TRAIN_FILE_NAME), "train")
def get_dev_examples(self, data_dir: str) -> List[InputExample]:
"""
This method loads dev examples from a file with name `DEV_FILE_NAME` in the given directory.
:param data_dir: the directory in which the dev data can be found
:return: a list of dev examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.DEV_FILE_NAME), "dev")
def get_test_examples(self, data_dir) -> List[InputExample]:
"""
This method loads test examples from a file with name `TEST_FILE_NAME` in the given directory.
:param data_dir: the directory in which the test data can be found
:return: a list of test examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.TEST_FILE_NAME), "test")
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
"""
This method loads unlabeled examples from a file with name `UNLABELED_FILE_NAME` in the given directory.
:param data_dir: the directory in which the unlabeled data can be found
:return: a list of unlabeled examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.UNLABELED_FILE_NAME), "unlabeled")
def get_labels(self) -> List[str]:
"""This method returns all possible labels for the task."""
return MyTaskDataProcessor.LABELS
def _create_examples(self, path, set_type, max_examples=-1, skip_first=0):
"""Creates examples for the training and dev sets."""
examples = []
with open(path) as f:
reader = csv.reader(f, delimiter='\t')
for idx, row in enumerate(reader):
if idx!=0:
guid = "%s-%s" % (set_type, idx)
label = int(row[MyTaskDataProcessor.LABEL_COLUMN] )
# print(label)
text = row[MyTaskDataProcessor.TEXT_COLUMN]
# text_b = row[MyTaskDataProcessor.TEXT_B_COLUMN] if MyTaskDataProcessor.TEXT_B_COLUMN >= 0 else None
example = InputExample(guid=guid, text_a=text, label=label)
examples.append(example)
return examples
# register the processor for this task with its name
PROCESSORS[MyTaskDataProcessor.TASK_NAME] = MyTaskDataProcessor
PROCESSORS['my-task2'] = MyTaskDataProcessor
PROCESSORS['autobest5'] = MyTaskDataProcessor
TASK_HELPERS = {
"wsc": task_helpers.WscTaskHelper,
"multirc": task_helpers.MultiRcTaskHelper,
"copa": task_helpers.CopaTaskHelper,
"record": task_helpers.RecordTaskHelper,
}
METRICS = {
"cb": ["acc", "f1-macro"],
"multirc": ["acc", "f1", "em"]
}
DEFAULT_METRICS = ["acc"]
TRAIN_SET = "train"
DEV_SET = "dev"
TEST_SET = "test"
UNLABELED_SET = "unlabeled"
SET_TYPES = [TRAIN_SET, DEV_SET, TEST_SET, UNLABELED_SET]
def load_examples(task, data_dir: str, set_type: str, *_, num_examples: int = None,
num_examples_per_label: int = None, seed: int = 42) -> List[InputExample]:
"""Load examples for a given task."""
assert (num_examples is not None) ^ (num_examples_per_label is not None), \
"Exactly one of 'num_examples' and 'num_examples_per_label' must be set."
assert (not set_type == UNLABELED_SET) or (num_examples is not None), \
"For unlabeled data, 'num_examples_per_label' is not allowed"
processor = PROCESSORS[task]()
ex_str = f"num_examples={num_examples}" if num_examples is not None \
else f"num_examples_per_label={num_examples_per_label}"
logger.info(
f"Creating features from dataset file at {data_dir} ({ex_str}, set_type={set_type})"
)
if set_type == DEV_SET:
examples = processor.get_dev_examples(data_dir)
elif set_type == TEST_SET:
examples = processor.get_test_examples(data_dir)
elif set_type == TRAIN_SET:
examples = processor.get_train_examples(data_dir)
elif set_type == UNLABELED_SET:
examples = processor.get_unlabeled_examples(data_dir)
for example in examples:
example.label = processor.get_labels()[0]
else:
raise ValueError(f"'set_type' must be one of {SET_TYPES}, got '{set_type}' instead")
if num_examples is not None:
examples = _shuffle_and_restrict(examples, num_examples, seed)
elif num_examples_per_label is not None:
limited_examples = LimitedExampleList(processor.get_labels(), num_examples_per_label)
for example in examples:
limited_examples.add(example)
examples = limited_examples.to_list()
label_distribution = Counter(example.label for example in examples)
logger.info(f"Returning {len(examples)} {set_type} examples with label dist.: {list(label_distribution.items())}")
return examples
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import pickle
import random
import string
from collections import defaultdict
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, GPT2Tokenizer
class LogitsList:
"""A list of logits obtained from a finetuned PET model"""
def __init__(self, score: float, logits: List[List[float]]):
"""
Create a new LogitsList.
:param score: the corresponding PET model's score on the training set
:param logits: the list of logits, where ``logits[i][j]`` is the score for label ``j`` at example ``i``
"""
self.score = score
self.logits = logits
def __repr__(self):
return 'LogitsList(score={}, logits[:2]={})'.format(self.score, self.logits[:2])
def save(self, path: str) -> None:
"""Save this list to a file."""
with open(path, 'w') as fh:
fh.write(str(self.score) + '\n')
for example_logits in self.logits:
fh.write(' '.join(str(logit) for logit in example_logits) + '\n')
@staticmethod
def load(path: str, with_score: bool = True) -> 'LogitsList':
"""Load a list from a file"""
score = -1
logits = []
with open(path, 'r') as fh:
for line_idx, line in enumerate(fh.readlines()):
line = line.rstrip('\n')
if line_idx == 0 and with_score:
score = float(line)
else:
logits.append([float(x) for x in line.split()])
return LogitsList(score=score, logits=logits)
class InputExample(object):
"""A raw input example consisting of one or two segments of text and a label"""
def __init__(self, guid, text_a, text_b=None, label=None, logits=None, meta: Optional[Dict] = None, idx=-1):
"""
Create a new InputExample.
:param guid: a unique textual identifier
:param text_a: the sequence of text
:param text_b: an optional, second sequence of text
:param label: an optional label
:param logits: an optional list of per-class logits
:param meta: an optional dictionary to store arbitrary meta information
:param idx: an optional numeric index
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
self.logits = logits
self.idx = idx
self.meta = meta if meta else {}
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serialize this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serialize this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
@staticmethod
def load_examples(path: str) -> List['InputExample']:
"""Load a set of input examples from a file"""
with open(path, 'rb') as fh:
return pickle.load(fh)
@staticmethod
def save_examples(examples: List['InputExample'], path: str) -> None:
"""Save a set of input examples to a file"""
with open(path, 'wb') as fh:
pickle.dump(examples, fh)
class InputFeatures(object):
"""A set of numeric features obtained from an :class:`InputExample`"""
def __init__(self, input_ids, attention_mask, token_type_ids, label, mlm_labels=None, logits=None,
meta: Optional[Dict] = None, idx=-1):
"""
Create new InputFeatures.
:param input_ids: the input ids corresponding to the original text or text sequence
:param attention_mask: an attention mask, with 0 = no attention, 1 = attention
:param token_type_ids: segment ids as used by BERT
:param label: the label
:param mlm_labels: an optional sequence of labels used for auxiliary language modeling
:param logits: an optional sequence of per-class logits
:param meta: an optional dictionary to store arbitrary meta information
:param idx: an optional numeric index
"""
self.input_ids = input_ids
self.attention_mask = attention_mask
self.token_type_ids = token_type_ids
self.label = label
self.mlm_labels = mlm_labels
self.logits = logits
self.idx = idx
self.meta = meta if meta else {}
def __repr__(self):
return str(self.to_json_string())
def pretty_print(self, tokenizer):
return f'input_ids = {tokenizer.convert_ids_to_tokens(self.input_ids)}\n' + \
f'attention_mask = {self.attention_mask}\n' + \
f'token_type_ids = {self.token_type_ids}\n' + \
f'mlm_labels = {self.mlm_labels}\n' + \
f'logits = {self.logits}\n' + \
f'label = {self.label}'
def to_dict(self):
"""Serialize this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serialize this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class PLMInputFeatures(InputFeatures):
"""A set of numeric input features for a model pretrained with a permuted language modeling objective."""
def __init__(self, *_, perm_mask, target_mapping, **kwargs):
super().__init__(**kwargs)
self.perm_mask = perm_mask
self.target_mapping = target_mapping
def pretty_print(self, tokenizer):
return super().pretty_print(tokenizer) + '\n' + \
f'perm_mask = {self.perm_mask}\n' + \
f'target_mapping = {self.target_mapping}'
class DictDataset(Dataset):
"""A dataset of tensors that uses a dictionary for key-value mappings"""
def __init__(self, **tensors):
tensors.values()
assert all(next(iter(tensors.values())).size(0) == tensor.size(0) for tensor in tensors.values())
self.tensors = tensors
def __getitem__(self, index):
return {key: tensor[index] for key, tensor in self.tensors.items()}
def __len__(self):
return next(iter(self.tensors.values())).size(0)
def set_seed(seed: int):
""" Set RNG seeds for python's `random` module, numpy and torch"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def eq_div(N, i):
""" Equally divide N examples among i buckets. For example, `eq_div(12,3) = [4,4,4]`. """
return [] if i <= 0 else [N // i + 1] * (N % i) + [N // i] * (i - N % i)
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
def remove_final_punc(s: str):
"""Remove the last character from a string if it is some form of punctuation"""
return s.rstrip(string.punctuation)
def lowercase_first(s: str):
"""Lowercase the first letter of a string"""
return s[0].lower() + s[1:]
def save_logits(path: str, logits: np.ndarray):
"""Save an array of logits to a file"""
with open(path, 'w') as fh:
for example_logits in logits:
fh.write(' '.join(str(logit) for logit in example_logits) + '\n')
pass
def save_predictions(path: str, wrapper, results: Dict):
"""Save a sequence of predictions to a file"""
predictions_with_idx = []
if wrapper.task_helper and wrapper.task_helper.output:
predictions_with_idx = wrapper.task_helper.output
else:
inv_label_map = {idx: label for label, idx in wrapper.preprocessor.label_map.items()}
for idx, prediction_idx in zip(results['indices'], results['predictions']):
prediction = inv_label_map[prediction_idx]
idx = idx.tolist() if isinstance(idx, np.ndarray) else int(idx)
predictions_with_idx.append({'idx': idx, 'label': prediction})
with open(path, 'w', encoding='utf8') as fh:
for line in predictions_with_idx:
fh.write(json.dumps(line) + '\n')
def softmax(x, temperature=1.0, axis=None):
"""Custom softmax implementation"""
y = np.atleast_2d(x)
if axis is None:
axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)
y = y * float(temperature)
y = y - np.expand_dims(np.max(y, axis=axis), axis)
y = np.exp(y)
ax_sum = np.expand_dims(np.sum(y, axis=axis), axis)
p = y / ax_sum
if len(x.shape) == 1:
p = p.flatten()
return p
def get_verbalization_ids(word: str, tokenizer: PreTrainedTokenizer, force_single_token: bool) -> Union[int, List[int]]:
"""
Get the token ids corresponding to a verbalization
:param word: the verbalization
:param tokenizer: the tokenizer to use
:param force_single_token: whether it should be enforced that the verbalization corresponds to a single token.
If set to true, this method returns a single int instead of a list and throws an error if the word
corresponds to multiple tokens.
:return: either the list of token ids or the single token id corresponding to this word
"""
kwargs = {'add_prefix_space': True} if isinstance(tokenizer, GPT2Tokenizer) else {}
ids = tokenizer.encode(word, add_special_tokens=False, **kwargs)
if not force_single_token:
return ids
assert len(ids) == 1, \
f'Verbalization "{word}" does not correspond to a single token, got {tokenizer.convert_ids_to_tokens(ids)}'
verbalization_id = ids[0]
assert verbalization_id not in tokenizer.all_special_ids, \
f'Verbalization {word} is mapped to a special token {tokenizer.convert_ids_to_tokens(verbalization_id)}'
return verbalization_id
def trim_input_ids(input_ids: torch.tensor, pad_token_id, mask_token_id, num_masks: int):
"""
Trim a sequence of input ids by removing all padding tokens and keeping at most a specific number of mask tokens.
:param input_ids: the sequence of input token ids
:param pad_token_id: the id of the pad token
:param mask_token_id: the id of the mask tokens
:param num_masks: the number of masks to keeps
:return: the trimmed sequence of input ids
"""
assert input_ids.shape[0] == 1
input_ids_without_pad = [x for x in input_ids[0] if x != pad_token_id]
trimmed_input_ids = []
mask_count = 0
for input_id in input_ids_without_pad:
if input_id == mask_token_id:
if mask_count >= num_masks:
continue
mask_count += 1
trimmed_input_ids.append(input_id)
return torch.tensor([trimmed_input_ids], dtype=torch.long, device=input_ids.device)
def exact_match(predictions: np.ndarray, actuals: np.ndarray, question_ids: np.ndarray):
"""Compute the exact match (EM) for a sequence of predictions and actual labels"""
unique_questions = set(question_ids)
q_actuals = list(zip(question_ids, actuals))
q_predictions = list(zip(question_ids, predictions))
actuals_per_question = defaultdict(list)
predictions_per_question = defaultdict(list)
for qid, val in q_actuals:
actuals_per_question[qid].append(val)
for qid, val in q_predictions:
predictions_per_question[qid].append(val)
em = 0
for qid in unique_questions:
if actuals_per_question[qid] == predictions_per_question[qid]:
em += 1
em /= len(unique_questions)
return em
def distillation_loss(predictions, targets, temperature):
"""Compute the distillation loss (KL divergence between predictions and targets) as described in the PET paper"""
p = F.log_softmax(predictions / temperature, dim=1)
q = F.softmax(targets / temperature, dim=1)
return F.kl_div(p, q, reduction='sum') * (temperature ** 2) / predictions.shape[0]
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains code for wrapping a transformer language model and
provides convenience methods for training and inference.
"""
import json
import jsonpickle
import os
from typing import List, Dict, Optional
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import RandomSampler, DataLoader, SequentialSampler
from tqdm import trange, tqdm
from transformers import InputExample, AdamW, get_linear_schedule_with_warmup, PreTrainedTokenizer, BertForMaskedLM, \
RobertaForMaskedLM, XLMRobertaForMaskedLM, XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer, \
XLNetLMHeadModel, BertConfig, BertForSequenceClassification, BertTokenizer, RobertaConfig, \
RobertaForSequenceClassification, RobertaTokenizer, XLMRobertaConfig, XLMRobertaForSequenceClassification, \
XLMRobertaTokenizer, AlbertForSequenceClassification, AlbertForMaskedLM, AlbertTokenizer, AlbertConfig, \
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers import __version__ as transformers_version
import log
from pet import preprocessor
from pet.tasks import TASK_HELPERS
from pet.utils import InputFeatures, DictDataset, distillation_loss
logger = log.get_logger('root')
CONFIG_NAME = 'wrapper_config.json'
SEQUENCE_CLASSIFIER_WRAPPER = "sequence_classifier"
MLM_WRAPPER = "mlm"
PLM_WRAPPER = "plm"
WRAPPER_TYPES = [SEQUENCE_CLASSIFIER_WRAPPER, MLM_WRAPPER, PLM_WRAPPER]
PREPROCESSORS = {
SEQUENCE_CLASSIFIER_WRAPPER: preprocessor.SequenceClassifierPreprocessor,
MLM_WRAPPER: preprocessor.MLMPreprocessor,
PLM_WRAPPER: preprocessor.PLMPreprocessor,
}
MODEL_CLASSES = {
'bert': {
'config': BertConfig,
'tokenizer': BertTokenizer,
SEQUENCE_CLASSIFIER_WRAPPER: BertForSequenceClassification,
MLM_WRAPPER: BertForMaskedLM
},
'roberta': {
'config': RobertaConfig,
'tokenizer': RobertaTokenizer,
SEQUENCE_CLASSIFIER_WRAPPER: RobertaForSequenceClassification,
MLM_WRAPPER: RobertaForMaskedLM
},
'xlm-roberta': {
'config': XLMRobertaConfig,
'tokenizer': XLMRobertaTokenizer,
SEQUENCE_CLASSIFIER_WRAPPER: XLMRobertaForSequenceClassification,
MLM_WRAPPER: XLMRobertaForMaskedLM
},
'xlnet': {
'config': XLNetConfig,
'tokenizer': XLNetTokenizer,
SEQUENCE_CLASSIFIER_WRAPPER: XLNetForSequenceClassification,
PLM_WRAPPER: XLNetLMHeadModel
},
'albert': {
'config': AlbertConfig,
'tokenizer': AlbertTokenizer,
SEQUENCE_CLASSIFIER_WRAPPER: AlbertForSequenceClassification,
MLM_WRAPPER: AlbertForMaskedLM
},
'gpt2': {
'config': GPT2Config,
'tokenizer': GPT2Tokenizer,
MLM_WRAPPER: GPT2LMHeadModel
},
}
EVALUATION_STEP_FUNCTIONS = {
MLM_WRAPPER: lambda wrapper: wrapper.mlm_eval_step,
PLM_WRAPPER: lambda wrapper: wrapper.plm_eval_step,
SEQUENCE_CLASSIFIER_WRAPPER: lambda wrapper: wrapper.sequence_classifier_eval_step,
}
TRAIN_STEP_FUNCTIONS = {
MLM_WRAPPER: lambda wrapper: wrapper.mlm_train_step,
PLM_WRAPPER: lambda wrapper: wrapper.plm_train_step,
SEQUENCE_CLASSIFIER_WRAPPER: lambda wrapper: wrapper.sequence_classifier_train_step,
}
class WrapperConfig(object):
"""A configuration for a :class:`TransformerModelWrapper`."""
def __init__(self, model_type: str, model_name_or_path: str, wrapper_type: str, task_name: str, max_seq_length: int,
label_list: List[str], pattern_id: int = 0, verbalizer_file: str = None, cache_dir: str = None):
"""
Create a new config.
:param model_type: the model type (e.g., 'bert', 'roberta', 'albert')
:param model_name_or_path: the model name (e.g., 'roberta-large') or path to a pretrained model
:param wrapper_type: the wrapper type (one of 'mlm', 'plm' and 'sequence_classifier')
:param task_name: the task to solve
:param max_seq_length: the maximum number of tokens in a sequence
:param label_list: the list of labels for the task
:param pattern_id: the id of the pattern to use
:param verbalizer_file: optional path to a verbalizer file
:param cache_dir: optional path to a cache dir
"""
self.model_type = model_type
self.model_name_or_path = model_name_or_path
self.wrapper_type = wrapper_type
self.task_name = task_name
self.max_seq_length = max_seq_length
self.label_list = label_list
self.pattern_id = pattern_id
self.verbalizer_file = verbalizer_file
self.cache_dir = cache_dir
class TransformerModelWrapper:
"""A wrapper around a Transformer-based language model."""
def __init__(self, config: WrapperConfig):
"""Create a new wrapper from the given config."""
self.config = config
config_class = MODEL_CLASSES[self.config.model_type]['config']
tokenizer_class = MODEL_CLASSES[self.config.model_type]['tokenizer']
model_class = MODEL_CLASSES[self.config.model_type][self.config.wrapper_type]
model_config = config_class.from_pretrained(
config.model_name_or_path, num_labels=len(config.label_list), finetuning_task=config.task_name,
cache_dir=config.cache_dir if config.cache_dir else None, use_cache=False)
self.tokenizer = tokenizer_class.from_pretrained(
config.model_name_or_path,
cache_dir=config.cache_dir if config.cache_dir else None) # type: PreTrainedTokenizer
if self.config.model_type == 'gpt2':
self.tokenizer.pad_token, self.tokenizer.mask_token = self.tokenizer.eos_token, self.tokenizer.eos_token
self.model = model_class.from_pretrained(config.model_name_or_path, config=model_config,
cache_dir=config.cache_dir if config.cache_dir else None)
self.preprocessor = PREPROCESSORS[self.config.wrapper_type](self, self.config.task_name, self.config.pattern_id,
self.config.verbalizer_file)
self.task_helper = TASK_HELPERS[self.config.task_name](self) if self.config.task_name in TASK_HELPERS else None
@classmethod
def from_pretrained(cls, path: str) -> 'TransformerModelWrapper':
"""Load a pretrained wrapper from a given path."""
wrapper = TransformerModelWrapper.__new__(TransformerModelWrapper)
wrapper.config = wrapper._load_config(path)
tokenizer_class = MODEL_CLASSES[wrapper.config.model_type]['tokenizer']
model_class = MODEL_CLASSES[wrapper.config.model_type][wrapper.config.wrapper_type]
wrapper.model = model_class.from_pretrained(path)
wrapper.tokenizer = tokenizer_class.from_pretrained(path)
wrapper.preprocessor = PREPROCESSORS[wrapper.config.wrapper_type](
wrapper, wrapper.config.task_name, wrapper.config.pattern_id, wrapper.config.verbalizer_file)
wrapper.task_helper = TASK_HELPERS[wrapper.config.task_name](wrapper) \
if wrapper.config.task_name in TASK_HELPERS else None
return wrapper
def save(self, path: str) -> None:
"""Save a pretrained wrapper."""
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
model_to_save.save_pretrained(path)
self.tokenizer.save_pretrained(path)
self._save_config(path)
def _save_config(self, path: str) -> None:
with open(os.path.join(path, CONFIG_NAME), 'w') as f:
f.write(jsonpickle.encode(self.config))
@staticmethod
def _load_config(path: str) -> WrapperConfig:
with open(os.path.join(path, CONFIG_NAME), 'r') as f:
return jsonpickle.decode(f.read())
def train(self, task_train_data: List[InputExample], device, per_gpu_train_batch_size: int = 8, n_gpu: int = 1,
num_train_epochs: int = 3, gradient_accumulation_steps: int = 1, weight_decay: float = 0.0,
learning_rate: float = 5e-5, adam_epsilon: float = 1e-8, warmup_steps=0, max_grad_norm: float = 1,
logging_steps: int = 50, per_gpu_unlabeled_batch_size: int = 8, unlabeled_data: List[InputExample] = None,
lm_training: bool = False, use_logits: bool = False, alpha: float = 0.8, temperature: float = 1,
max_steps=-1, **_):
"""
Train the underlying language model.
:param task_train_data: the training examples to use
:param device: the training device (cpu/gpu)
:param per_gpu_train_batch_size: the number of training examples per batch and gpu
:param n_gpu: the number of gpus to use
:param num_train_epochs: the number of epochs to train
:param gradient_accumulation_steps: the number of gradient accumulation steps before performing an update
:param weight_decay: the weight decay to use
:param learning_rate: the learning rate to use
:param adam_epsilon: epsilon parameter for the Adam optimizer
:param warmup_steps: the number of warmup steps
:param max_grad_norm: the maximum norm for the gradient
:param logging_steps: the number of steps after which logging information is printed
:param per_gpu_unlabeled_batch_size: the number of unlabeled examples per batch and gpu
:param unlabeled_data: the unlabeled examples to use
:param lm_training: whether to perform auxiliary language modeling (only for MLMs)
:param use_logits: whether to use the example's logits instead of their labels to compute the loss
:param alpha: the alpha parameter for auxiliary language modeling
:param temperature: the temperature for knowledge distillation
:param max_steps: the maximum number of training steps, overrides ``num_train_epochs``
:return: a tuple consisting of the total number of steps and the average training loss
"""
train_batch_size = per_gpu_train_batch_size * max(1, n_gpu)
train_dataset = self._generate_dataset(task_train_data)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_batch_size)
unlabeled_dataloader, unlabeled_iter = None, None
if lm_training or use_logits:
# we need unlabeled data both for auxiliary language modeling and for knowledge distillation
assert unlabeled_data is not None
unlabeled_batch_size = per_gpu_unlabeled_batch_size * max(1, n_gpu)
unlabeled_dataset = self._generate_dataset(unlabeled_data, labelled=False)
unlabeled_sampler = RandomSampler(unlabeled_dataset)
unlabeled_dataloader = DataLoader(unlabeled_dataset, sampler=unlabeled_sampler,
batch_size=unlabeled_batch_size)
unlabeled_iter = unlabeled_dataloader.__iter__()
if use_logits:
train_dataloader = unlabeled_dataloader
if max_steps > 0:
t_total = max_steps
num_train_epochs = max_steps // (max(1, len(train_dataloader) // gradient_accumulation_steps)) + 1
else:
t_total = len(train_dataloader) // gradient_accumulation_steps * num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': weight_decay},
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
num_training_steps=t_total)
# multi-gpu training
if n_gpu > 1:
self.model = torch.nn.DataParallel(self.model)
step = 0
global_step = 0
tr_loss, logging_loss = 0.0, 0.0
self.model.zero_grad()
train_iterator = trange(int(num_train_epochs), desc="Epoch")
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration")
for _, batch in enumerate(epoch_iterator):
self.model.train()
unlabeled_batch = None
batch = {k: t.to(device) for k, t in batch.items()}
if lm_training:
while unlabeled_batch is None:
try:
unlabeled_batch = unlabeled_iter.__next__()
except StopIteration:
logger.info("Resetting unlabeled dataset")
unlabeled_iter = unlabeled_dataloader.__iter__()
lm_input_ids = unlabeled_batch['input_ids']
unlabeled_batch['input_ids'], unlabeled_batch['mlm_labels'] = self._mask_tokens(lm_input_ids)
unlabeled_batch = {k: t.to(device) for k, t in unlabeled_batch.items()}
train_step_inputs = {
'unlabeled_batch': unlabeled_batch, 'lm_training': lm_training, 'alpha': alpha,
'use_logits': use_logits, 'temperature': temperature
}
loss = self.task_helper.train_step(batch, **train_step_inputs) if self.task_helper else None
if loss is None:
loss = TRAIN_STEP_FUNCTIONS[self.config.wrapper_type](self)(batch, **train_step_inputs)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
loss.backward()
tr_loss += loss.item()
if (step + 1) % gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
optimizer.step()
scheduler.step()
self.model.zero_grad()
global_step += 1
if logging_steps > 0 and global_step % logging_steps == 0:
logs = {}
loss_scalar = (tr_loss - logging_loss) / logging_steps
learning_rate_scalar = scheduler.get_lr()[0]
logs['learning_rate'] = learning_rate_scalar
logs['loss'] = loss_scalar
logging_loss = tr_loss
print(json.dumps({**logs, **{'step': global_step}}))
if 0 < max_steps < global_step:
epoch_iterator.close()
break
step += 1
if 0 < max_steps < global_step:
train_iterator.close()
break
return global_step, (tr_loss / global_step if global_step > 0 else -1)
def eval(self, eval_data: List[InputExample], device, per_gpu_eval_batch_size: int = 8, n_gpu: int = 1,
priming: bool = False, decoding_strategy: str = 'default') -> Dict:
"""
Evaluate the underlying language model.
:param eval_data: the evaluation examples to use
:param device: the evaluation device (cpu/gpu)
:param per_gpu_eval_batch_size: the number of evaluation examples per batch and gpu
:param n_gpu: the number of gpus to use
:param priming: whether to use priming
:param decoding_strategy: the decoding strategy for PET with multiple masks ('default', 'ltr' or 'parallel')
:return: a dictionary of numpy arrays containing the indices, logits, labels, and (optional) question_ids for
each evaluation example.
"""
eval_dataset = self._generate_dataset(eval_data, priming=priming)
eval_batch_size = per_gpu_eval_batch_size * max(1, n_gpu)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size)
if n_gpu > 1:
self.model = torch.nn.DataParallel(self.model)
preds = None
all_indices, out_label_ids, question_ids = None, None, None
for batch in tqdm(eval_dataloader, desc="Evaluating"):
self.model.eval()
batch = {k: t.to(device) for k, t in batch.items()}
labels = batch['labels']
indices = batch['idx']
with torch.no_grad():
# some tasks require special evaluation
logits = self.task_helper.eval_step(batch,
decoding_strategy=decoding_strategy) if self.task_helper else None
if logits is None:
logits = EVALUATION_STEP_FUNCTIONS[self.config.wrapper_type](self)(batch)
if preds is None:
preds = logits.detach().cpu().numpy()
out_label_ids = labels.detach().cpu().numpy()
all_indices = indices.detach().cpu().numpy()
if 'question_idx' in batch:
question_ids = batch['question_idx'].detach().cpu().numpy()
else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, labels.detach().cpu().numpy(), axis=0)
all_indices = np.append(all_indices, indices.detach().cpu().numpy(), axis=0)
if 'question_idx' in batch:
question_ids = np.append(question_ids, batch['question_idx'].detach().cpu().numpy(), axis=0)
return {
'indices': all_indices,
'logits': preds,
'labels': out_label_ids,
'question_ids': question_ids
}
def _generate_dataset(self, data: List[InputExample], labelled: bool = True, priming: bool = False):
features = self._convert_examples_to_features(data, labelled=labelled, priming=priming)
feature_dict = {
'input_ids': torch.tensor([f.input_ids for f in features], dtype=torch.long),
'attention_mask': torch.tensor([f.attention_mask for f in features], dtype=torch.long),
'token_type_ids': torch.tensor([f.token_type_ids for f in features], dtype=torch.long),
'labels': torch.tensor([f.label for f in features], dtype=torch.long),
'mlm_labels': torch.tensor([f.mlm_labels for f in features], dtype=torch.long),
'logits': torch.tensor([f.logits for f in features], dtype=torch.float),
'idx': torch.tensor([f.idx for f in features], dtype=torch.long)
}
if self.config.wrapper_type == PLM_WRAPPER:
feature_dict['perm_mask'] = torch.tensor([f.perm_mask for f in features], dtype=torch.float)
feature_dict['target_mapping'] = torch.tensor([f.target_mapping for f in features], dtype=torch.float)
if self.task_helper:
self.task_helper.add_features_to_dict(features, feature_dict)
return DictDataset(**feature_dict)
def _convert_examples_to_features(self, examples: List[InputExample], labelled: bool = True,
priming: bool = False) -> List[InputFeatures]:
features = []
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logger.info("Writing example {}".format(ex_index))
input_features = self.preprocessor.get_input_features(example, labelled=labelled, priming=priming)
if self.task_helper:
self.task_helper.add_special_input_features(example, input_features)
features.append(input_features)
if ex_index < 5:
logger.info(f'--- Example {ex_index} ---')
logger.info(input_features.pretty_print(self.tokenizer))
return features
def _mask_tokens(self, input_ids):
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
labels = input_ids.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability 0.15)
probability_matrix = torch.full(labels.shape, 0.15)
special_tokens_mask = [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in
labels.tolist()]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
# if a version of transformers < 2.4.0 is used, -1 is the expected value for indices to ignore
if [int(v) for v in transformers_version.split('.')][:3] >= [2, 4, 0]:
ignore_value = -100
else:
ignore_value = -1
labels[~masked_indices] = ignore_value # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
input_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return input_ids, labels
def generate_default_inputs(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Generate the default inputs required by almost every language model."""
inputs = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}
if self.config.model_type in ['bert', 'xlnet']:
inputs['token_type_ids'] = batch['token_type_ids']
return inputs
def mlm_train_step(self, labeled_batch: Dict[str, torch.Tensor],
unlabeled_batch: Optional[Dict[str, torch.Tensor]] = None, lm_training: bool = False,
alpha: float = 0, **_) -> torch.Tensor:
"""Perform a MLM training step."""
inputs = self.generate_default_inputs(labeled_batch)
mlm_labels, labels = labeled_batch['mlm_labels'], labeled_batch['labels']
outputs = self.model(**inputs)
prediction_scores = self.preprocessor.pvp.convert_mlm_logits_to_cls_logits(mlm_labels, outputs[0])
loss = nn.CrossEntropyLoss()(prediction_scores.view(-1, len(self.config.label_list)), labels.view(-1))
if lm_training:
lm_inputs = self.generate_default_inputs(unlabeled_batch)
lm_inputs['masked_lm_labels'] = unlabeled_batch['mlm_labels']
lm_loss = self.model(**lm_inputs)[0]
loss = alpha * loss + (1 - alpha) * lm_loss
return loss
def plm_train_step(self, labeled_batch: Dict[str, torch.Tensor], lm_training: bool = False, **_):
"""Perform a PLM training step."""
inputs = self.generate_default_inputs(labeled_batch)
inputs['perm_mask'], inputs['target_mapping'] = labeled_batch['perm_mask'], labeled_batch['target_mapping']
labels = labeled_batch['labels']
outputs = self.model(**inputs)
prediction_scores = self.preprocessor.pvp.convert_plm_logits_to_cls_logits(outputs[0])
loss = nn.CrossEntropyLoss()(prediction_scores.view(-1, len(self.config.label_list)), labels.view(-1))
if lm_training:
raise NotImplementedError("Language model training is currently not implemented for PLMs")
return loss
def sequence_classifier_train_step(self, batch: Dict[str, torch.Tensor], use_logits: bool = False,
temperature: float = 1, **_) -> torch.Tensor:
"""Perform a sequence classifier training step."""
inputs = self.generate_default_inputs(batch)
if not use_logits:
inputs['labels'] = batch['labels']
outputs = self.model(**inputs)
if use_logits:
logits_predicted, logits_target = outputs[0], batch['logits']
return distillation_loss(logits_predicted, logits_target, temperature)
else:
return outputs[0]
def mlm_eval_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Perform a MLM evaluation step."""
inputs = self.generate_default_inputs(batch)
outputs = self.model(**inputs)
return self.preprocessor.pvp.convert_mlm_logits_to_cls_logits(batch['mlm_labels'], outputs[0])
def plm_eval_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Perform a PLM evaluation step."""
inputs = self.generate_default_inputs(batch)
inputs['perm_mask'], inputs['target_mapping'] = batch['perm_mask'], batch['target_mapping']
outputs = self.model(**inputs)
return self.preprocessor.pvp.convert_plm_logits_to_cls_logits(outputs[0])
def sequence_classifier_eval_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Perform a sequence classifier evaluation step."""
inputs = self.generate_default_inputs(batch)
return self.model(**inputs)[0]
home = /usr/local/bin
include-system-site-packages = false
version = 3.6.9
-f https://download.pytorch.org/whl/torch_stable.html
numpy==1.19
jsonpickle==1.1
scikit-learn==0.23.1
torch===1.5.0
torchvision==0.6.0
transformers==3.0.2
tqdm==4.48.1
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
date
python3 cli.py \
--method pet \
--pattern_ids 0 1 2 3 4 \
--data_dir SST-2.300 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name autobest5 \
--output_dir autobest5_pet2 \
--do_train \
--do_eval
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 1 2 3 4 \
--data_dir ../data/SST-2 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name autobest5 \
--output_dir autobest5_pet3 \
--do_train \
--do_eval \
--eval_set test && \
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 1 2 3 4 \
--data_dir ../data/SST-2.1000 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name autobest5 \
--output_dir autobest5_pet4 \
--do_train \
--do_eval \
--eval_set test && \
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method ipet \
--pattern_ids 0 1 2 3 4 \
--data_dir ../data/SST-2.1000 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name autobest5 \
--output_dir autobest5_ipet5 \
--do_train \
--do_eval \
--eval_set test && \
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 1 2 3 4 \
--data_dir SST-2.300 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name autobest5 \
--output_dir autobest5_pet \
--do_train \
--do_eval
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 \
--data_dir ../data/SST-2 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task2 \
--output_dir result_handcrafted_pattern2 \
--do_train \
--do_eval \
--eval_set test && \
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 \
--data_dir ../data/SST-2.1000 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task2 \
--output_dir result_handcrafted_pattern3 \
--do_train \
--do_eval \
--eval_set test && \
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 \
--data_dir SST-2.300 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task2 \
--output_dir result_handcrafted_pattern \
--do_train \
--do_eval
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
python3 cli.py \
--method pet \
--pattern_ids 0 \
--data_dir /home/mist/projects/LM-BFF/data/k-shot/SST-2/16-100 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task \
--output_dir results \
--do_train \
--do_eval
python3 cli.py \
--method ipet \
--pattern_ids 0 1 2 3 \
--data_dir /home/mist/projects/LM-BFF/data/k-shot/SST-2/16-100 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name lm-bff \
--output_dir results1 \
--do_train \
--do_eval
date
python3 cli.py \
--method pet \
--pattern_ids 0 1 2 3 4 5 \
--data_dir SST-2.300 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task \
--output_dir result_300 \
--do_train \
--do_eval
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method ipet \
--pattern_ids 0 1 2 3 4 5 \
--data_dir SST-2.300 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task \
--output_dir result4_300 \
--do_train \
--do_eval
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment