0816XLNet_for_down_stream_tasks_by_Qiao.ipynb
62.5 KB
利用Huggingface实现的预训练语言模型做下游任务
by Qiao for NLP7 2020-8-16
预训练语言模型的用法:
- 作为特征提取器
- 作为encoder参与下游任务微调 使用上非常类似,差别是后者在训练过程中原预训练语言模型的参数也允许优化。
主要内容:
- 以XLNet介绍HuggingFace transformers组件的使用套路
- 以XLNet为例介绍如何接续下游的文本分类和抽取式问答。
以XLNet为例,使用其他Huggingface封装的预训练语言模型的套路与类似
In [1]:
import os import torch import torch.nn as nn import torch.functional as F !pip install transformers from transformers import XLNetModel, XLNetTokenizer, XLNetConfig
Out [1]:
Requirement already satisfied: transformers in /usr/local/lib/python3.6/dist-packages (3.0.2) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20) Requirement already satisfied: tokenizers==0.8.1.rc1 in /usr/local/lib/python3.6/dist-packages (from transformers) (0.8.1rc1) Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0) Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7) Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12) Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.41.1) Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.5) Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers) (0.0.43) Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.4) Requirement already satisfied: sentencepiece!=0.1.92 in /usr/local/lib/python3.6/dist-packages (from transformers) (0.1.91) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.6.20) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.10) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4) Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.16.0) Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2) Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.15.0) Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7)
In [2]:
torch.transpose(torch.ones((3,2,4)), 2, 1).shape
Out [2]:
torch.Size([3, 4, 2])
In [3]:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased') model = XLNetModel.from_pretrained('xlnet-base-cased', output_hidden_states=True, output_attentions=True)
注意:
以上使用模型名称初始化的模块,程序会在后台下载预训练完成的XLNet模型并加载。对于内地同学,除改变上网方式外,还可以手动下载模型,指定路径加载。
手动下载模型:
在HuggingFace官方模型库上找到需要下载的模型,点击模型链接,例如:xlnet-base-cased模型。在跳转到的模型页面中点击List all files in model
(字比较小,注意查看),将跳出框中的模型相关文(pytorch或tf版本)件保存到本地。
In [4]:
# # 本地加载XLNet模型 # MODEL_PATH = r"D:\data\nlp\xlnet-model/" # config = XLNetConfig.from_json_file(os.path.join(MODEL_PATH, "xlnet-base-cased-config.json")) # #config文件不仅用于设置模型参数,也可以用来控制模型的行为 # config.output_hidden_states = True # config.output_attentions = True # tokenizer = XLNetTokenizer(os.path.join(MODEL_PATH, 'xlnet-base-cased-spiece.model')) # model = XLNetModel.from_pretrained(MODEL_PATH, config = config)
1. 句子到token id转换
In [5]:
# 利用tokenizer将原始的句子准备成模型输入 sentence = "This is an interesting review session" # tokenization tokens = tokenizer.tokenize(sentence) print("Tokens: {}".format(tokens)) # 将token转化为ID tokens_ids = tokenizer.convert_tokens_to_ids(tokens) print("Tokens id: {}".format(tokens_ids)) # 添加特殊token: <cls>, <sep> tokens_ids = tokenizer.build_inputs_with_special_tokens(tokens_ids) # 准备成pytorch tensor tokens_pt = torch.tensor([tokens_ids]) print("Tokens PyTorch: {}".format(tokens_pt)) # print(tokenizer.convert_ids_to_tokens([122, 27, 48, 5272, 717, 4, 3]))
Out [5]:
Tokens: ['▁This', '▁is', '▁an', '▁interesting', '▁review', '▁session'] Tokens id: [122, 27, 48, 2456, 1398, 1961] Tokens PyTorch: tensor([[ 122, 27, 48, 2456, 1398, 1961, 4, 3]])
In [6]:
# 偷懒的一条龙服务 tokens_pt2 = tokenizer(sentence, return_tensors="pt") print("Tokens PyTorch: {}".format(tokens_pt2))
Out [6]:
Tokens PyTorch: {'input_ids': tensor([[ 122, 27, 48, 2456, 1398, 1961, 4, 3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
In [7]:
# 批处理 # padding sentences = ["The ultimate answer to life, universe and time is 42.", "Take a towel for a space travel."] print("Batch tokenization:\n", tokenizer(sentences)['input_ids']) print("With Padding:\n", tokenizer(sentences, padding=True)['input_ids'])
Out [7]:
Batch tokenization: [[32, 6452, 1543, 22, 235, 19, 6486, 21, 92, 27, 4087, 9, 4, 3], [3636, 24, 14680, 28, 24, 888, 1316, 9, 4, 3]] With Padding: [[32, 6452, 1543, 22, 235, 19, 6486, 21, 92, 27, 4087, 9, 4, 3], [5, 5, 5, 5, 3636, 24, 14680, 28, 24, 888, 1316, 9, 4, 3]]
In [8]:
# 输入句子对: multi_seg_input = tokenizer("This is segment A", "This is segment B") print("Multi segment token (str): {}".format(tokenizer.convert_ids_to_tokens(multi_seg_input['input_ids']))) print("Multi segment token (int): {}".format(multi_seg_input['input_ids'])) print("Multi segment type : {}".format(multi_seg_input['token_type_ids']))
Out [8]:
Multi segment token (str): ['▁This', '▁is', '▁segment', '▁A', '<sep>', '▁This', '▁is', '▁segment', '▁B', '<sep>', '<cls>'] Multi segment token (int): [122, 27, 7295, 79, 4, 122, 27, 7295, 322, 4, 3] Multi segment type : [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2]
2. 模型encoding
In [9]:
# 默认情况下,model.dev()模式下。下面使用模型Encode输入的句子 # 因为我们在config中设置模型返回每层的hidden states和注意力,再加上默认输出的最后一层隐状态,输出有3个部分 print("Is training mode ? ", model.training) sentence = "The ultimate answer to life, universe and time is 42." tokens_pt = tokenizer(sentence, return_tensors="pt") print("Token (str): {}".format( tokenizer.convert_ids_to_tokens(tokens_pt['input_ids'][0]) )) final_layer_h, all_layer_h, attentions = model(**tokens_pt) print(torch.sum(final_layer_h - all_layer_h[-1]).item()) final_layer_h.shape, len(all_layer_h), len(attentions)
Out [9]:
Is training mode ? False Token (str): ['▁The', '▁ultimate', '▁answer', '▁to', '▁life', ',', '▁universe', '▁and', '▁time', '▁is', '▁42', '.', '<sep>', '<cls>'] 0.0
3. 下游任务
例1. 文本分类
In [10]:
class XLNetSeqSummary(nn.Module): def __init__(self, how='cls', hidden_size=768, activation=None, first_dropout=None, last_dropout=None): super().__init__() self.how = how self.summary = nn.Linear(hidden_size, hidden_size) self.activation = activation if activation else nn.GELU() self.first_dropout = first_dropout if first_dropout else nn.Dropout(0.5) self.last_dropout = last_dropout if last_dropout else nn.Dropout(0.5) def forward(self, hidden_states): """ 对隐状态序列池化或返回cls处的表示,作为句子的encoding Args: hidden_states : XLNet模型输出的最后层隐状态序列. Returns: : 句子向量表示 """ if self.how == "cls": output = hidden_states[:, -1] elif self.how == "mean": output = hidden_states.mean(dim=1) elif self.how == "max": output = hidden_states.max(dim=1) else: raise Exception("Summary type '{}' not implemted.".format(self.how)) output = self.first_dropout(output) output = self.summary(output) output = self.activation(output) output = self.last_dropout(output) return output class XLNetSentenceClassifier(nn.Module): def __init__(self, num_labels, xlnet_model, d_model=768): super().__init__() self.num_labels = num_labels self.d_model = d_model self.transformer = xlnet_model self.sequence_summary = XLNetSeqSummary('cls', d_model, nn.GELU()) self.logits_proj = nn.Linear(d_model, num_labels) def forward(self, model_inputs): transformer_outputs = self.transformer(**model_inputs) output = transformer_outputs[0] output = self.sequence_summary(output) logits = self.logits_proj(output) return logits def get_loss(criterion, logits, labels): return criterion(logits, labels)
In [11]:
# 验证forward和反向传播 # toy examples sentences = ["The ultimate answer to life, universe and time is 42.", "Take a towel for a space travel."] labels = torch.LongTensor([0, 1]) # 实例化各个模块 criterion = nn.CrossEntropyLoss() classifier = XLNetSentenceClassifier(2, model, 768) optimizer = torch.optim.AdamW(classifier.parameters()) # forward + loss classifier.train() optimizer.zero_grad() logits = classifier(tokenizer(sentences, padding=True, return_tensors='pt')) loss = get_loss(criterion, logits, labels) print("Loss: ", loss.item()) # backwawrd step loss.backward() optimizer.step() print("="*25) print("Confirm that the gradients are computed for the original XLNet parameters.\n") print("="*25) for param in classifier.parameters(): print(param.shape, param.grad.sum() if not param.grad is None else param.grad)
Out [11]:
/usr/local/lib/python3.6/dist-packages/transformers/modeling_xlnet.py:283: UserWarning: Mixed memory format inputs detected while calling the operator. The operator will output contiguous tensor even if some of the inputs are in channels_last format. (Triggered internally at /pytorch/aten/src/ATen/native/TensorIterator.cpp:918.) attn_score = (ac + bd + ef) * self.scale
例2. 抽取式问答(类似SQuAD)
In [12]:
class AnsStartLogits(nn.Module): """ 用于预测每个token是否为答案span开始位置 """ def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(hidden_size, 1) def forward(self, hidden_states, p_mask=None ): x = self.linear(hidden_states).squeeze(-1) if p_mask is not None: x = x * (1 - p_mask) - 1e30 * p_mask return x class AnsEndLogits(nn.Module): """ 用于预测每个token是否为答案span结束位置,符合直觉,conditioned on 开始位置 """ def __init__(self, hidden_size): super().__init__() self.layer = nn.Sequential( nn.Linear(hidden_size * 2, hidden_size), nn.Tanh(), nn.LayerNorm(hidden_size), nn.Linear(hidden_size, 1) ) def forward(self, hidden_states, start_states, p_mask = None, ): x = self.layer(torch.cat([hidden_states, start_states], dim=-1)) x = x.squeeze(-1) if p_mask is not None: x = x * (1 - p_mask) - 1e30 * p_mask return x class XLNetQuestionAnswering(nn.Module): def __init__(self, num_labels, xlnet_model, d_model=768, top_k_start=2, top_k_end=2 ): super().__init__() self.transformer = xlnet_model self.start_logits = AnsStartLogits(d_model) self.end_logits = AnsEndLogits(d_model) self.top_k_start = top_k_start # for beam search self.top_k_end = top_k_end # for beam search def forward(self, model_inputs, p_mask=None, start_positions=None ): """ p_mask: 可选的mask, 被mask掉的位置不可能存在答案(e.g. [CLS], [PAD], ...)。 1.0 表示应当被mask. 0.0反之。 start_positions: 正确答案标注的开始位置。训练时需要输入模型以利用teacher forcing计算end_logits。 Inference时不需输入,beam search返回top k个开始和结束位置。 """ transformer_outputs = self.transformer(**model_inputs) hidden_states = transformer_outputs[0] start_logits = self.start_logits(hidden_states, p_mask=p_mask) if not start_positions is None: # 在训练时利用 teacher forcing trick训练 end_logits slen, hsz = hidden_states.shape[-2:] start_positions = start_positions.expand(-1, -1, hsz) # shape (bsz, 1, hsz) start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) end_logits = self.end_logits(hidden_states, start_states=start_states, p_mask=p_mask) return start_logits, end_logits else: # 在Inference时利用Beam Search求end_logits bsz, slen, hsz = hidden_states.size() start_probs = torch.softmax(start_logits, dim=-1) # shape (bsz, slen) start_top_probs, start_top_index = torch.topk( start_probs, self.top_k_start, dim=-1 ) # shape (bsz, top_k_start) start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, top_k_start, hsz) start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, top_k_start, hsz) start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, top_k_start, hsz) hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( start_states ) # shape (bsz, slen, top_k_start, hsz) p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) end_probs = torch.softmax(end_logits, dim=1) # shape (bsz, slen, top_k_start) end_top_probs, end_top_index = torch.topk( end_probs, self.top_k_end, dim=1 ) # shape (bsz, top_k_end, top_k_start) end_top_probs = torch.transpose(end_top_probs, 2, 1) # shape (bsz, top_k_start, top_k_end) end_top_index = torch.transpose(end_top_index, 2, 1) # shape (bsz, top_k_start, top_k_end) end_top_probs = end_top_probs.reshape(-1, self.top_k_start * self.top_k_end) end_top_index = end_top_index.reshape(-1, self.top_k_start * self.top_k_end) return start_top_probs, start_top_index, end_top_probs, end_top_index, start_logits, end_logits def get_loss(criterion, start_logits, start_positions, end_logits, end_positions ): start_loss = criterion(start_logits, start_positions) end_loss = criterion(end_logits, end_positions) return (start_loss + end_loss) / 2
检测用于训练的forward和backward
In [13]:
context = r""" Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch. """ questions = [ "How many pretrained models are available in Transformers?", "What does Transformers provide?", "Transformers provides interoperability between which frameworks?", ] start_positions = torch.LongTensor([95, 36, 110]) end_positions = torch.LongTensor([97, 88, 123]) p_mask = [[1]*12 + [0]* (125 -14) + [1,1], [1]* 7 + [0]* (120 - 9) + [1,1], [1]*12 + [0]* (125 -14) + [1,1], ] neg_log_loss = nn.CrossEntropyLoss() q_answer = XLNetQuestionAnswering(2, model, 768, 2, 2) optimizer = torch.optim.AdamW(q_answer.parameters())
In [14]:
q_answer.train() optimizer.zero_grad() for ith, question in enumerate(questions): start_logits, end_logits = q_answer( tokenizer(question, context, add_special_tokens=True, return_tensors='pt'), p_mask=torch.ByteTensor(p_mask[ith]), start_positions=start_positions[ith].view(1,1,1) ) loss = get_loss( criterion, start_logits, start_positions[ith].view(-1), end_logits, end_positions[ith].view(-1) ) print("\nTrue Start: {}, True End: {}\nPred Start Prob: {}, Pred End Prob: {}\nPred Max Start: {}, Pred Max End: {}\nPred Max Start Prob: {}, Pred Max end Prob:{}\nLoss: {}\n".format( start_positions[ith].item(), end_positions[ith].item(), torch.sigmoid(start_logits[:,start_positions[ith]]).item(), torch.sigmoid(end_logits[:, end_positions[ith]]).item(), torch.argmax(start_logits).item(), torch.argmax(end_logits).item(), torch.sigmoid(torch.max(start_logits)).item(), torch.sigmoid(torch.max(end_logits)).item(), loss.item() ) ) print("="*25) loss.backward() optimizer.step() print("\nConfirm that the gradients are computed for the original XLNet parameters.") for param in q_answer.parameters(): print(param.shape, param.grad.sum() if not param.grad is None else param.grad)
Out [14]:
True Start: 95, True End: 97 Pred Start Prob: 0.3128710687160492, Pred End Prob: 0.6340706944465637 Pred Max Start: 78, Pred Max End: 39 Pred Max Start Prob: 0.6634821891784668, Pred Max end Prob:0.6745399832725525 Loss: 4.699392318725586 ========================= True Start: 36, True End: 88 Pred Start Prob: 0.6726937294006348, Pred End Prob: 0.93719482421875 Pred Max Start: 25, Pred Max End: 96 Pred Max Start Prob: 0.8701046705245972, Pred Max end Prob:0.9859831929206848 Loss: 5.206717491149902 ========================= True Start: 110, True End: 123 Pred Start Prob: 0.8552238941192627, Pred End Prob: 0.0 Pred Max Start: 97, Pred Max End: 98 Pred Max Start Prob: 0.9279953241348267, Pred Max end Prob:0.4768109917640686 Loss: 5.000000075237331e+29 ========================= Confirm that the gradients are computed for the original XLNet parameters. torch.Size([1, 1, 768]) None torch.Size([32000, 768]) tensor(-0.6884) torch.Size([768, 12, 64]) tensor(-0.0334) torch.Size([768, 12, 64]) tensor(0.0020) torch.Size([768, 12, 64]) tensor(-0.2530) torch.Size([768, 12, 64]) tensor(0.1670) torch.Size([768, 12, 64]) tensor(-0.3664) torch.Size([12, 64]) tensor(0.0001) torch.Size([12, 64]) tensor(-0.0014) torch.Size([12, 64]) tensor(0.0102) torch.Size([2, 12, 64]) tensor(-9.2882e-10) torch.Size([768]) tensor(-0.0579) torch.Size([768]) tensor(-0.0086) torch.Size([768]) tensor(0.3276) torch.Size([768]) tensor(-0.4108) torch.Size([3072, 768]) tensor(-0.7222) torch.Size([3072]) tensor(0.0196) torch.Size([768, 3072]) tensor(1.9823) torch.Size([768]) tensor(-0.0305) torch.Size([768, 12, 64]) tensor(0.1537) torch.Size([768, 12, 64]) tensor(0.1984) torch.Size([768, 12, 64]) tensor(-0.1475) torch.Size([768, 12, 64]) tensor(-0.3304) torch.Size([768, 12, 64]) tensor(-0.3685) torch.Size([12, 64]) tensor(-0.0085) torch.Size([12, 64]) tensor(-9.0756e-05) torch.Size([12, 64]) tensor(0.0146) torch.Size([2, 12, 64]) tensor(-6.3806e-09) torch.Size([768]) tensor(-0.0117) torch.Size([768]) tensor(0.0255) torch.Size([768]) tensor(0.0031) torch.Size([768]) tensor(-0.4477) torch.Size([3072, 768]) tensor(-0.6170) torch.Size([3072]) tensor(-0.0502) torch.Size([768, 3072]) tensor(-0.3674) torch.Size([768]) tensor(0.0046) torch.Size([768, 12, 64]) tensor(0.1544) torch.Size([768, 12, 64]) tensor(0.0277) torch.Size([768, 12, 64]) tensor(-0.7144) torch.Size([768, 12, 64]) tensor(0.0200) torch.Size([768, 12, 64]) tensor(2.8775) torch.Size([12, 64]) tensor(-0.0066) torch.Size([12, 64]) tensor(-0.0012) torch.Size([12, 64]) tensor(-0.0104) torch.Size([2, 12, 64]) tensor(5.9663e-09) torch.Size([768]) tensor(0.0509) torch.Size([768]) tensor(-0.0095) torch.Size([768]) tensor(-0.0630) torch.Size([768]) tensor(0.0002) torch.Size([3072, 768]) tensor(0.3525) torch.Size([3072]) tensor(0.0530) torch.Size([768, 3072]) tensor(2.0844) torch.Size([768]) tensor(-0.0176) torch.Size([768, 12, 64]) tensor(0.2019) torch.Size([768, 12, 64]) tensor(0.2833) torch.Size([768, 12, 64]) tensor(0.6085) torch.Size([768, 12, 64]) tensor(-0.1063) torch.Size([768, 12, 64]) tensor(-0.7396) torch.Size([12, 64]) tensor(0.0198) torch.Size([12, 64]) tensor(-0.0023) torch.Size([12, 64]) tensor(-0.0082) torch.Size([2, 12, 64]) tensor(-2.4584e-09) torch.Size([768]) tensor(-0.2631) torch.Size([768]) tensor(-0.0389) torch.Size([768]) tensor(0.0537) torch.Size([768]) tensor(0.2082) torch.Size([3072, 768]) tensor(5.2597) torch.Size([3072]) tensor(-0.2162) torch.Size([768, 3072]) tensor(1.0177) torch.Size([768]) tensor(-0.0024) torch.Size([768, 12, 64]) tensor(0.2092) torch.Size([768, 12, 64]) tensor(-0.0467) torch.Size([768, 12, 64]) tensor(2.2503) torch.Size([768, 12, 64]) tensor(0.3033) torch.Size([768, 12, 64]) tensor(-1.2354) torch.Size([12, 64]) tensor(-0.0038) torch.Size([12, 64]) tensor(-0.0013) torch.Size([12, 64]) tensor(0.0100) torch.Size([2, 12, 64]) tensor(5.4883e-09) torch.Size([768]) tensor(-0.1086) torch.Size([768]) tensor(0.0420) torch.Size([768]) tensor(0.1451) torch.Size([768]) tensor(0.3843) torch.Size([3072, 768]) tensor(1.1731) torch.Size([3072]) tensor(-0.0912) torch.Size([768, 3072]) tensor(-0.0175) torch.Size([768]) tensor(0.0038) torch.Size([768, 12, 64]) tensor(-0.1283) torch.Size([768, 12, 64]) tensor(0.0159) torch.Size([768, 12, 64]) tensor(1.8471) torch.Size([768, 12, 64]) tensor(-0.0603) torch.Size([768, 12, 64]) tensor(-0.5668) torch.Size([12, 64]) tensor(-0.0011) torch.Size([12, 64]) tensor(0.0008) torch.Size([12, 64]) tensor(0.0101) torch.Size([2, 12, 64]) tensor(1.6712e-09) torch.Size([768]) tensor(0.2400) torch.Size([768]) tensor(-0.0586) torch.Size([768]) tensor(0.0619) torch.Size([768]) tensor(-0.2086) torch.Size([3072, 768]) tensor(1.8524) torch.Size([3072]) tensor(-0.1007) torch.Size([768, 3072]) tensor(1.4429) torch.Size([768]) tensor(-0.0088) torch.Size([768, 12, 64]) tensor(0.5602) torch.Size([768, 12, 64]) tensor(-0.0269) torch.Size([768, 12, 64]) tensor(12.1219) torch.Size([768, 12, 64]) tensor(-0.5672) torch.Size([768, 12, 64]) tensor(-1.1886) torch.Size([12, 64]) tensor(-0.0125) torch.Size([12, 64]) tensor(-0.0005) torch.Size([12, 64]) tensor(-0.0047) torch.Size([2, 12, 64]) tensor(8.1666e-09) torch.Size([768]) tensor(0.3063) torch.Size([768]) tensor(-0.0601) torch.Size([768]) tensor(0.0642) torch.Size([768]) tensor(-1.2854) torch.Size([3072, 768]) tensor(0.6916) torch.Size([3072]) tensor(0.0861) torch.Size([768, 3072]) tensor(-5.8404) torch.Size([768]) tensor(-0.0209) torch.Size([768, 12, 64]) tensor(-0.2547) torch.Size([768, 12, 64]) tensor(0.0448) torch.Size([768, 12, 64]) tensor(-11.1511) torch.Size([768, 12, 64]) tensor(-0.9826) torch.Size([768, 12, 64]) tensor(-4.3894) torch.Size([12, 64]) tensor(-0.0002) torch.Size([12, 64]) tensor(0.0011) torch.Size([12, 64]) tensor(0.0086) torch.Size([2, 12, 64]) tensor(1.4228e-08) torch.Size([768]) tensor(0.3623) torch.Size([768]) tensor(0.4062) torch.Size([768]) tensor(-0.1065) torch.Size([768]) tensor(-3.0040) torch.Size([3072, 768]) tensor(5.2218) torch.Size([3072]) tensor(-0.2266) torch.Size([768, 3072]) tensor(9.2306) torch.Size([768]) tensor(0.0284) torch.Size([768, 12, 64]) tensor(-0.4832) torch.Size([768, 12, 64]) tensor(0.0096) torch.Size([768, 12, 64]) tensor(-10.0414) torch.Size([768, 12, 64]) tensor(0.5542) torch.Size([768, 12, 64]) tensor(-1.1537) torch.Size([12, 64]) tensor(0.0121) torch.Size([12, 64]) tensor(-0.0002) torch.Size([12, 64]) tensor(0.0021) torch.Size([2, 12, 64]) tensor(1.1473e-08) torch.Size([768]) tensor(0.1046) torch.Size([768]) tensor(-0.5176) torch.Size([768]) tensor(-0.1872) torch.Size([768]) tensor(-0.4094) torch.Size([3072, 768]) tensor(-2.1765) torch.Size([3072]) tensor(0.2907) torch.Size([768, 3072]) tensor(-17.4914) torch.Size([768]) tensor(-0.2188) torch.Size([768, 12, 64]) tensor(0.1278) torch.Size([768, 12, 64]) tensor(0.0075) torch.Size([768, 12, 64]) tensor(4.4422) torch.Size([768, 12, 64]) tensor(0.1824) torch.Size([768, 12, 64]) tensor(-0.2453) torch.Size([12, 64]) tensor(-0.0021) torch.Size([12, 64]) tensor(0.0010) torch.Size([12, 64]) tensor(-0.0038) torch.Size([2, 12, 64]) tensor(3.8835e-08) torch.Size([768]) tensor(0.0947) torch.Size([768]) tensor(-0.1559) torch.Size([768]) tensor(-0.4608) torch.Size([768]) tensor(-3.1537) torch.Size([3072, 768]) tensor(-0.0686) torch.Size([3072]) tensor(-0.1366) torch.Size([768, 3072]) tensor(-23.9523) torch.Size([768]) tensor(-0.0108) torch.Size([768, 12, 64]) tensor(-1.3291) torch.Size([768, 12, 64]) tensor(0.0523) torch.Size([768, 12, 64]) tensor(-1.8151) torch.Size([768, 12, 64]) tensor(-0.4039) torch.Size([768, 12, 64]) tensor(-0.0539) torch.Size([12, 64]) tensor(0.0087) torch.Size([12, 64]) tensor(-0.0025) torch.Size([12, 64]) tensor(0.0149) torch.Size([2, 12, 64]) tensor(-7.3160e-09) torch.Size([768]) tensor(-0.4885) torch.Size([768]) tensor(1.7355) torch.Size([768]) tensor(-0.3012) torch.Size([768]) tensor(1.5989) torch.Size([3072, 768]) tensor(10.9814) torch.Size([3072]) tensor(-1.0466) torch.Size([768, 3072]) tensor(7.2909) torch.Size([768]) tensor(0.0650) torch.Size([768, 12, 64]) tensor(4.3189) torch.Size([768, 12, 64]) tensor(-0.3560) torch.Size([768, 12, 64]) tensor(36.1090) torch.Size([768, 12, 64]) tensor(-2.3517) torch.Size([768, 12, 64]) tensor(-2.1882) torch.Size([12, 64]) tensor(0.0060) torch.Size([12, 64]) tensor(-0.0008) torch.Size([12, 64]) tensor(-0.0341) torch.Size([2, 12, 64]) tensor(4.8329e-08) torch.Size([768]) tensor(11.1667) torch.Size([768]) tensor(17.8769) torch.Size([768]) tensor(0.2330) torch.Size([768]) tensor(-0.1388) torch.Size([3072, 768]) tensor(20.5294) torch.Size([3072]) tensor(-1.7453) torch.Size([768, 3072]) tensor(-39.9618) torch.Size([768]) tensor(0.1220) torch.Size([1, 768]) tensor(-10.1781) torch.Size([1]) tensor(9.1270e-08) torch.Size([768, 1536]) tensor(64.5171) torch.Size([768]) tensor(-0.1971) torch.Size([768]) tensor(0.1439) torch.Size([768]) tensor(-0.2693) torch.Size([1, 768]) tensor(0.0336) torch.Size([1]) tensor(0.5000)
inference的forward以及实现Beam Search decoding
In [15]:
import numpy as np def decode(start_probs, end_probs, topk): """ 给定beam中预测的开始和结束概率,搜索topk个最佳答案 """ top_k_start = start_probs.shape[-1] top_k_end = end_probs.shape[-1] // top_k_start # 计算每一个(start, end)对的分数 P(start, end| sentence) = P(start|sentence) * P(end|start, sentence) joint_probs = dict() for i in range(top_k_start): for j in range(top_k_end): end_idx = i*top_k_end+j joint_probs[(i, end_idx)] = start_probs[i]*end_probs[end_idx] id_pairs, probs = zip(*sorted(joint_probs.items(), key=lambda kv:kv[1], reverse=True)[:topk]) start_ids, end_ids = zip(*id_pairs) return start_ids, end_ids, probs
In [16]:
# inference context = r""" Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch. """ questions = [ "How many pretrained models are available in Transformers?", "What does Transformers provide?", "Transformers provides interoperability between which frameworks?", ] q_answer.eval() for ith, question in enumerate(questions): inputs = tokenizer(question, context, add_special_tokens=True, return_tensors="pt") input_ids = inputs["input_ids"].tolist()[0] text_tokens = tokenizer.convert_ids_to_tokens(input_ids) start_probs, start_index, end_probs, end_index, stt_logits, end_logits = q_answer( inputs, p_mask=torch.ByteTensor(p_mask[ith]) ) pred_starts, pred_ends, probs = decode( start_probs.detach().squeeze().numpy(), end_probs.detach().squeeze().numpy(), 2) # 只打印一个答案 start = start_index[:, pred_starts[0]].item() end = end_index[:, pred_ends[0]].item() # print(probs, pred_starts, pred_ends) # print(len(input_ids), stt_logits.shape, end_logits.shape) # print(tokenizer.convert_ids_to_tokens(input_ids).index('?')) print("="*25) print("True start: {}, True end: {}".format( start_positions[ith].item(), end_positions[ith].item() )) print("Max answer prob: {:0.8f}, start idx: {}, end idx: {}".format( probs[0], start, end, )) print("-"*25) print("Question: '{}'".format(question)) print("Answer: '{}'".format(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[start:end])))) print("="*25)
Out [16]:
========================= True start: 95, True end: 97 Max answer prob: 0.00008122, start idx: 120, end idx: 122 ------------------------- Question: 'How many pretrained models are available in Transformers?' Answer: 'orch' ========================= ========================= True start: 36, True end: 88 Max answer prob: 0.00008121, start idx: 115, end idx: 117 ------------------------- Question: 'What does Transformers provide?' Answer: 'orch' ========================= ========================= True start: 110, True end: 123 Max answer prob: 0.00008122, start idx: 120, end idx: 122 ------------------------- Question: 'Transformers provides interoperability between which frameworks?' Answer: 'orch' =========================
In [16]: