0816XLNet_for_down_stream_tasks_by_Qiao.ipynb 62.5 KB

利用Huggingface实现的预训练语言模型做下游任务

by Qiao for NLP7 2020-8-16

预训练语言模型的用法:

  1. 作为特征提取器
  2. 作为encoder参与下游任务微调 使用上非常类似,差别是后者在训练过程中原预训练语言模型的参数也允许优化。

主要内容:

  1. 以XLNet介绍HuggingFace transformers组件的使用套路
  2. 以XLNet为例介绍如何接续下游的文本分类和抽取式问答。

主要参考文档代码

XLNet为例,使用其他Huggingface封装的预训练语言模型的套路与类似

In [1]:
import os
import torch
import torch.nn as nn
import torch.functional as F
!pip install transformers
from transformers import XLNetModel, XLNetTokenizer, XLNetConfig
Out [1]:
Requirement already satisfied: transformers in /usr/local/lib/python3.6/dist-packages (3.0.2)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: tokenizers==0.8.1.rc1 in /usr/local/lib/python3.6/dist-packages (from transformers) (0.8.1rc1)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.41.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.5)
Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers) (0.0.43)
Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.4)
Requirement already satisfied: sentencepiece!=0.1.92 in /usr/local/lib/python3.6/dist-packages (from transformers) (0.1.91)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.6.20)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.16.0)
Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.15.0)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7)
In [2]:
torch.transpose(torch.ones((3,2,4)), 2, 1).shape
Out [2]:
torch.Size([3, 4, 2])
In [3]:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
model = XLNetModel.from_pretrained('xlnet-base-cased', 
                                   output_hidden_states=True,
                                   output_attentions=True)

注意:

以上使用模型名称初始化的模块,程序会在后台下载预训练完成的XLNet模型并加载。对于内地同学,除改变上网方式外,还可以手动下载模型,指定路径加载。

手动下载模型:

在HuggingFace官方模型库上找到需要下载的模型,点击模型链接,例如:xlnet-base-cased模型。在跳转到的模型页面中点击List all files in model(字比较小,注意查看),将跳出框中的模型相关文(pytorch或tf版本)件保存到本地。 image.png

In [4]:
# # 本地加载XLNet模型
# MODEL_PATH = r"D:\data\nlp\xlnet-model/"
# config = XLNetConfig.from_json_file(os.path.join(MODEL_PATH, "xlnet-base-cased-config.json"))

# #config文件不仅用于设置模型参数,也可以用来控制模型的行为
# config.output_hidden_states = True
# config.output_attentions = True

# tokenizer = XLNetTokenizer(os.path.join(MODEL_PATH, 'xlnet-base-cased-spiece.model'))
# model = XLNetModel.from_pretrained(MODEL_PATH, config = config)

1. 句子到token id转换

In [5]:
# 利用tokenizer将原始的句子准备成模型输入
sentence = "This is an interesting review session"

# tokenization
tokens = tokenizer.tokenize(sentence)
print("Tokens: {}".format(tokens))

# 将token转化为ID
tokens_ids = tokenizer.convert_tokens_to_ids(tokens)
print("Tokens id: {}".format(tokens_ids))

# 添加特殊token: <cls>, <sep>
tokens_ids = tokenizer.build_inputs_with_special_tokens(tokens_ids)

# 准备成pytorch tensor
tokens_pt = torch.tensor([tokens_ids])
print("Tokens PyTorch: {}".format(tokens_pt))

# print(tokenizer.convert_ids_to_tokens([122,   27,   48, 5272,  717,    4,    3]))
Out [5]:
Tokens: ['▁This', '▁is', '▁an', '▁interesting', '▁review', '▁session']
Tokens id: [122, 27, 48, 2456, 1398, 1961]
Tokens PyTorch: tensor([[ 122,   27,   48, 2456, 1398, 1961,    4,    3]])
In [6]:
# 偷懒的一条龙服务
tokens_pt2 = tokenizer(sentence, return_tensors="pt")
print("Tokens PyTorch: {}".format(tokens_pt2))
Out [6]:
Tokens PyTorch: {'input_ids': tensor([[ 122,   27,   48, 2456, 1398, 1961,    4,    3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
In [7]:
# 批处理
# padding
sentences = ["The ultimate answer to life, universe and time is 42.", "Take a towel for a space travel."]
print("Batch tokenization:\n", tokenizer(sentences)['input_ids'])
print("With Padding:\n", tokenizer(sentences, padding=True)['input_ids'])
Out [7]:
Batch tokenization:
 [[32, 6452, 1543, 22, 235, 19, 6486, 21, 92, 27, 4087, 9, 4, 3], [3636, 24, 14680, 28, 24, 888, 1316, 9, 4, 3]]
With Padding:
 [[32, 6452, 1543, 22, 235, 19, 6486, 21, 92, 27, 4087, 9, 4, 3], [5, 5, 5, 5, 3636, 24, 14680, 28, 24, 888, 1316, 9, 4, 3]]
In [8]:
# 输入句子对:
multi_seg_input = tokenizer("This is segment A", "This is segment B")
print("Multi segment token (str): {}".format(tokenizer.convert_ids_to_tokens(multi_seg_input['input_ids'])))
print("Multi segment token (int): {}".format(multi_seg_input['input_ids']))
print("Multi segment type       : {}".format(multi_seg_input['token_type_ids']))
Out [8]:
Multi segment token (str): ['▁This', '▁is', '▁segment', '▁A', '<sep>', '▁This', '▁is', '▁segment', '▁B', '<sep>', '<cls>']
Multi segment token (int): [122, 27, 7295, 79, 4, 122, 27, 7295, 322, 4, 3]
Multi segment type       : [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2]

2. 模型encoding

In [9]:
# 默认情况下,model.dev()模式下。下面使用模型Encode输入的句子
# 因为我们在config中设置模型返回每层的hidden states和注意力,再加上默认输出的最后一层隐状态,输出有3个部分
print("Is training mode ? ", model.training)

sentence = "The ultimate answer to life, universe and time is 42."

tokens_pt = tokenizer(sentence, return_tensors="pt")
print("Token (str): {}".format(
    tokenizer.convert_ids_to_tokens(tokens_pt['input_ids'][0])
    ))

final_layer_h, all_layer_h, attentions = model(**tokens_pt)

print(torch.sum(final_layer_h - all_layer_h[-1]).item())

final_layer_h.shape, len(all_layer_h), len(attentions)
Out [9]:
Is training mode ?  False
Token (str): ['▁The', '▁ultimate', '▁answer', '▁to', '▁life', ',', '▁universe', '▁and', '▁time', '▁is', '▁42', '.', '<sep>', '<cls>']
0.0

3. 下游任务

例1. 文本分类

In [10]:
class XLNetSeqSummary(nn.Module):
    
    def __init__(self, 
                 how='cls', 
                 hidden_size=768, 
                 activation=None, 
                 first_dropout=None, 
                 last_dropout=None):
        super().__init__()
        self.how = how
        self.summary = nn.Linear(hidden_size, hidden_size)
        self.activation = activation if activation else nn.GELU()
        self.first_dropout = first_dropout if first_dropout else nn.Dropout(0.5)
        self.last_dropout = last_dropout if last_dropout else nn.Dropout(0.5)

    def forward(self, hidden_states):
        """
        对隐状态序列池化或返回cls处的表示,作为句子的encoding
        Args:
            hidden_states :
                XLNet模型输出的最后层隐状态序列.
        Returns:
            : 句子向量表示
        """
        if self.how == "cls":
            output = hidden_states[:, -1]
        elif self.how == "mean":
            output = hidden_states.mean(dim=1)
        elif self.how == "max":
            output = hidden_states.max(dim=1)
        else:
            raise Exception("Summary type '{}' not implemted.".format(self.how))

        output = self.first_dropout(output)
        output = self.summary(output)
        output = self.activation(output)
        output = self.last_dropout(output)

        return output


class XLNetSentenceClassifier(nn.Module):
    
    def __init__(self,
                 num_labels,
                 xlnet_model,
                 d_model=768):
        super().__init__()
        self.num_labels = num_labels
        self.d_model = d_model
        self.transformer = xlnet_model
        self.sequence_summary = XLNetSeqSummary('cls', d_model, nn.GELU())
        self.logits_proj = nn.Linear(d_model, num_labels)
        
    def forward(self, model_inputs):
        transformer_outputs = self.transformer(**model_inputs)
            
        output = transformer_outputs[0]
        output = self.sequence_summary(output)
        logits = self.logits_proj(output)

        return logits
    
def get_loss(criterion, logits, labels):
    return criterion(logits, labels)

In [11]:
# 验证forward和反向传播

# toy examples
sentences = ["The ultimate answer to life, universe and time is 42.", 
             "Take a towel for a space travel."]
labels = torch.LongTensor([0, 1])

# 实例化各个模块
criterion = nn.CrossEntropyLoss()
classifier = XLNetSentenceClassifier(2, model, 768)
optimizer = torch.optim.AdamW(classifier.parameters())

# forward + loss
classifier.train()
optimizer.zero_grad()
logits = classifier(tokenizer(sentences, padding=True, return_tensors='pt'))
loss = get_loss(criterion, logits, labels)

print("Loss: ", loss.item())

# backwawrd step
loss.backward()
optimizer.step()

print("="*25)
print("Confirm that the gradients are computed for the original XLNet parameters.\n")
print("="*25)
for param in classifier.parameters():
    print(param.shape, param.grad.sum() if not param.grad is None else param.grad)
Out [11]:
/usr/local/lib/python3.6/dist-packages/transformers/modeling_xlnet.py:283: UserWarning: Mixed memory format inputs detected while calling the operator. The operator will output contiguous tensor even if some of the inputs are in channels_last format. (Triggered internally at  /pytorch/aten/src/ATen/native/TensorIterator.cpp:918.)
  attn_score = (ac + bd + ef) * self.scale

例2. 抽取式问答(类似SQuAD)

image.png

In [12]:
class AnsStartLogits(nn.Module):
    """
    用于预测每个token是否为答案span开始位置
    """
    def __init__(self, hidden_size):
        super().__init__()
        self.linear = nn.Linear(hidden_size, 1)

    def forward(self, 
                hidden_states, 
                p_mask=None
               ):
        x = self.linear(hidden_states).squeeze(-1)

        if p_mask is not None:
            x = x * (1 - p_mask) - 1e30 * p_mask
        return x
    
    
class AnsEndLogits(nn.Module):
    """
    用于预测每个token是否为答案span结束位置,符合直觉,conditioned on 开始位置
    """
    def __init__(self, hidden_size):
        super().__init__()
        self.layer = nn.Sequential(
                nn.Linear(hidden_size * 2, hidden_size),
                nn.Tanh(),
                nn.LayerNorm(hidden_size),
                nn.Linear(hidden_size, 1)
        )

    def forward(self,
                hidden_states,
                start_states,
                p_mask = None,
               ):

        x = self.layer(torch.cat([hidden_states, start_states], dim=-1))
        x = x.squeeze(-1)

        if p_mask is not None:
            x = x * (1 - p_mask) - 1e30 * p_mask
        return x
    

class XLNetQuestionAnswering(nn.Module):
    
    def __init__(self,
                 num_labels,
                 xlnet_model,
                 d_model=768,
                 top_k_start=2,
                 top_k_end=2
                 ):
        super().__init__()
        self.transformer = xlnet_model
        self.start_logits = AnsStartLogits(d_model)
        self.end_logits = AnsEndLogits(d_model)
        self.top_k_start = top_k_start # for beam search
        self.top_k_end = top_k_end # for beam search       
        
    def forward(self, 
                model_inputs,
                p_mask=None,
                start_positions=None
                ):
        """
        p_mask:
            可选的mask, 被mask掉的位置不可能存在答案(e.g. [CLS], [PAD], ...)。
            1.0 表示应当被mask. 0.0反之。
        start_positions:
            正确答案标注的开始位置。训练时需要输入模型以利用teacher forcing计算end_logits。
            Inference时不需输入,beam search返回top k个开始和结束位置。
        """
        transformer_outputs = self.transformer(**model_inputs)
        
        hidden_states = transformer_outputs[0]
        start_logits = self.start_logits(hidden_states, p_mask=p_mask)
        
        if not start_positions is None:
            # 在训练时利用 teacher forcing trick训练 end_logits
            slen, hsz = hidden_states.shape[-2:]
            start_positions = start_positions.expand(-1, -1, hsz)  # shape (bsz, 1, hsz)
            start_states = hidden_states.gather(-2, start_positions)  # shape (bsz, 1, hsz)
            start_states = start_states.expand(-1, slen, -1)  # shape (bsz, slen, hsz)
            end_logits = self.end_logits(hidden_states, 
                                         start_states=start_states, 
                                         p_mask=p_mask)
            
            return start_logits, end_logits
        else:
            # 在Inference时利用Beam Search求end_logits
            bsz, slen, hsz = hidden_states.size()
            start_probs = torch.softmax(start_logits, dim=-1)  # shape (bsz, slen)

            start_top_probs, start_top_index = torch.topk(
                start_probs, self.top_k_start, dim=-1
            )  # shape (bsz, top_k_start)
            start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz)  # shape (bsz, top_k_start, hsz)
            start_states = torch.gather(hidden_states, -2, start_top_index_exp)  # shape (bsz, top_k_start, hsz)
            start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1)  # shape (bsz, slen, top_k_start, hsz)

            hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
                start_states
            )  # shape (bsz, slen, top_k_start, hsz)
            p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
            end_logits = self.end_logits(hidden_states_expanded, 
                                         start_states=start_states, 
                                         p_mask=p_mask) 
            end_probs = torch.softmax(end_logits, dim=1)  # shape (bsz, slen, top_k_start)

            end_top_probs, end_top_index = torch.topk(
                end_probs, self.top_k_end, dim=1
            )  # shape (bsz, top_k_end, top_k_start)

            end_top_probs = torch.transpose(end_top_probs, 2, 1) # shape (bsz, top_k_start, top_k_end)
            end_top_index = torch.transpose(end_top_index, 2, 1) # shape (bsz, top_k_start, top_k_end)

            end_top_probs = end_top_probs.reshape(-1, self.top_k_start * self.top_k_end)
            end_top_index = end_top_index.reshape(-1, self.top_k_start * self.top_k_end)

            
            return start_top_probs, start_top_index, end_top_probs, end_top_index, start_logits, end_logits
    
def get_loss(criterion, 
             start_logits, 
             start_positions,
             end_logits,
             end_positions
            ):
    start_loss = criterion(start_logits, start_positions)
    end_loss = criterion(end_logits, end_positions)
    return (start_loss + end_loss) / 2
    
    

检测用于训练的forward和backward

In [13]:
context = r"""
    Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose
    architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural
    Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between
    TensorFlow 2.0 and PyTorch.
    """
questions = [
    "How many pretrained models are available in Transformers?",
    "What does Transformers provide?",
    "Transformers provides interoperability between which frameworks?",
]

start_positions = torch.LongTensor([95, 36, 110])
end_positions = torch.LongTensor([97, 88, 123])
p_mask = [[1]*12 + [0]* (125 -14) + [1,1],
          [1]* 7 + [0]* (120 - 9) + [1,1],
          [1]*12 + [0]* (125 -14) + [1,1],
         ]

neg_log_loss = nn.CrossEntropyLoss()

q_answer = XLNetQuestionAnswering(2, model, 768, 2, 2)

optimizer = torch.optim.AdamW(q_answer.parameters())
In [14]:
q_answer.train()
optimizer.zero_grad()
for ith, question in enumerate(questions):
    start_logits, end_logits = q_answer(
        tokenizer(question, 
                  context, 
                  add_special_tokens=True,
                  return_tensors='pt'),
        p_mask=torch.ByteTensor(p_mask[ith]),
        start_positions=start_positions[ith].view(1,1,1)
    )
    loss = get_loss(
        criterion,
        start_logits, 
        start_positions[ith].view(-1),
        end_logits,
        end_positions[ith].view(-1)
    )
    print("\nTrue Start: {}, True End: {}\nPred Start Prob: {}, Pred End Prob: {}\nPred Max Start: {}, Pred Max End: {}\nPred Max Start Prob: {}, Pred Max end Prob:{}\nLoss: {}\n".format(
        start_positions[ith].item(),
        end_positions[ith].item(),
        torch.sigmoid(start_logits[:,start_positions[ith]]).item(), 
        torch.sigmoid(end_logits[:, end_positions[ith]]).item(),
        torch.argmax(start_logits).item(),
        torch.argmax(end_logits).item(),
        torch.sigmoid(torch.max(start_logits)).item(), 
        torch.sigmoid(torch.max(end_logits)).item(),
        loss.item()
        )
    )
    print("="*25)
    loss.backward()
    optimizer.step()

print("\nConfirm that the gradients are computed for the original XLNet parameters.")
for param in q_answer.parameters():
    print(param.shape, param.grad.sum() if not param.grad is None else param.grad)
Out [14]:
True Start: 95, True End: 97
Pred Start Prob: 0.3128710687160492, Pred End Prob: 0.6340706944465637
Pred Max Start: 78, Pred Max End: 39
Pred Max Start Prob: 0.6634821891784668, Pred Max end Prob:0.6745399832725525
Loss: 4.699392318725586

=========================

True Start: 36, True End: 88
Pred Start Prob: 0.6726937294006348, Pred End Prob: 0.93719482421875
Pred Max Start: 25, Pred Max End: 96
Pred Max Start Prob: 0.8701046705245972, Pred Max end Prob:0.9859831929206848
Loss: 5.206717491149902

=========================

True Start: 110, True End: 123
Pred Start Prob: 0.8552238941192627, Pred End Prob: 0.0
Pred Max Start: 97, Pred Max End: 98
Pred Max Start Prob: 0.9279953241348267, Pred Max end Prob:0.4768109917640686
Loss: 5.000000075237331e+29

=========================

Confirm that the gradients are computed for the original XLNet parameters.
torch.Size([1, 1, 768]) None
torch.Size([32000, 768]) tensor(-0.6884)
torch.Size([768, 12, 64]) tensor(-0.0334)
torch.Size([768, 12, 64]) tensor(0.0020)
torch.Size([768, 12, 64]) tensor(-0.2530)
torch.Size([768, 12, 64]) tensor(0.1670)
torch.Size([768, 12, 64]) tensor(-0.3664)
torch.Size([12, 64]) tensor(0.0001)
torch.Size([12, 64]) tensor(-0.0014)
torch.Size([12, 64]) tensor(0.0102)
torch.Size([2, 12, 64]) tensor(-9.2882e-10)
torch.Size([768]) tensor(-0.0579)
torch.Size([768]) tensor(-0.0086)
torch.Size([768]) tensor(0.3276)
torch.Size([768]) tensor(-0.4108)
torch.Size([3072, 768]) tensor(-0.7222)
torch.Size([3072]) tensor(0.0196)
torch.Size([768, 3072]) tensor(1.9823)
torch.Size([768]) tensor(-0.0305)
torch.Size([768, 12, 64]) tensor(0.1537)
torch.Size([768, 12, 64]) tensor(0.1984)
torch.Size([768, 12, 64]) tensor(-0.1475)
torch.Size([768, 12, 64]) tensor(-0.3304)
torch.Size([768, 12, 64]) tensor(-0.3685)
torch.Size([12, 64]) tensor(-0.0085)
torch.Size([12, 64]) tensor(-9.0756e-05)
torch.Size([12, 64]) tensor(0.0146)
torch.Size([2, 12, 64]) tensor(-6.3806e-09)
torch.Size([768]) tensor(-0.0117)
torch.Size([768]) tensor(0.0255)
torch.Size([768]) tensor(0.0031)
torch.Size([768]) tensor(-0.4477)
torch.Size([3072, 768]) tensor(-0.6170)
torch.Size([3072]) tensor(-0.0502)
torch.Size([768, 3072]) tensor(-0.3674)
torch.Size([768]) tensor(0.0046)
torch.Size([768, 12, 64]) tensor(0.1544)
torch.Size([768, 12, 64]) tensor(0.0277)
torch.Size([768, 12, 64]) tensor(-0.7144)
torch.Size([768, 12, 64]) tensor(0.0200)
torch.Size([768, 12, 64]) tensor(2.8775)
torch.Size([12, 64]) tensor(-0.0066)
torch.Size([12, 64]) tensor(-0.0012)
torch.Size([12, 64]) tensor(-0.0104)
torch.Size([2, 12, 64]) tensor(5.9663e-09)
torch.Size([768]) tensor(0.0509)
torch.Size([768]) tensor(-0.0095)
torch.Size([768]) tensor(-0.0630)
torch.Size([768]) tensor(0.0002)
torch.Size([3072, 768]) tensor(0.3525)
torch.Size([3072]) tensor(0.0530)
torch.Size([768, 3072]) tensor(2.0844)
torch.Size([768]) tensor(-0.0176)
torch.Size([768, 12, 64]) tensor(0.2019)
torch.Size([768, 12, 64]) tensor(0.2833)
torch.Size([768, 12, 64]) tensor(0.6085)
torch.Size([768, 12, 64]) tensor(-0.1063)
torch.Size([768, 12, 64]) tensor(-0.7396)
torch.Size([12, 64]) tensor(0.0198)
torch.Size([12, 64]) tensor(-0.0023)
torch.Size([12, 64]) tensor(-0.0082)
torch.Size([2, 12, 64]) tensor(-2.4584e-09)
torch.Size([768]) tensor(-0.2631)
torch.Size([768]) tensor(-0.0389)
torch.Size([768]) tensor(0.0537)
torch.Size([768]) tensor(0.2082)
torch.Size([3072, 768]) tensor(5.2597)
torch.Size([3072]) tensor(-0.2162)
torch.Size([768, 3072]) tensor(1.0177)
torch.Size([768]) tensor(-0.0024)
torch.Size([768, 12, 64]) tensor(0.2092)
torch.Size([768, 12, 64]) tensor(-0.0467)
torch.Size([768, 12, 64]) tensor(2.2503)
torch.Size([768, 12, 64]) tensor(0.3033)
torch.Size([768, 12, 64]) tensor(-1.2354)
torch.Size([12, 64]) tensor(-0.0038)
torch.Size([12, 64]) tensor(-0.0013)
torch.Size([12, 64]) tensor(0.0100)
torch.Size([2, 12, 64]) tensor(5.4883e-09)
torch.Size([768]) tensor(-0.1086)
torch.Size([768]) tensor(0.0420)
torch.Size([768]) tensor(0.1451)
torch.Size([768]) tensor(0.3843)
torch.Size([3072, 768]) tensor(1.1731)
torch.Size([3072]) tensor(-0.0912)
torch.Size([768, 3072]) tensor(-0.0175)
torch.Size([768]) tensor(0.0038)
torch.Size([768, 12, 64]) tensor(-0.1283)
torch.Size([768, 12, 64]) tensor(0.0159)
torch.Size([768, 12, 64]) tensor(1.8471)
torch.Size([768, 12, 64]) tensor(-0.0603)
torch.Size([768, 12, 64]) tensor(-0.5668)
torch.Size([12, 64]) tensor(-0.0011)
torch.Size([12, 64]) tensor(0.0008)
torch.Size([12, 64]) tensor(0.0101)
torch.Size([2, 12, 64]) tensor(1.6712e-09)
torch.Size([768]) tensor(0.2400)
torch.Size([768]) tensor(-0.0586)
torch.Size([768]) tensor(0.0619)
torch.Size([768]) tensor(-0.2086)
torch.Size([3072, 768]) tensor(1.8524)
torch.Size([3072]) tensor(-0.1007)
torch.Size([768, 3072]) tensor(1.4429)
torch.Size([768]) tensor(-0.0088)
torch.Size([768, 12, 64]) tensor(0.5602)
torch.Size([768, 12, 64]) tensor(-0.0269)
torch.Size([768, 12, 64]) tensor(12.1219)
torch.Size([768, 12, 64]) tensor(-0.5672)
torch.Size([768, 12, 64]) tensor(-1.1886)
torch.Size([12, 64]) tensor(-0.0125)
torch.Size([12, 64]) tensor(-0.0005)
torch.Size([12, 64]) tensor(-0.0047)
torch.Size([2, 12, 64]) tensor(8.1666e-09)
torch.Size([768]) tensor(0.3063)
torch.Size([768]) tensor(-0.0601)
torch.Size([768]) tensor(0.0642)
torch.Size([768]) tensor(-1.2854)
torch.Size([3072, 768]) tensor(0.6916)
torch.Size([3072]) tensor(0.0861)
torch.Size([768, 3072]) tensor(-5.8404)
torch.Size([768]) tensor(-0.0209)
torch.Size([768, 12, 64]) tensor(-0.2547)
torch.Size([768, 12, 64]) tensor(0.0448)
torch.Size([768, 12, 64]) tensor(-11.1511)
torch.Size([768, 12, 64]) tensor(-0.9826)
torch.Size([768, 12, 64]) tensor(-4.3894)
torch.Size([12, 64]) tensor(-0.0002)
torch.Size([12, 64]) tensor(0.0011)
torch.Size([12, 64]) tensor(0.0086)
torch.Size([2, 12, 64]) tensor(1.4228e-08)
torch.Size([768]) tensor(0.3623)
torch.Size([768]) tensor(0.4062)
torch.Size([768]) tensor(-0.1065)
torch.Size([768]) tensor(-3.0040)
torch.Size([3072, 768]) tensor(5.2218)
torch.Size([3072]) tensor(-0.2266)
torch.Size([768, 3072]) tensor(9.2306)
torch.Size([768]) tensor(0.0284)
torch.Size([768, 12, 64]) tensor(-0.4832)
torch.Size([768, 12, 64]) tensor(0.0096)
torch.Size([768, 12, 64]) tensor(-10.0414)
torch.Size([768, 12, 64]) tensor(0.5542)
torch.Size([768, 12, 64]) tensor(-1.1537)
torch.Size([12, 64]) tensor(0.0121)
torch.Size([12, 64]) tensor(-0.0002)
torch.Size([12, 64]) tensor(0.0021)
torch.Size([2, 12, 64]) tensor(1.1473e-08)
torch.Size([768]) tensor(0.1046)
torch.Size([768]) tensor(-0.5176)
torch.Size([768]) tensor(-0.1872)
torch.Size([768]) tensor(-0.4094)
torch.Size([3072, 768]) tensor(-2.1765)
torch.Size([3072]) tensor(0.2907)
torch.Size([768, 3072]) tensor(-17.4914)
torch.Size([768]) tensor(-0.2188)
torch.Size([768, 12, 64]) tensor(0.1278)
torch.Size([768, 12, 64]) tensor(0.0075)
torch.Size([768, 12, 64]) tensor(4.4422)
torch.Size([768, 12, 64]) tensor(0.1824)
torch.Size([768, 12, 64]) tensor(-0.2453)
torch.Size([12, 64]) tensor(-0.0021)
torch.Size([12, 64]) tensor(0.0010)
torch.Size([12, 64]) tensor(-0.0038)
torch.Size([2, 12, 64]) tensor(3.8835e-08)
torch.Size([768]) tensor(0.0947)
torch.Size([768]) tensor(-0.1559)
torch.Size([768]) tensor(-0.4608)
torch.Size([768]) tensor(-3.1537)
torch.Size([3072, 768]) tensor(-0.0686)
torch.Size([3072]) tensor(-0.1366)
torch.Size([768, 3072]) tensor(-23.9523)
torch.Size([768]) tensor(-0.0108)
torch.Size([768, 12, 64]) tensor(-1.3291)
torch.Size([768, 12, 64]) tensor(0.0523)
torch.Size([768, 12, 64]) tensor(-1.8151)
torch.Size([768, 12, 64]) tensor(-0.4039)
torch.Size([768, 12, 64]) tensor(-0.0539)
torch.Size([12, 64]) tensor(0.0087)
torch.Size([12, 64]) tensor(-0.0025)
torch.Size([12, 64]) tensor(0.0149)
torch.Size([2, 12, 64]) tensor(-7.3160e-09)
torch.Size([768]) tensor(-0.4885)
torch.Size([768]) tensor(1.7355)
torch.Size([768]) tensor(-0.3012)
torch.Size([768]) tensor(1.5989)
torch.Size([3072, 768]) tensor(10.9814)
torch.Size([3072]) tensor(-1.0466)
torch.Size([768, 3072]) tensor(7.2909)
torch.Size([768]) tensor(0.0650)
torch.Size([768, 12, 64]) tensor(4.3189)
torch.Size([768, 12, 64]) tensor(-0.3560)
torch.Size([768, 12, 64]) tensor(36.1090)
torch.Size([768, 12, 64]) tensor(-2.3517)
torch.Size([768, 12, 64]) tensor(-2.1882)
torch.Size([12, 64]) tensor(0.0060)
torch.Size([12, 64]) tensor(-0.0008)
torch.Size([12, 64]) tensor(-0.0341)
torch.Size([2, 12, 64]) tensor(4.8329e-08)
torch.Size([768]) tensor(11.1667)
torch.Size([768]) tensor(17.8769)
torch.Size([768]) tensor(0.2330)
torch.Size([768]) tensor(-0.1388)
torch.Size([3072, 768]) tensor(20.5294)
torch.Size([3072]) tensor(-1.7453)
torch.Size([768, 3072]) tensor(-39.9618)
torch.Size([768]) tensor(0.1220)
torch.Size([1, 768]) tensor(-10.1781)
torch.Size([1]) tensor(9.1270e-08)
torch.Size([768, 1536]) tensor(64.5171)
torch.Size([768]) tensor(-0.1971)
torch.Size([768]) tensor(0.1439)
torch.Size([768]) tensor(-0.2693)
torch.Size([1, 768]) tensor(0.0336)
torch.Size([1]) tensor(0.5000)

inference的forward以及实现Beam Search decoding

In [15]:
import numpy as np
def decode(start_probs, end_probs, topk):
    """
    给定beam中预测的开始和结束概率,搜索topk个最佳答案
    """
    top_k_start = start_probs.shape[-1]
    top_k_end = end_probs.shape[-1] // top_k_start

    # 计算每一个(start, end)对的分数 P(start, end| sentence) = P(start|sentence) * P(end|start, sentence)
    joint_probs = dict()
    for i in range(top_k_start):
      for j in range(top_k_end):
        end_idx = i*top_k_end+j
        joint_probs[(i, end_idx)] = start_probs[i]*end_probs[end_idx]
    
    id_pairs, probs = zip(*sorted(joint_probs.items(), key=lambda kv:kv[1], reverse=True)[:topk])
    start_ids, end_ids = zip(*id_pairs)
    return start_ids, end_ids, probs
In [16]:
# inference
context = r"""
    Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose
    architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural
    Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between
    TensorFlow 2.0 and PyTorch.
    """
questions = [
    "How many pretrained models are available in Transformers?",
    "What does Transformers provide?",
    "Transformers provides interoperability between which frameworks?",
]
q_answer.eval()
for ith, question in enumerate(questions):
    inputs = tokenizer(question, context, add_special_tokens=True, return_tensors="pt")
    input_ids = inputs["input_ids"].tolist()[0]
    
    text_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    start_probs, start_index, end_probs, end_index, stt_logits, end_logits = q_answer(
        inputs, 
        p_mask=torch.ByteTensor(p_mask[ith])
    )

    pred_starts, pred_ends, probs = decode(
        start_probs.detach().squeeze().numpy(), 
        end_probs.detach().squeeze().numpy(), 
        2)
    
    # 只打印一个答案
    start = start_index[:, pred_starts[0]].item()
    end = end_index[:, pred_ends[0]].item()
    
#     print(probs, pred_starts, pred_ends)
#     print(len(input_ids), stt_logits.shape, end_logits.shape)
#     print(tokenizer.convert_ids_to_tokens(input_ids).index('?'))

    print("="*25)
    print("True start: {}, True end: {}".format(
        start_positions[ith].item(),
        end_positions[ith].item()
        ))
    print("Max answer prob: {:0.8f}, start idx: {}, end idx: {}".format(
        probs[0],
        start,
        end,
    ))
    print("-"*25)
    print("Question: '{}'".format(question))
    print("Answer: '{}'".format(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[start:end]))))
    print("="*25)
Out [16]:
=========================
True start: 95, True end: 97
Max answer prob: 0.00008122, start idx: 120, end idx: 122
-------------------------
Question: 'How many pretrained models are available in Transformers?'
Answer: 'orch'
=========================
=========================
True start: 36, True end: 88
Max answer prob: 0.00008121, start idx: 115, end idx: 117
-------------------------
Question: 'What does Transformers provide?'
Answer: 'orch'
=========================
=========================
True start: 110, True end: 123
Max answer prob: 0.00008122, start idx: 120, end idx: 122
-------------------------
Question: 'Transformers provides interoperability between which frameworks?'
Answer: 'orch'
=========================
In [16]: