# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This file contains the pattern-verbalizer pairs (PVPs) for all tasks. """ import random import string from abc import ABC, abstractmethod from collections import defaultdict from typing import Tuple, List, Union, Dict import torch from transformers import PreTrainedTokenizer, GPT2Tokenizer from pet.task_helpers import MultiMaskTaskHelper from pet.tasks import TASK_HELPERS from pet.utils import InputExample, get_verbalization_ids import log from pet import wrapper as wrp logger = log.get_logger('root') FilledPattern = Tuple[List[Union[str, Tuple[str, bool]]], List[Union[str, Tuple[str, bool]]]] class PVP(ABC): """ This class contains functions to apply patterns and verbalizers as required by PET. Each task requires its own custom implementation of a PVP. """ def __init__(self, wrapper, pattern_id: int = 0, verbalizer_file: str = None, seed: int = 42): """ Create a new PVP. :param wrapper: the wrapper for the underlying language model :param pattern_id: the pattern id to use :param verbalizer_file: an optional file that contains the verbalizer to be used :param seed: a seed to be used for generating random numbers if necessary """ self.wrapper = wrapper self.pattern_id = pattern_id self.rng = random.Random(seed) if verbalizer_file: self.verbalize = PVP._load_verbalizer_from_file(verbalizer_file, self.pattern_id) use_multimask = (self.wrapper.config.task_name in TASK_HELPERS) and ( issubclass(TASK_HELPERS[self.wrapper.config.task_name], MultiMaskTaskHelper) ) if not use_multimask and self.wrapper.config.wrapper_type in [wrp.MLM_WRAPPER, wrp.PLM_WRAPPER]: self.mlm_logits_to_cls_logits_tensor = self._build_mlm_logits_to_cls_logits_tensor() def _build_mlm_logits_to_cls_logits_tensor(self): label_list = self.wrapper.config.label_list m2c_tensor = torch.ones([len(label_list), self.max_num_verbalizers], dtype=torch.long) * -1 for label_idx, label in enumerate(label_list): verbalizers = self.verbalize(label) for verbalizer_idx, verbalizer in enumerate(verbalizers): verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True) assert verbalizer_id != self.wrapper.tokenizer.unk_token_id, "verbalization was tokenized as <UNK>" m2c_tensor[label_idx, verbalizer_idx] = verbalizer_id return m2c_tensor @property def mask(self) -> str: """Return the underlying LM's mask token""" return self.wrapper.tokenizer.mask_token @property def mask_id(self) -> int: """Return the underlying LM's mask id""" return self.wrapper.tokenizer.mask_token_id @property def max_num_verbalizers(self) -> int: """Return the maximum number of verbalizers across all labels""" return max(len(self.verbalize(label)) for label in self.wrapper.config.label_list) @staticmethod def shortenable(s): """Return an instance of this string that is marked as shortenable""" return s, True @staticmethod def remove_final_punc(s: Union[str, Tuple[str, bool]]): """Remove the final punctuation mark""" if isinstance(s, tuple): return PVP.remove_final_punc(s[0]), s[1] return s.rstrip(string.punctuation) @staticmethod def lowercase_first(s: Union[str, Tuple[str, bool]]): """Lowercase the first character""" if isinstance(s, tuple): return PVP.lowercase_first(s[0]), s[1] return s[0].lower() + s[1:] def encode(self, example: InputExample, priming: bool = False, labeled: bool = False) \ -> Tuple[List[int], List[int]]: """ Encode an input example using this pattern-verbalizer pair. :param example: the input example to encode :param priming: whether to use this example for priming :param labeled: if ``priming=True``, whether the label should be appended to this example :return: A tuple, consisting of a list of input ids and a list of token type ids """ if not priming: assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true" tokenizer = self.wrapper.tokenizer # type: PreTrainedTokenizer parts_a, parts_b = self.get_parts(example) kwargs = {'add_prefix_space': True} if isinstance(tokenizer, GPT2Tokenizer) else {} parts_a = [x if isinstance(x, tuple) else (x, False) for x in parts_a] parts_a = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_a if x] if parts_b: parts_b = [x if isinstance(x, tuple) else (x, False) for x in parts_b] parts_b = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_b if x] self.truncate(parts_a, parts_b, max_length=self.wrapper.config.max_seq_length) tokens_a = [token_id for part, _ in parts_a for token_id in part] tokens_b = [token_id for part, _ in parts_b for token_id in part] if parts_b else None if priming: input_ids = tokens_a if tokens_b: input_ids += tokens_b if labeled: mask_idx = input_ids.index(self.mask_id) assert mask_idx >= 0, 'sequence of input_ids must contain a mask token' assert len(self.verbalize(example.label)) == 1, 'priming only supports one verbalization per label' verbalizer = self.verbalize(example.label)[0] verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True) input_ids[mask_idx] = verbalizer_id return input_ids, [] input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b) token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b) return input_ids, token_type_ids @staticmethod def _seq_length(parts: List[Tuple[str, bool]], only_shortenable: bool = False): return sum([len(x) for x, shortenable in parts if not only_shortenable or shortenable]) if parts else 0 @staticmethod def _remove_last(parts: List[Tuple[str, bool]]): last_idx = max(idx for idx, (seq, shortenable) in enumerate(parts) if shortenable and seq) parts[last_idx] = (parts[last_idx][0][:-1], parts[last_idx][1]) def truncate(self, parts_a: List[Tuple[str, bool]], parts_b: List[Tuple[str, bool]], max_length: int): """Truncate two sequences of text to a predefined total maximum length""" total_len = self._seq_length(parts_a) + self._seq_length(parts_b) total_len += self.wrapper.tokenizer.num_special_tokens_to_add(bool(parts_b)) num_tokens_to_remove = total_len - max_length if num_tokens_to_remove <= 0: return parts_a, parts_b for _ in range(num_tokens_to_remove): if self._seq_length(parts_a, only_shortenable=True) > self._seq_length(parts_b, only_shortenable=True): self._remove_last(parts_a) else: self._remove_last(parts_b) @abstractmethod def get_parts(self, example: InputExample) -> FilledPattern: """ Given an input example, apply a pattern to obtain two text sequences (text_a and text_b) containing exactly one mask token (or one consecutive sequence of mask tokens for PET with multiple masks). If a task requires only a single sequence of text, the second sequence should be an empty list. :param example: the input example to process :return: Two sequences of text. All text segments can optionally be marked as being shortenable. """ pass @abstractmethod def verbalize(self, label) -> List[str]: """ Return all verbalizations for a given label. :param label: the label :return: the list of verbalizations """ pass def get_mask_positions(self, input_ids: List[int]) -> List[int]: label_idx = input_ids.index(self.mask_id) labels = [-1] * len(input_ids) labels[label_idx] = 1 return labels def convert_mlm_logits_to_cls_logits(self, mlm_labels: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: masked_logits = logits[mlm_labels >= 0] cls_logits = torch.stack([self._convert_single_mlm_logits_to_cls_logits(ml) for ml in masked_logits]) return cls_logits def _convert_single_mlm_logits_to_cls_logits(self, logits: torch.Tensor) -> torch.Tensor: m2c = self.mlm_logits_to_cls_logits_tensor.to(logits.device) # filler_len.shape() == max_fillers filler_len = torch.tensor([len(self.verbalize(label)) for label in self.wrapper.config.label_list], dtype=torch.float) filler_len = filler_len.to(logits.device) # cls_logits.shape() == num_labels x max_fillers (and 0 when there are not as many fillers). cls_logits = logits[torch.max(torch.zeros_like(m2c), m2c)] cls_logits = cls_logits * (m2c > 0).float() # cls_logits.shape() == num_labels cls_logits = cls_logits.sum(axis=1) / filler_len return cls_logits def convert_plm_logits_to_cls_logits(self, logits: torch.Tensor) -> torch.Tensor: assert logits.shape[1] == 1 logits = torch.squeeze(logits, 1) # remove second dimension as we always have exactly one <mask> per example cls_logits = torch.stack([self._convert_single_mlm_logits_to_cls_logits(lgt) for lgt in logits]) return cls_logits @staticmethod def _load_verbalizer_from_file(path: str, pattern_id: int): verbalizers = defaultdict(dict) # type: Dict[int, Dict[str, List[str]]] current_pattern_id = None with open(path, 'r') as fh: for line in fh.read().splitlines(): if line.isdigit(): current_pattern_id = int(line) elif line: label, *realizations = line.split() verbalizers[current_pattern_id][label] = realizations logger.info("Automatically loaded the following verbalizer: \n {}".format(verbalizers[pattern_id])) def verbalize(label) -> List[str]: return verbalizers[pattern_id][label] return verbalize class AgnewsPVP(PVP): VERBALIZER = { "1": ["World"], "2": ["Sports"], "3": ["Business"], "4": ["Tech"] } def get_parts(self, example: InputExample) -> FilledPattern: text_a = self.shortenable(example.text_a) text_b = self.shortenable(example.text_b) if self.pattern_id == 0: return [self.mask, ':', text_a, text_b], [] elif self.pattern_id == 1: return [self.mask, 'News:', text_a, text_b], [] elif self.pattern_id == 2: return [text_a, '(', self.mask, ')', text_b], [] elif self.pattern_id == 3: return [text_a, text_b, '(', self.mask, ')'], [] elif self.pattern_id == 4: return ['[ Category:', self.mask, ']', text_a, text_b], [] elif self.pattern_id == 5: return [self.mask, '-', text_a, text_b], [] else: raise ValueError("No pattern implemented for id {}".format(self.pattern_id)) def verbalize(self, label) -> List[str]: return AgnewsPVP.VERBALIZER[label] class YahooPVP(PVP): VERBALIZER = { "1": ["Society"], "2": ["Science"], "3": ["Health"], "4": ["Education"], "5": ["Computer"], "6": ["Sports"], "7": ["Business"], "8": ["Entertainment"], "9": ["Relationship"], "10": ["Politics"], } def get_parts(self, example: InputExample) -> FilledPattern: text_a = self.shortenable(example.text_a) text_b = self.shortenable(example.text_b) if self.pattern_id == 0: return [self.mask, ':', text_a, text_b], [] elif self.pattern_id == 1: return [self.mask, 'Question:', text_a, text_b], [] elif self.pattern_id == 2: return [text_a, '(', self.mask, ')', text_b], [] elif self.pattern_id == 3: return [text_a, text_b, '(', self.mask, ')'], [] elif self.pattern_id == 4: return ['[ Category:', self.mask, ']', text_a, text_b], [] elif self.pattern_id == 5: return [self.mask, '-', text_a, text_b], [] else: raise ValueError("No pattern implemented for id {}".format(self.pattern_id)) def verbalize(self, label) -> List[str]: return YahooPVP.VERBALIZER[label] class MnliPVP(PVP): VERBALIZER_A = { "contradiction": ["Wrong"], "entailment": ["Right"], "neutral": ["Maybe"] } VERBALIZER_B = { "contradiction": ["No"], "entailment": ["Yes"], "neutral": ["Maybe"] } def get_parts(self, example: InputExample) -> FilledPattern: text_a = self.shortenable(self.remove_final_punc(example.text_a)) text_b = self.shortenable(example.text_b) if self.pattern_id == 0 or self.pattern_id == 2: return ['"', text_a, '" ?'], [self.mask, ', "', text_b, '"'] elif self.pattern_id == 1 or self.pattern_id == 3: return [text_a, '?'], [self.mask, ',', text_b] def verbalize(self, label) -> List[str]: if self.pattern_id == 0 or self.pattern_id == 1: return MnliPVP.VERBALIZER_A[label] return MnliPVP.VERBALIZER_B[label] class YelpPolarityPVP(PVP): VERBALIZER = { "1": ["bad"], "2": ["good"] } def get_parts(self, example: InputExample) -> FilledPattern: text = self.shortenable(example.text_a) if self.pattern_id == 0: return ['It was', self.mask, '.', text], [] elif self.pattern_id == 1: return [text, '. All in all, it was', self.mask, '.'], [] elif self.pattern_id == 2: return ['Just', self.mask, "!"], [text] elif self.pattern_id == 3: return [text], ['In summary, the restaurant is', self.mask, '.'] else: raise ValueError("No pattern implemented for id {}".format(self.pattern_id)) def verbalize(self, label) -> List[str]: return YelpPolarityPVP.VERBALIZER[label] class YelpFullPVP(YelpPolarityPVP): VERBALIZER = { "1": ["terrible"], "2": ["bad"], "3": ["okay"], "4": ["good"], "5": ["great"] } def verbalize(self, label) -> List[str]: return YelpFullPVP.VERBALIZER[label] class XStancePVP(PVP): VERBALIZERS = { 'en': {"FAVOR": ["Yes"], "AGAINST": ["No"]}, 'de': {"FAVOR": ["Ja"], "AGAINST": ["Nein"]}, 'fr': {"FAVOR": ["Oui"], "AGAINST": ["Non"]} } def get_parts(self, example: InputExample) -> FilledPattern: text_a = self.shortenable(example.text_a) text_b = self.shortenable(example.text_b) if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4: return ['"', text_a, '"'], [self.mask, '. "', text_b, '"'] elif self.pattern_id == 1 or self.pattern_id == 3 or self.pattern_id == 5: return [text_a], [self.mask, '.', text_b] def verbalize(self, label) -> List[str]: lang = 'de' if self.pattern_id < 2 else 'en' if self.pattern_id < 4 else 'fr' return XStancePVP.VERBALIZERS[lang][label] class RtePVP(PVP): VERBALIZER = { "not_entailment": ["No"], "entailment": ["Yes"] } def get_parts(self, example: InputExample) -> FilledPattern: # switch text_a and text_b to get the correct order text_a = self.shortenable(example.text_a) text_b = self.shortenable(example.text_b.rstrip(string.punctuation)) if self.pattern_id == 0: return ['"', text_b, '" ?'], [self.mask, ', "', text_a, '"'] elif self.pattern_id == 1: return [text_b, '?'], [self.mask, ',', text_a] if self.pattern_id == 2: return ['"', text_b, '" ?'], [self.mask, '. "', text_a, '"'] elif self.pattern_id == 3: return [text_b, '?'], [self.mask, '.', text_a] elif self.pattern_id == 4: return [text_a, ' question: ', self.shortenable(example.text_b), ' True or False? answer:', self.mask], [] def verbalize(self, label) -> List[str]: if self.pattern_id == 4: return ['true'] if label == 'entailment' else ['false'] return RtePVP.VERBALIZER[label] class CbPVP(RtePVP): VERBALIZER = { "contradiction": ["No"], "entailment": ["Yes"], "neutral": ["Maybe"] } def get_parts(self, example: InputExample) -> FilledPattern: if self.pattern_id == 4: text_a = self.shortenable(example.text_a) text_b = self.shortenable(example.text_b) return [text_a, ' question: ', text_b, ' true, false or neither? answer:', self.mask], [] return super().get_parts(example) def verbalize(self, label) -> List[str]: if self.pattern_id == 4: return ['true'] if label == 'entailment' else ['false'] if label == 'contradiction' else ['neither'] return CbPVP.VERBALIZER[label] class CopaPVP(PVP): def get_parts(self, example: InputExample) -> FilledPattern: premise = self.remove_final_punc(self.shortenable(example.text_a)) choice1 = self.remove_final_punc(self.lowercase_first(example.meta['choice1'])) choice2 = self.remove_final_punc(self.lowercase_first(example.meta['choice2'])) question = example.meta['question'] assert question in ['cause', 'effect'] example.meta['choice1'], example.meta['choice2'] = choice1, choice2 num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in [choice1, choice2]) if question == 'cause': if self.pattern_id == 0: return ['"', choice1, '" or "', choice2, '"?', premise, 'because', self.mask * num_masks, '.'], [] elif self.pattern_id == 1: return [choice1, 'or', choice2, '?', premise, 'because', self.mask * num_masks, '.'], [] else: if self.pattern_id == 0: return ['"', choice1, '" or "', choice2, '"?', premise, ', so', self.mask * num_masks, '.'], [] elif self.pattern_id == 1: return [choice1, 'or', choice2, '?', premise, ', so', self.mask * num_masks, '.'], [] def verbalize(self, label) -> List[str]: return [] class WscPVP(PVP): def get_parts(self, example: InputExample) -> FilledPattern: pronoun = example.meta['span2_text'] target = example.meta['span1_text'] pronoun_idx = example.meta['span2_index'] words_a = example.text_a.split() words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*' text_a = ' '.join(words_a) text_a = self.shortenable(text_a) num_pad = self.rng.randint(0, 3) if 'train' in example.guid else 1 num_masks = len(get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)) + num_pad masks = self.mask * num_masks if self.pattern_id == 0: return [text_a, "The pronoun '*" + pronoun + "*' refers to", masks + '.'], [] elif self.pattern_id == 1: return [text_a, "In the previous sentence, the pronoun '*" + pronoun + "*' refers to", masks + '.'], [] elif self.pattern_id == 2: return [text_a, "Question: In the passage above, what does the pronoun '*" + pronoun + "*' refer to? Answer: ", masks + '.'], [] def verbalize(self, label) -> List[str]: return [] class BoolQPVP(PVP): VERBALIZER_A = { "False": ["No"], "True": ["Yes"] } VERBALIZER_B = { "False": ["false"], "True": ["true"] } def get_parts(self, example: InputExample) -> FilledPattern: passage = self.shortenable(example.text_a) question = self.shortenable(example.text_b) if self.pattern_id < 2: return [passage, '. Question: ', question, '? Answer: ', self.mask, '.'], [] elif self.pattern_id < 4: return [passage, '. Based on the previous passage, ', question, '?', self.mask, '.'], [] else: return ['Based on the following passage, ', question, '?', self.mask, '.', passage], [] def verbalize(self, label) -> List[str]: if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4: return BoolQPVP.VERBALIZER_A[label] else: return BoolQPVP.VERBALIZER_B[label] class MultiRcPVP(PVP): VERBALIZER = { "0": ["No"], "1": ["Yes"] } def get_parts(self, example: InputExample) -> FilledPattern: passage = self.shortenable(example.text_a) question = example.text_b answer = example.meta['answer'] if self.pattern_id == 0: return [passage, '. Question: ', question, '? Is it ', answer, '?', self.mask, '.'], [] if self.pattern_id == 1: return [passage, '. Question: ', question, '? Is the correct answer "', answer, '"?', self.mask, '.'], [] if self.pattern_id == 2: return [passage, '. Based on the previous passage, ', question, '? Is "', answer, '" a correct answer?', self.mask, '.'], [] if self.pattern_id == 3: return [passage, question, '- [', self.mask, ']', answer], [] def verbalize(self, label) -> List[str]: if self.pattern_id == 3: return ['False'] if label == "0" else ['True'] return MultiRcPVP.VERBALIZER[label] class WicPVP(PVP): VERBALIZER_A = { "F": ["No"], "T": ["Yes"] } VERBALIZER_B = { "F": ["2"], "T": ["b"] } def get_parts(self, example: InputExample) -> FilledPattern: text_a = self.shortenable(example.text_a) text_b = self.shortenable(example.text_b) word = example.meta['word'] if self.pattern_id == 0: return ['"', text_a, '" / "', text_b, '" Similar sense of "' + word + '"?', self.mask, '.'], [] if self.pattern_id == 1: return [text_a, text_b, 'Does ' + word + ' have the same meaning in both sentences?', self.mask], [] if self.pattern_id == 2: return [word, ' . Sense (1) (a) "', text_a, '" (', self.mask, ') "', text_b, '"'], [] def verbalize(self, label) -> List[str]: if self.pattern_id == 2: return WicPVP.VERBALIZER_B[label] return WicPVP.VERBALIZER_A[label] class RecordPVP(PVP): def get_parts(self, example: InputExample) -> FilledPattern: premise = self.shortenable(example.text_a) choices = example.meta['candidates'] assert '@placeholder' in example.text_b, f'question "{example.text_b}" does not contain a @placeholder token' num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in choices) question = example.text_b.replace('@placeholder', self.mask * num_masks) return [premise, question], [] def verbalize(self, label) -> List[str]: return [] PVPS = { 'agnews': AgnewsPVP, 'mnli': MnliPVP, 'yelp-polarity': YelpPolarityPVP, 'yelp-full': YelpFullPVP, 'yahoo': YahooPVP, 'xstance': XStancePVP, 'xstance-de': XStancePVP, 'xstance-fr': XStancePVP, 'rte': RtePVP, 'wic': WicPVP, 'cb': CbPVP, 'wsc': WscPVP, 'boolq': BoolQPVP, 'copa': CopaPVP, 'multirc': MultiRcPVP, 'record': RecordPVP, 'ax-b': RtePVP, 'ax-g': RtePVP, } class MyTaskPVP(PVP): """ Example for a pattern-verbalizer pair (PVP). """ # Set this to the name of the task TASK_NAME = "my-task" # Set this to the verbalizer for the given task: a mapping from the task's labels (which can be obtained using # the corresponding DataProcessor's get_labels method) to tokens from the language model's vocabulary #VERBALIZER = {"+1": ["Good"],"-1": ["Bad"],} VERBALIZER = {1: ["fascinating",'effective',"irresistible","thrilling"], # "breathtaking", 0: ["boring","embarrassing","weird","nothing"],} #"depressing", def get_parts(self, example: InputExample): text_a = self.shortenable(example.text_a) if self.pattern_id == 0: #*cls*_A*mask*_idea.*+sent_0**sep+* 93 {"0": "boring", "1": "fascinating"} return ['A',self.mask,'idea.',text_a],[] if self.pattern_id == 1: # *cls*_A*mask*_show.*+sent_0**sep+* 95 {"0": "embarrassing", "1": "effective"} return ['A',self.mask,'show.',text_a],[] if self.pattern_id == 2: # *cls*_The_story_is*mask*.*+sent_0**sep+* 81 {"0": "depressing", "1": "irresistible"} return ['The story is',self.mask,'.',text_a],[] if self.pattern_id == 3: # *cls**sent_0*_The_result_is*mask*.*sep+* 49 {"0": "weird", "1": "breathtaking"} return [text_a,'The result is', self.mask,'.'],[] if self.pattern_id == 4: # *cls**sent_0*_Very*mask*!*sep+* 30 {"0": "embarrassing", "1": "thrilling"} return [text_a,'Very',self.mask,'!'],[] if self.pattern_id == 5: # *cls**sent_0*_This_is*mask*.*sep+* 7 {"0": "nothing", "1": "thrilling"} return [text_a,'This is',self.mask,'.'],[] def verbalize(self, label) -> List[str]: return MyTaskPVP.VERBALIZER[label] class MyTaskPVP2(PVP): """ Example for a pattern-verbalizer pair (PVP). """ # Set this to the name of the task TASK_NAME = "my-task2" # Set this to the verbalizer for the given task: a mapping from the task's labels (which can be obtained using # the corresponding DataProcessor's get_labels method) to tokens from the language model's vocabulary #VERBALIZER = {"+1": ["Good"],"-1": ["Bad"],} VERBALIZER = {1: ["Good"], # "breathtaking", 0: ["Bad"],} #"depressing", def get_parts(self, example: InputExample): text_a = self.shortenable(example.text_a) return [text_a,'It is',self.mask,'.'],[] def verbalize(self, label) -> List[str]: return MyTaskPVP.VERBALIZER[label] class AutoBest5(PVP): """ Example for a pattern-verbalizer pair (PVP). """ # Set this to the name of the task TASK_NAME = "autobest5" # Set this to the verbalizer for the given task: a mapping from the task's labels (which can be obtained using # the corresponding DataProcessor's get_labels method) to tokens from the language model's vocabulary #VERBALIZER = {"+1": ["Good"],"-1": ["Bad"],} VERBALIZER = {1: ["good"], # "breathtaking", 0: ["bad"],} #"depressing", def get_parts(self, example: InputExample): text_a = self.shortenable(example.text_a) ''' 0.89967 *cls*_A*mask*_film!*+sent_0**sep+* 0.89298 *cls**sent_0*_This_is_really*mask*.*sep+* 0.89298 *cls*_It_is*mask*!*+sent_0**sep+* 0.89298 *cls*_This_movie_is*mask*.*+sent_0**sep+* 0.88963 *cls*_Just*mask*.*+sent_0**sep+* 0.88963 *cls*_The_movie_is*mask*.*+sent_0**sep+* 0.88963 *cls*_This_film_is*mask*.*+sent_0**sep+* 0.88629 *cls**sent_0*_It's*mask*.*sep+* 0.88629 *cls**sent_0*_The_movie_is*mask*.*sep+* 0.88629 *cls**sent_0*_This_is_just*mask*.*sep+* ''' if self.pattern_id == 0: return ['A',self.mask,'film.',text_a],[] if self.pattern_id == 1: return [text_a,'This is really',self.mask,'.'],[] if self.pattern_id == 2: return ['It is',self.mask,'!',text_a],[] if self.pattern_id == 3: return ['The movie is',self.mask,'.',text_a],[] # *cls**sent_0*_The_result_is*mask*.*sep+* 49 {"0": "weird", "1": "breathtaking"} #return [text_a,'The result is', self.mask,'.'],[] if self.pattern_id == 4: return ['Just',self.mask,'.',text_a],[] # *cls**sent_0*_Very*mask*!*sep+* 30 {"0": "embarrassing", "1": "thrilling"} #return [text_a,'Very',self.mask,'!'],[] #if self.pattern_id == 5: def verbalize(self, label) -> List[str]: return MyTaskPVP.VERBALIZER[label] # register the PVP for this task with its name PVPS[MyTaskPVP.TASK_NAME] = MyTaskPVP PVPS[MyTaskPVP2.TASK_NAME] = MyTaskPVP2 PVPS[AutoBest5.TASK_NAME] = AutoBest5