Commit 59302e20 by 20210828028

v1

parent 8e48a3b2
LM-BFF @ d3f96076
Subproject commit d3f960766ece2006bfc83b1567b149f146eb2c05
This source diff could not be displayed because it is too large. You can view the blob instead.
{
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To add a new task to PET, both a DataProcessor and a PVP for this task must
be added. The DataProcessor is responsible for loading training and test data.
This file shows an example of a DataProcessor for a new task.
"""
import csv
import os
from typing import List
from pet.task_helpers import MultiMaskTaskHelper
from pet.tasks import DataProcessor, PROCESSORS, TASK_HELPERS
from pet.utils import InputExample
class MyTaskDataProcessor(DataProcessor):
"""
Example for a data processor.
"""
# Set this to the name of the task
TASK_NAME = "my-task"
# Set this to the name of the file containing the train examples
TRAIN_FILE_NAME = "train.csv"
# Set this to the name of the file containing the dev examples
DEV_FILE_NAME = "dev.csv"
# Set this to the name of the file containing the test examples
TEST_FILE_NAME = "test.csv"
# Set this to the name of the file containing the unlabeled examples
UNLABELED_FILE_NAME = "unlabeled.csv"
# Set this to a list of all labels in the train + test data
LABELS = ["1", "2", "3", "4"]
# Set this to the column of the train/test csv files containing the input's text a
TEXT_A_COLUMN = 1
# Set this to the column of the train/test csv files containing the input's text b or to -1 if there is no text b
TEXT_B_COLUMN = 2
# Set this to the column of the train/test csv files containing the input's gold label
LABEL_COLUMN = 0
def get_train_examples(self, data_dir: str) -> List[InputExample]:
"""
This method loads train examples from a file with name `TRAIN_FILE_NAME` in the given directory.
:param data_dir: the directory in which the training data can be found
:return: a list of train examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.TRAIN_FILE_NAME), "train")
def get_dev_examples(self, data_dir: str) -> List[InputExample]:
"""
This method loads dev examples from a file with name `DEV_FILE_NAME` in the given directory.
:param data_dir: the directory in which the dev data can be found
:return: a list of dev examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.DEV_FILE_NAME), "dev")
def get_test_examples(self, data_dir) -> List[InputExample]:
"""
This method loads test examples from a file with name `TEST_FILE_NAME` in the given directory.
:param data_dir: the directory in which the test data can be found
:return: a list of test examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.TEST_FILE_NAME), "test")
def get_unlabeled_examples(self, data_dir) -> List[InputExample]:
"""
This method loads unlabeled examples from a file with name `UNLABELED_FILE_NAME` in the given directory.
:param data_dir: the directory in which the unlabeled data can be found
:return: a list of unlabeled examples
"""
return self._create_examples(os.path.join(data_dir, MyTaskDataProcessor.UNLABELED_FILE_NAME), "unlabeled")
def get_labels(self) -> List[str]:
"""This method returns all possible labels for the task."""
return MyTaskDataProcessor.LABELS
def _create_examples(self, path, set_type, max_examples=-1, skip_first=0):
"""Creates examples for the training and dev sets."""
examples = []
with open(path) as f:
reader = csv.reader(f, delimiter=',')
for idx, row in enumerate(reader):
guid = "%s-%s" % (set_type, idx)
label = row[MyTaskDataProcessor.LABEL_COLUMN]
text_a = row[MyTaskDataProcessor.TEXT_A_COLUMN]
text_b = row[MyTaskDataProcessor.TEXT_B_COLUMN] if MyTaskDataProcessor.TEXT_B_COLUMN >= 0 else None
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
return examples
# register the processor for this task with its name
PROCESSORS[MyTaskDataProcessor.TASK_NAME] = MyTaskDataProcessor
# optional: if you have to use verbalizers that correspond to multiple tokens, uncomment the following line
# TASK_HELPERS[MyTaskDataProcessor.TASK_NAME] = MultiMaskTaskHelper
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To add a new task to PET, both a DataProcessor and a PVP for this task must
be added. The PVP is responsible for applying patterns to inputs and mapping
labels to their verbalizations (see the paper for more details on PVPs).
This file shows an example of a PVP for a new task.
"""
from typing import List
from pet.pvp import PVP, PVPS
from pet.utils import InputExample
class MyTaskPVP(PVP):
"""
Example for a pattern-verbalizer pair (PVP).
"""
# Set this to the name of the task
TASK_NAME = "my-task"
# Set this to the verbalizer for the given task: a mapping from the task's labels (which can be obtained using
# the corresponding DataProcessor's get_labels method) to tokens from the language model's vocabulary
VERBALIZER = {
"1": ["World"],
"2": ["Sports"],
"3": ["Business"],
"4": ["Tech"]
}
def get_parts(self, example: InputExample):
"""
This function defines the actual patterns: It takes as input an example and outputs the result of applying a
pattern to it. To allow for multiple patterns, a pattern_id can be passed to the PVP's constructor. This
method must implement the application of all patterns.
"""
# We tell the tokenizer that both text_a and text_b can be truncated if the resulting sequence is longer than
# our language model's max sequence length.
text_a = self.shortenable(example.text_a)
text_b = self.shortenable(example.text_b)
# For each pattern_id, we define the corresponding pattern and return a pair of text a and text b (where text b
# can also be empty).
if self.pattern_id == 0:
# this corresponds to the pattern [MASK]: a b
return [self.mask, ':', text_a, text_b], []
elif self.pattern_id == 1:
# this corresponds to the pattern [MASK] News: a || (b)
return [self.mask, 'News:', text_a], ['(', text_b, ')']
else:
raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
def verbalize(self, label) -> List[str]:
return MyTaskPVP.VERBALIZER[label]
# register the PVP for this task with its name
PVPS[MyTaskPVP.TASK_NAME] = MyTaskPVP
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains basic logging logic.
"""
import logging
names = set()
def __setup_custom_logger(name: str) -> logging.Logger:
root_logger = logging.getLogger()
root_logger.handlers.clear()
formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
names.add(name)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
logger.addHandler(handler)
return logger
def get_logger(name: str) -> logging.Logger:
if name in names:
return logging.getLogger(name)
else:
return __setup_custom_logger(name)
{
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
from pet.modeling import *
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import List
import numpy as np
from pet.utils import InputFeatures, InputExample, PLMInputFeatures
from pet.pvp import PVP, PVPS
class Preprocessor(ABC):
"""
A preprocessor that transforms an :class:`InputExample` into a :class:`InputFeatures` object so that it can be
processed by the model being used.
"""
def __init__(self, wrapper, task_name, pattern_id: int = 0, verbalizer_file: str = None):
"""
Create a new preprocessor.
:param wrapper: the wrapper for the language model to use
:param task_name: the name of the task
:param pattern_id: the id of the PVP to be used
:param verbalizer_file: path to a file containing a verbalizer that overrides the default verbalizer
"""
self.wrapper = wrapper
self.pvp = PVPS[task_name](self.wrapper, pattern_id, verbalizer_file) # type: PVP
self.label_map = {label: i for i, label in enumerate(self.wrapper.config.label_list)}
@abstractmethod
def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False,
**kwargs) -> InputFeatures:
"""Convert the given example into a set of input features"""
pass
class MLMPreprocessor(Preprocessor):
"""Preprocessor for models pretrained using a masked language modeling objective (e.g., BERT)."""
def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False,
**kwargs) -> InputFeatures:
if priming:
input_ids, token_type_ids = self.pvp.encode(example, priming=True)
priming_data = example.meta['priming_data'] # type: List[InputExample]
priming_input_ids = []
for priming_example in priming_data:
pe_input_ids, _ = self.pvp.encode(priming_example, priming=True, labeled=True)
priming_input_ids += pe_input_ids
input_ids = priming_input_ids + input_ids
token_type_ids = self.wrapper.tokenizer.create_token_type_ids_from_sequences(input_ids)
input_ids = self.wrapper.tokenizer.build_inputs_with_special_tokens(input_ids)
else:
input_ids, token_type_ids = self.pvp.encode(example)
attention_mask = [1] * len(input_ids)
padding_length = self.wrapper.config.max_seq_length - len(input_ids)
if padding_length < 0:
raise ValueError(f"Maximum sequence length is too small, got {len(input_ids)} input ids")
input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
token_type_ids = token_type_ids + ([0] * padding_length)
assert len(input_ids) == self.wrapper.config.max_seq_length
assert len(attention_mask) == self.wrapper.config.max_seq_length
assert len(token_type_ids) == self.wrapper.config.max_seq_length
label = self.label_map[example.label] if example.label is not None else -100
logits = example.logits if example.logits else [-1]
if labelled:
mlm_labels = self.pvp.get_mask_positions(input_ids)
if self.wrapper.config.model_type == 'gpt2':
# shift labels to the left by one
mlm_labels.append(mlm_labels.pop(0))
else:
mlm_labels = [-1] * self.wrapper.config.max_seq_length
return InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
label=label, mlm_labels=mlm_labels, logits=logits, idx=example.idx)
class PLMPreprocessor(MLMPreprocessor):
"""Preprocessor for models pretrained using a permuted language modeling objective (e.g., XLNet)."""
def get_input_features(self, example: InputExample, labelled: bool, priming: bool = False,
**kwargs) -> PLMInputFeatures:
input_features = super().get_input_features(example, labelled, priming, **kwargs)
input_ids = input_features.input_ids
num_masks = 1 # currently, PLMPreprocessor supports only replacements that require exactly one mask
perm_mask = np.zeros((len(input_ids), len(input_ids)), dtype=np.float)
label_idx = input_ids.index(self.pvp.mask_id)
perm_mask[:, label_idx] = 1 # the masked token is not seen by any other token
target_mapping = np.zeros((num_masks, len(input_ids)), dtype=np.float)
target_mapping[0, label_idx] = 1.0
return PLMInputFeatures(perm_mask=perm_mask, target_mapping=target_mapping, **input_features.__dict__)
class SequenceClassifierPreprocessor(Preprocessor):
"""Preprocessor for a regular sequence classification model."""
def get_input_features(self, example: InputExample, **kwargs) -> InputFeatures:
inputs = self.wrapper.task_helper.get_sequence_classifier_inputs(example) if self.wrapper.task_helper else None
if inputs is None:
inputs = self.wrapper.tokenizer.encode_plus(
example.text_a if example.text_a else None,
example.text_b if example.text_b else None,
add_special_tokens=True,
max_length=self.wrapper.config.max_seq_length,
)
input_ids, token_type_ids = inputs["input_ids"], inputs.get("token_type_ids")
attention_mask = [1] * len(input_ids)
padding_length = self.wrapper.config.max_seq_length - len(input_ids)
input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
if not token_type_ids:
token_type_ids = [0] * self.wrapper.config.max_seq_length
else:
token_type_ids = token_type_ids + ([0] * padding_length)
mlm_labels = [-1] * len(input_ids)
assert len(input_ids) == self.wrapper.config.max_seq_length
assert len(attention_mask) == self.wrapper.config.max_seq_length
assert len(token_type_ids) == self.wrapper.config.max_seq_length
label = self.label_map[example.label] if example.label is not None else -100
logits = example.logits if example.logits else [-1]
return InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
label=label, mlm_labels=mlm_labels, logits=logits, idx=example.idx)
home = /usr/local/bin
include-system-site-packages = false
version = 3.6.9
-f https://download.pytorch.org/whl/torch_stable.html
numpy==1.19
jsonpickle==1.1
scikit-learn==0.23.1
torch===1.5.0
torchvision==0.6.0
transformers==3.0.2
tqdm==4.48.1
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
date
python3 cli.py \
--method pet \
--pattern_ids 0 1 2 3 4 \
--data_dir SST-2.300 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name autobest5 \
--output_dir autobest5_pet2 \
--do_train \
--do_eval
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 1 2 3 4 \
--data_dir ../data/SST-2 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name autobest5 \
--output_dir autobest5_pet3 \
--do_train \
--do_eval \
--eval_set test && \
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 1 2 3 4 \
--data_dir ../data/SST-2.1000 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name autobest5 \
--output_dir autobest5_pet4 \
--do_train \
--do_eval \
--eval_set test && \
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method ipet \
--pattern_ids 0 1 2 3 4 \
--data_dir ../data/SST-2.1000 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name autobest5 \
--output_dir autobest5_ipet5 \
--do_train \
--do_eval \
--eval_set test && \
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 1 2 3 4 \
--data_dir SST-2.300 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name autobest5 \
--output_dir autobest5_pet \
--do_train \
--do_eval
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 \
--data_dir ../data/SST-2 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task2 \
--output_dir result_handcrafted_pattern2 \
--do_train \
--do_eval \
--eval_set test && \
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 \
--data_dir ../data/SST-2.1000 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task2 \
--output_dir result_handcrafted_pattern3 \
--do_train \
--do_eval \
--eval_set test && \
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method pet \
--pattern_ids 0 \
--data_dir SST-2.300 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task2 \
--output_dir result_handcrafted_pattern \
--do_train \
--do_eval
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
python3 cli.py \
--method pet \
--pattern_ids 0 \
--data_dir /home/mist/projects/LM-BFF/data/k-shot/SST-2/16-100 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task \
--output_dir results \
--do_train \
--do_eval
python3 cli.py \
--method ipet \
--pattern_ids 0 1 2 3 \
--data_dir /home/mist/projects/LM-BFF/data/k-shot/SST-2/16-100 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name lm-bff \
--output_dir results1 \
--do_train \
--do_eval
date
python3 cli.py \
--method pet \
--pattern_ids 0 1 2 3 4 5 \
--data_dir SST-2.300 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task \
--output_dir result_300 \
--do_train \
--do_eval
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
date
python3 cli.py \
--method ipet \
--pattern_ids 0 1 2 3 4 5 \
--data_dir SST-2.300 \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name my-task \
--output_dir result4_300 \
--do_train \
--do_eval
date
#--data_dir /data/projects/LM-BFF/data/k-shot/SST-2/16-100/ \
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment