task_helpers.py 31.9 KB
Newer Older
20210828028 committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from abc import ABC
from collections import defaultdict
from typing import Dict, List, Optional, Any
import torch
import re

import numpy as np
from torch.nn import CrossEntropyLoss

from pet.utils import InputFeatures, InputExample, get_verbalization_ids, chunks, trim_input_ids, remove_final_punc, \
    lowercase_first


class TaskHelper(ABC):
    """
    A helper class that provides custom training and evaluation methods for tasks that do not fit in PETs default
    schema, for example because they require more than two sequences of text, different evaluation metrics or
    verbalizers consisting of multiple tokens.
    """

    def __init__(self, wrapper):
        """
        Create a new task helper.

        :param wrapper: The wrapper for the language model being used.
        """
        self.wrapper = wrapper
        self.output = None

    def train_step(self, batch: Dict[str, torch.Tensor], **kwargs) -> Optional[torch.Tensor]:
        """
        Custom implementation of the train step for this task.

        :param batch: a batch of examples
        :return: a scalar loss tensor
        """
        pass

    def eval_step(self, batch: Dict[str, torch.Tensor], **kwargs) -> Optional[torch.Tensor]:
        """
        Custom implementation of the eval step for this task.

        :param batch: a batch of examples
        :return: a tensor of logits
        """
        pass

    def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
        """
        Add special features to the ``meta`` dictionary of a feature set

        :param input_example: the input example considered
        :param input_features: the set of features corresponding to this example
        """

        pass

    def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
        """
        Add special features from the ``meta`` dictionary of a sequence of features to the corresponding dictionary

        :param features: the sequence of features
        :param feature_dict: the dictionary that stores aggregated feature views as tensors
        """
        pass

    def get_sequence_classifier_inputs(self, example: InputExample) -> Dict[str, Any]:
        """
        Get the inputs for sequence classification. Override this method if the input for the task considered is of a
        more complicated form than `text_a` or `text_a [SEP] text_b`.

        :param example: the input example
        :return: the dictionary of inputs
        """
        pass


class MultiMaskTaskHelper(TaskHelper):
    """A custom task helper for classification datasets where multiple masks are required for one or more verbalizers."""

    def train_step(self, batch, **kwargs) -> Optional[torch.Tensor]:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        assert self.wrapper.config.wrapper_type == 'mlm', 'train_step() for MultiMaskTaskHelper is only implemented for MLM models'
        inputs = self.wrapper.generate_default_inputs(batch)
        loss_fct = CrossEntropyLoss(reduction='none')

        # prediction_scores.shape == max_seq_len x vocab_size x batch_size
        prediction_scores = self.wrapper.model(**inputs)[0].permute(1, 2, 0)

        # all_choice_token_ids.shape == batch_size x num_choices x max_seq_len
        all_choice_token_ids = batch['choice_token_ids']
        batch_size, num_choices, max_seq_len = all_choice_token_ids.shape

        # all_candidate_labels.shape() == batch_size
        all_labels = batch['labels']

        # correct_choice_token_ids.shape == max_seq_len x batch_size
        correct_choice_token_ids = all_choice_token_ids[torch.arange(batch_size), all_labels].permute(1, 0)

        wrong_choices_mask = torch.ones_like(all_choice_token_ids)
        wrong_choices_mask.scatter_(1, all_labels.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, max_seq_len), 0)
        wrong_choices_token_ids = all_choice_token_ids[wrong_choices_mask.bool()].view(batch_size, num_choices - 1, max_seq_len)

        # wrong_choices_token_ids.shape == (num_choices - 1) x max_seq_len x batch_size
        wrong_choices_token_ids = wrong_choices_token_ids.permute(1, 2, 0)

        total_loss = 0

        # loss_correct_label.shape == batch_size
        loss_correct_choice = loss_fct(prediction_scores, correct_choice_token_ids).sum(dim=0)

        # compute hinge loss
        for wrong_choice_token_ids in wrong_choices_token_ids:
            loss_wrong_choice = loss_fct(prediction_scores, wrong_choice_token_ids).sum(dim=0)
            hinge_loss = 1 + loss_correct_choice - loss_wrong_choice
            hinge_loss[hinge_loss < 0] = 0
            total_loss += hinge_loss

        return total_loss.mean()

    def eval_step(self, batch: Dict[str, torch.Tensor], batch_size: int = 8, decoding_strategy: str = 'default'):
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        assert self.wrapper.config.wrapper_type == 'mlm', 'eval_step() for MultiMaskTaskHelper is only implemented for MLM models'
        assert batch['input_ids'].shape[0] == 1, "eval_step() for MultiMaskTaskHelper is only implemented for batch_size=1"

        all_choice_token_ids = batch['choice_token_ids'][0]
        log_probabilities = torch.tensor([[-math.inf] * len(all_choice_token_ids)], dtype=torch.float, device=all_choice_token_ids.device)

        # group choices by length to speed up decoding
        choices_grouped_by_length = defaultdict(list)

        for idx, choice_token_ids in enumerate(all_choice_token_ids):
            num_masks = sum(1 for x in choice_token_ids if x != -100)
            choices_grouped_by_length[num_masks].append((idx, choice_token_ids))

        input_ids = {}
        initial_outputs = {}

        for num_masks in choices_grouped_by_length.keys():
            # modify the input ids to contain the correct number of masks
            input_ids[num_masks] = trim_input_ids(batch['input_ids'], num_masks=num_masks,
                                                  pad_token_id=self.wrapper.tokenizer.pad_token_id,
                                                  mask_token_id=self.wrapper.tokenizer.mask_token_id)

            initial_outputs[num_masks] = self.wrapper.model(input_ids[num_masks])

        for num_masks, choices_with_labels in choices_grouped_by_length.items():

            for batch in chunks(choices_with_labels, batch_size):
                batch_input_ids = input_ids[num_masks].repeat(len(batch), 1)
                choice_token_ids = torch.stack([choice_token_ids for idx, choice_token_ids in batch])

                batch_probabilities = self._get_choice_probabilities_batched(choice_token_ids, batch_input_ids, initial_outputs[num_masks],
                                                                             decoding_strategy=decoding_strategy)

                for batch_idx, (idx, choice_token_ids) in enumerate(batch):
                    log_probabilities[0][idx] = batch_probabilities[batch_idx]

        return log_probabilities

    def _get_choice_probabilities_batched(self, target_sequences, input_ids, initial_output, decoding_strategy):

        log_probabilities = defaultdict(list)
        first_call = True

        while True:
            masks = {batch_idx: [(idx, tok) for idx, tok in enumerate(target_sequences[batch_idx]) if tok >= 0] for
                     batch_idx in range(len(target_sequences))}

            if not masks[0]:  # there are no masks left to process, we are done
                break

            if first_call:
                outputs = initial_output
            else:
                outputs = self.wrapper.model(input_ids)

            next_token_logits = outputs[0]
            next_token_logits = torch.nn.Softmax(dim=2)(next_token_logits)

            if decoding_strategy == 'ltr':
                masks = {batch_idx: [batch_masks[0]] for batch_idx, batch_masks in masks.items()}

            for batch_idx in range(len(target_sequences)):

                ntl = next_token_logits[batch_idx] if not first_call else next_token_logits[0]

                if decoding_strategy == 'parallel':
                    for m_pos, m_id in masks[batch_idx]:
                        log_probabilities[batch_idx].append(math.log(ntl[m_pos][m_id].item()))
                        target_sequences[batch_idx][m_pos] = -100

                else:
                    mask_pos, masked_id = None, None
                    highest_prob = None
                    for m_pos, m_id in masks[batch_idx]:
                        m_prob = ntl[m_pos][m_id]
                        if highest_prob is None or m_prob > highest_prob:
                            highest_prob = m_prob
                            mask_pos, masked_id = m_pos, m_id

                    log_probabilities[batch_idx].append(math.log(ntl[mask_pos][masked_id].item()))
                    input_ids[batch_idx][mask_pos] = masked_id
                    target_sequences[batch_idx][mask_pos] = -100

            first_call = False

        return {batch_idx: sum(log_prob for log_prob in log_probabilities[batch_idx]) for batch_idx in
                range(len(target_sequences))}

    def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)

        if 'choices' in input_example.meta:
            choices = [choice for choice in input_example.meta['choices']]
        else:
            label_list = self.wrapper.config.label_list
            choices = [self.wrapper.preprocessor.pvp.verbalize(label)[0] for label in label_list]

        input_features.meta['choice_token_ids'] = []

        for idx, choice_text in enumerate(choices):
            choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False)
            mask_end = mask_start + len(choice_token_ids)
            candidate_token_ids = [-100] * len(input_features.input_ids)
            candidate_token_ids[mask_start:mask_end] = choice_token_ids
            input_features.meta['choice_token_ids'].append(candidate_token_ids)

    def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return
        
        max_num_choices = max(len(f.meta['choice_token_ids']) for f in features)
        for feature in features:
            if len(feature.meta['choice_token_ids']) != max_num_choices:
                raise ValueError(f"The number of output choices must be identical for all examples, got "
                                 f"{len(feature.meta['choice_token_ids'])} and {max_num_choices}")

        feature_dict['choice_token_ids'] = torch.tensor([f.meta['choice_token_ids'] for f in features], dtype=torch.long)


class WicTaskHelper(TaskHelper):
    """A custom task helper for the WiC dataset."""

    def get_sequence_classifier_inputs(self, example: InputExample) -> Dict[str, Any]:
        text_a = example.meta['word'] + ': ' + example.text_a
        return self.wrapper.tokenizer.encode_plus(text_a, example.text_b, add_special_tokens=True,
                                                  max_length=self.wrapper.config.max_seq_length)


class MultiRcTaskHelper(TaskHelper):
    """A custom task helper for the MultiRC dataset."""

    def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
        input_features.meta['question_idx'] = input_example.meta['question_idx']

    def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
        feature_dict['question_idx'] = torch.tensor([f.meta['question_idx'] for f in features], dtype=torch.long)

    def get_sequence_classifier_inputs(self, example: InputExample) -> Dict[str, Any]:
        text_a = example.text_a
        text_b = ' '.join([example.text_b, self.wrapper.tokenizer.sep_token, example.meta['answer']])

        return self.wrapper.tokenizer.encode_plus(text_a, text_b, add_special_tokens=True,
                                                  max_length=self.wrapper.config.max_seq_length)


class CopaTaskHelper(TaskHelper):
    """A custom task helper for the COPA dataset."""

    def get_sequence_classifier_inputs(self, example: InputExample) -> Dict[str, Any]:
        premise = remove_final_punc(example.text_a)
        choice1, choice2 = lowercase_first(example.meta['choice1']), lowercase_first(example.meta['choice2'])
        question = example.meta['question']
        joiner = 'because' if question == 'cause' else 'so'
        text_a, text_b = ' '.join([premise, joiner, choice1]), ' '.join([premise, joiner, choice2])
        return self.wrapper.tokenizer.encode_plus(text_a, text_b, add_special_tokens=True,
                                                  max_length=self.wrapper.config.max_seq_length)

    def train_step(self, batch, **kwargs) -> Optional[torch.Tensor]:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        assert self.wrapper.config.wrapper_type == 'mlm', 'train_step() for COPA is only implemented for MLM models'

        inputs = self.wrapper.generate_default_inputs(batch)
        mask = batch['labels'].unsqueeze(1)
        correct_targets = batch['choice1_token_ids'] * (1 - mask) + batch['choice2_token_ids'] * mask
        wrong_targets = batch['choice1_token_ids'] * mask + batch['choice2_token_ids'] * (1 - mask)

        prediction_scores = self.wrapper.model(**inputs)[0].view(-1, self.wrapper.model.config.vocab_size)
        loss_fct = CrossEntropyLoss()

        loss_correct_label = loss_fct(prediction_scores, correct_targets.view(-1))
        loss_wrong_label = loss_fct(prediction_scores, wrong_targets.view(-1))
        loss = 1 + loss_correct_label - loss_wrong_label
        loss[loss < 0] = 0
        return loss

    def eval_step(self, batch: Dict[str, torch.Tensor], decoding_strategy: str = 'default', **kwargs):
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        assert self.wrapper.config.wrapper_type == 'mlm', 'eval_step() for COPA is only implemented for MLM models'
        assert batch['input_ids'].shape[0] == 1, 'eval_step() for COPA is only implemented for batch_size=1'

        log_probs = []
        for choice in ['choice1', 'choice2']:
            labels = batch[f'{choice}_token_ids']
            log_prob = self._get_choice_log_probability(batch, labels, decoding_strategy=decoding_strategy)
            log_probs.append(log_prob)

        return torch.tensor([log_probs])

    def _get_choice_log_probability(self, batch, target_sequence, decoding_strategy: str = 'default'):
        # adjust the number of masks
        num_masks = sum(1 for tok_id in target_sequence[0] if tok_id != -100)
        input_ids = trim_input_ids(batch['input_ids'], num_masks=num_masks,
                                   pad_token_id=self.wrapper.tokenizer.pad_token_id,
                                   mask_token_id=self.wrapper.tokenizer.mask_token_id)

        log_probabilities = []

        while True:
            masks = [(idx, tok_id) for idx, tok_id in enumerate(target_sequence[0]) if tok_id != -100]
            if not masks:  # there are no masks left to process, we are done
                break

            outputs = self.wrapper.model(input_ids)
            next_token_logits = torch.nn.Softmax(dim=2)(outputs[0])[0]

            if decoding_strategy == 'ltr':
                mask_pos, masked_id = masks[0]
                max_prob = next_token_logits[mask_pos][masked_id].item()
            elif decoding_strategy == 'parallel':
                for m_pos, m_id in masks:
                    log_probabilities.append(math.log(next_token_logits[m_pos][m_id].item()))
                break
            else:
                mask_pos, masked_id = None, None
                max_prob = None
                for m_pos, m_id in masks:
                    m_prob = next_token_logits[m_pos][m_id].item()
                    if max_prob is None or m_prob > max_prob:
                        max_prob = m_prob
                        mask_pos, masked_id = m_pos, m_id

            log_probabilities.append(math.log(max_prob))
            input_ids[0][mask_pos] = masked_id
            target_sequence[0][mask_pos] = -100

        return sum(log_probabilities)

    def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)

        for choice in ['choice1', 'choice2']:
            choice_text = input_example.meta[choice]
            choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False)
            mask_end = mask_start + len(choice_token_ids)
            input_features.meta[f'{choice}_token_ids'] = [-100] * len(input_features.input_ids)
            input_features.meta[f'{choice}_token_ids'][mask_start:mask_end] = choice_token_ids

    def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        for choice in ['choice1', 'choice2']:
            feature_dict[f'{choice}_token_ids'] = torch.tensor(
                [f.meta[f'{choice}_token_ids'] for f in features], dtype=torch.long)


class WscTaskHelper(TaskHelper):
    """A custom task helper for the Wsc dataset."""

    def __init__(self, wrapper):
        super().__init__(wrapper)
        self.id_to_target = []

    def get_sequence_classifier_inputs(self, example: InputExample) -> Dict[str, Any]:
        target = example.meta['span1_text']
        pronoun_idx = example.meta['span2_index']

        # mark the pronoun with asterisks
        words_a = example.text_a.split()
        words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*'
        text_a = ' '.join(words_a)
        text_b = target

        return self.wrapper.tokenizer.encode_plus(text_a, text_b, add_special_tokens=True,
                                                  max_length=self.wrapper.config.max_seq_length)

    def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)
        num_masks = input_features.input_ids.count(self.wrapper.tokenizer.mask_token_id)
        mask_end = mask_start + num_masks

        target = input_example.meta['span1_text']
        input_features.meta['target'] = target
        target_token_ids = get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)
        input_features.meta['target_token_ids'] = [-100] * len(input_features.input_ids)

        # we also predict <pad> tokens at the missing positions
        target_token_ids += [self.wrapper.tokenizer.pad_token_id] * (num_masks - len(target_token_ids))
        input_features.meta['target_token_ids'][mask_start:mask_end] = target_token_ids

    def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        feature_dict['target_id'] = torch.tensor([len(self.id_to_target) + idx for idx, f in enumerate(features)],
                                                 dtype=torch.long)
        self.id_to_target += [f.meta['target'] for f in features]
        feature_dict['target_token_ids'] = torch.tensor([f.meta['target_token_ids'] for f in features],
                                                        dtype=torch.long)

    def train_step(self, batch, **kwargs) -> Optional[torch.Tensor]:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        assert self.wrapper.config.wrapper_type == 'mlm', 'train_step() for WSC is only implemented for MLM models'
        inputs = self.wrapper.generate_default_inputs(batch)
        inputs['labels'] = batch['target_token_ids']
        loss = self.wrapper.model(**inputs)[0]
        return loss

    def eval_step(self, batch: Dict[str, torch.Tensor], decoding_strategy: str = 'default', **kwargs):
        if self.wrapper.config.wrapper_type in ['sequence_classifier', 'span_pair_classifier']:
            return

        assert self.wrapper.config.wrapper_type == 'mlm', 'eval_step() for WSC is only implemented for MLM models'
        assert batch['input_ids'].shape[0] == 1, 'eval_step() for COPA is only implemented for batch_size=1'

        inputs = self.wrapper.generate_default_inputs(batch)
        input_ids = inputs['input_ids']

        orig_mask_positions = [
            idx for idx, input_id in enumerate(input_ids[0]) if input_id == self.wrapper.tokenizer.mask_token_id
        ]

        while True:
            mask_positions = [
                idx for idx, input_id in enumerate(input_ids[0]) if input_id == self.wrapper.tokenizer.mask_token_id
            ]
            if not mask_positions:  # there are no masks left to process, we are done
                input_ids = input_ids[0].detach().cpu().tolist()
                output_actual = self.wrapper.tokenizer.decode([
                    input_id for idx, input_id in enumerate(input_ids)
                    if idx in orig_mask_positions and input_id not in self.wrapper.tokenizer.all_special_ids
                ])

                output_expected = self.id_to_target[batch["target_id"][0].item()]

                # transform both outputs as described in the T5 paper
                output_actual = output_actual.lower().strip()
                output_actual = [w for w in re.split('[^a-zA-Z]', output_actual) if w]
                output_expected = output_expected.lower().strip()
                output_expected = [w for w in re.split('[^a-zA-Z]', output_expected) if w]

                # compare outputs
                if all(x in output_expected for x in output_actual) or all(
                        x in output_actual for x in output_expected):
                    return torch.tensor([[0, 1]])
                return torch.tensor([[1, 0]])

            outputs = self.wrapper.model(**inputs)
            next_token_logits = outputs[0]
            next_token_logits = torch.nn.Softmax(dim=2)(next_token_logits)
            next_token_logits = next_token_logits[0].detach().cpu().numpy()

            most_confident = ()
            most_confident_score = -1

            if decoding_strategy == 'ltr':
                mask_positions = [mask_positions[0]]

            k = 1
            for mask_position in mask_positions:
                ntl = next_token_logits[mask_position]
                top_token_id = np.argmax(ntl)
                top_score = ntl[top_token_id]

                if decoding_strategy == 'parallel':
                    input_ids[0][mask_position] = top_token_id

                elif top_score > most_confident_score:
                    most_confident_score = top_score
                    most_confident = (mask_position, top_token_id)

            if decoding_strategy == 'parallel':
                continue

            input_ids[0][most_confident[0]] = most_confident[1]


class RecordTaskHelper(TaskHelper):
    """A custom task helper for the ReCoRD dataset."""

    def __init__(self, wrapper):
        super().__init__(wrapper)
        self.output = []
        self.original_choices = {}

    def train_step(self, batch, **kwargs) -> Optional[torch.Tensor]:
        assert self.wrapper.config.wrapper_type == 'mlm', 'train_step() for ReCoRD is only implemented for MLM models'
        inputs = self.wrapper.generate_default_inputs(batch)

        prediction_scores = self.wrapper.model(**inputs)[0].view(-1, self.wrapper.model.config.vocab_size)
        loss_fct = CrossEntropyLoss()

        # all_candidate_token_ids.shape() == batch_size x max_num_candidates x max_seq_len
        all_candidate_token_ids = batch['candidate_token_ids']

        # all_candidate_labels.shape() == batch_size x max_num_candidates
        all_candidate_labels = batch['candidate_labels']

        all_candidate_token_ids = all_candidate_token_ids.permute(1, 0, 2)
        all_candidate_labels = all_candidate_labels.permute(1, 0)

        total_loss = 0
        loss_correct_label = loss_fct(prediction_scores, all_candidate_token_ids[0].view(-1))

        # compute hinge loss
        for candidate_token_ids, candidate_labels in zip(all_candidate_token_ids[1:], all_candidate_labels[1:]):
            loss_wrong_label = loss_fct(prediction_scores, candidate_token_ids.view(-1))
            hinge_loss = 1 + loss_correct_label - loss_wrong_label
            hinge_loss[hinge_loss < 0] = 0
            total_loss += hinge_loss

        return total_loss

    def eval_step(self, batch: Dict[str, torch.Tensor], batch_size: int = 8, decoding_strategy: str = 'default'):
        assert self.wrapper.config.wrapper_type == 'mlm', 'eval_step() for ReCoRD is only implemented for MLM models'
        assert batch['input_ids'].shape[0] == 1, "eval_step() for ReCoRD is only implemented for batch_size=1"

        best_choice_correct, best_choice, max_prob = False, None, None
        question_idx = batch['question_idx'][0].item()
        output_line = {'idx': question_idx, 'choices': {}}

        # group choices by length to speed up decoding
        choices_grouped_by_length = defaultdict(list)

        for idx, (choice_ids, label) in enumerate(zip(batch['candidate_token_ids'][0], batch['candidate_labels'][0])):
            if label < 0:
                continue
            num_masks = sum(1 for x in choice_ids if x != -100)
            choice = self.original_choices[question_idx][idx]
            choices_grouped_by_length[num_masks].append((choice, choice_ids, label))

        input_ids = {}
        initial_outputs = {}

        for num_masks in choices_grouped_by_length.keys():
            # modify the input ids to contain the correct number of masks
            input_ids[num_masks] = trim_input_ids(batch['input_ids'], num_masks=num_masks,
                                                  pad_token_id=self.wrapper.tokenizer.pad_token_id,
                                                  mask_token_id=self.wrapper.tokenizer.mask_token_id)

            initial_outputs[num_masks] = self.wrapper.model(input_ids[num_masks])

        for num_masks, choices_with_labels in choices_grouped_by_length.items():

            for batch in chunks(choices_with_labels, batch_size):
                batch_input_ids = input_ids[num_masks].repeat(len(batch), 1)
                choice_ids = torch.stack([choice_id for choice, choice_id, label in batch])

                probs = self._get_choice_probabilities_batched(choice_ids, batch_input_ids, initial_outputs[num_masks],
                                                               decoding_strategy=decoding_strategy)

                for idx, (choice, choice_ids, label) in enumerate(batch):
                    prob = probs[idx]
                    output_line['choices'][choice] = prob

                    if max_prob is None or prob > max_prob:
                        best_choice_correct, max_prob = (label == 1), prob

        self.output.append(output_line)

        if best_choice_correct:
            return torch.tensor([[0, 1]])
        return torch.tensor([[1, 0]])

    def _get_choice_probabilities_batched(self, target_sequences, input_ids, initial_output, decoding_strategy):

        log_probabilities = defaultdict(list)
        first_call = True

        while True:
            masks = {batch_idx: [(idx, tok) for idx, tok in enumerate(target_sequences[batch_idx]) if tok >= 0] for
                     batch_idx in range(len(target_sequences))}

            if not masks[0]:  # there are no masks left to process, we are done
                break

            if first_call:
                outputs = initial_output
            else:
                outputs = self.wrapper.model(input_ids)

            next_token_logits = outputs[0]
            next_token_logits = torch.nn.Softmax(dim=2)(next_token_logits)

            if decoding_strategy == 'ltr':
                masks = {batch_idx: [batch_masks[0]] for batch_idx, batch_masks in masks.items()}

            for batch_idx in range(len(target_sequences)):

                ntl = next_token_logits[batch_idx] if not first_call else next_token_logits[0]

                if decoding_strategy == 'parallel':
                    for m_pos, m_id in masks[batch_idx]:
                        log_probabilities[batch_idx].append(math.log(ntl[m_pos][m_id].item()))
                        target_sequences[batch_idx][m_pos] = -100

                else:
                    mask_pos, masked_id = None, None
                    highest_prob = None
                    for m_pos, m_id in masks[batch_idx]:
                        m_prob = ntl[m_pos][m_id]
                        if highest_prob is None or m_prob > highest_prob:
                            highest_prob = m_prob
                            mask_pos, masked_id = m_pos, m_id

                    log_probabilities[batch_idx].append(math.log(ntl[mask_pos][masked_id].item()))
                    input_ids[batch_idx][mask_pos] = masked_id
                    target_sequences[batch_idx][mask_pos] = -100

            first_call = False

        return {batch_idx: sum(log_prob for log_prob in log_probabilities[batch_idx]) for batch_idx in
                range(len(target_sequences))}

    def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
        mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)

        choices = input_example.meta['candidates']
        question_idx = input_example.meta['question_idx']

        input_features.meta['candidate_token_ids'] = []
        input_features.meta['candidate_labels'] = []
        input_features.meta['question_idx'] = question_idx

        self.original_choices[question_idx] = []

        for idx, choice_text in enumerate(choices):
            choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False)
            choice_label = 1 if choice_text in input_example.meta['answers'] else 0

            mask_end = mask_start + len(choice_token_ids)
            candidate_token_ids = [-100] * len(input_features.input_ids)
            candidate_token_ids[mask_start:mask_end] = choice_token_ids

            input_features.meta['candidate_token_ids'].append(candidate_token_ids)
            input_features.meta['candidate_labels'].append(choice_label)
            self.original_choices[question_idx].append(choice_text)

    def add_features_to_dict(self, features: List[InputFeatures], feature_dict: Dict[str, torch.Tensor]) -> None:
        # apply padding if necessary
        max_num_candidates = max(len(f.meta['candidate_token_ids']) for f in features)
        for feature in features:
            while len(feature.meta['candidate_token_ids']) < max_num_candidates:
                feature.meta['candidate_token_ids'].append([-100] * len(feature.input_ids))
                feature.meta['candidate_labels'].append(-100)

        feature_dict['candidate_token_ids'] = \
            torch.tensor([f.meta['candidate_token_ids'] for f in features], dtype=torch.long)
        feature_dict['candidate_labels'] = \
            torch.tensor([f.meta['candidate_labels'] for f in features], dtype=torch.long)

        feature_dict['question_idx'] = torch.tensor([f.meta['question_idx'] for f in features], dtype=torch.long)