Commit 89dd729d by TeacherZhu

Upload New File

parent 6465a4be
{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.8"},"colab":{"name":"Bert-BiLSTM-CRF_tutorial.ipynb","provenance":[{"file_id":"https://github.com/pytorch/tutorials/blob/gh-pages/_downloads/b3265db81c2bf86cc3e2b0dcdaddc850/advanced_tutorial.ipynb","timestamp":1597973227736}],"collapsed_sections":[]}},"cells":[{"cell_type":"code","metadata":{"id":"Meh59-bZ93fZ","colab_type":"code","colab":{}},"source":["%matplotlib inline"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"B9bKjdsL93fc","colab_type":"text"},"source":["\n","# Bert-BiLSTM-CRF 命名实体识别\n","\n","by Qiao for NLP7 2020-08-21\n","\n","1) 首先请回顾BiLSTM-CRF的review,Bert只作为encoder,替换BiLSTM原本的embedding,下游任务并无变化。\n","\n","2) notebook续用Pytorch官方[BiLSTM-CRF](https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html)的教程。在此基础上加上bert相关处理。\n","\n","### 核心:\n","- 前向算法(forward)\n","- Viterbi算法\n","\n","======================================================\n","### 要点\n","\n","令 $y$ 为标注序列,$x$为token序列, 模型计算的是条件概率:\n","\n","\\begin{align}P(y|x) = \\frac{\\exp{(\\text{Score}(x, y)})}{\\sum_{y'} \\exp{(\\text{Score}(x, y')})}\\end{align}\n","\n","得分函数可定义为句子各个位置的发射分数$f(y_t|x)$(特征)以及先后位置之间的转移分数$g(y_t|x, y_{t-1})$ 之和:\n","\n","\\begin{align}\\text{Score}(x,y) = \\sum_{t=1}^{T} f(y_t|x) + \\sum_{t=2}^{T} g(y_t|x, y_{t-1})\\end{align}\n","\n","请回顾,在Bi-LSTM CRF中, $f(y_t|x)$可由第t个token的隐状态来表示。$g(y_t|x,y_{t-1})$ 由参数矩阵$\\mathbf{P} \\in R^{K \\times K}$中的$P_{t,t-1}$表示, $K$ 是标签集合的元素个数. 在代码中$P_{ij}$表示的是由标签$t_j$ 转移到标签$t_i$,\n","\n","实际我们优化的是$\\log(P(y|x))$, (或最小化Negative Probability):\n","\\begin{align}\n","\\log(P(y|x)) = & \\text{ Score}(x, y) - \\log \\bigg(\\sum_{y'} \\exp \\big (\\text{Score}(x, y') \\big) \\bigg) \\\\ \n","= &\\sum_{t=1}^{T} f(y_t|x) + \\sum_{t=2}^{T} g(y_t|x, y_{t-1}) - \\\\ \n","&- \\log \\bigg ( {\\sum_{y'} \\bigg \\{ \\exp \\big( \\sum_{t=1}^{T} f(y'_t|x) + \\sum_{t=2}^{T} g(y'_t|x, y'_{t-1}) \\big ) \\bigg \\}} \\bigg ) \\\\\n","\\end{align}\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"x7J_JGYk93fc","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1598018058243,"user_tz":-480,"elapsed":5092,"user":{"displayName":"Chen Qiao","photoUrl":"","userId":"09795836144824686107"}},"outputId":"cd202ba3-4efa-45cc-e1d8-9cb4a17b6e05"},"source":["import torch\n","import torch.autograd as autograd\n","import torch.nn as nn\n","import torch.optim as optim\n","!pip install -q transformers\n","torch.manual_seed(1)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["<torch._C.Generator at 0x7f294aa74120>"]},"metadata":{"tags":[]},"execution_count":2}]},{"cell_type":"code","metadata":{"id":"sNdGSyJCrx7V","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":101},"executionInfo":{"status":"ok","timestamp":1598018065368,"user_tz":-480,"elapsed":12175,"user":{"displayName":"Chen Qiao","photoUrl":"","userId":"09795836144824686107"}},"outputId":"71027e3e-7f05-40ed-c59d-5f7cb1fd5368"},"source":["def show_bert_doctrine():\n"," tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)\n"," bert = BertModel.from_pretrained(BERT_MODEL_NAME)\n"," for k, v in tokenizer(\"I am a boy\", return_tensors=\"pt\").items():\n"," print(k, v)\n"," if k == \"input_ids\":\n"," print(tokenizer.convert_ids_to_tokens(v.squeeze()))\n"," h = bert(**tokenizer(\"I am a boy\", return_tensors=\"pt\"))[0]\n"," print(h.shape)\n","show_bert_doctrine()"],"execution_count":null,"outputs":[{"output_type":"stream","text":["input_ids tensor([[ 101, 146, 1821, 170, 2298, 102]])\n","['[CLS]', 'I', 'am', 'a', 'boy', '[SEP]']\n","token_type_ids tensor([[0, 0, 0, 0, 0, 0]])\n","attention_mask tensor([[1, 1, 1, 1, 1, 1]])\n","torch.Size([1, 6, 768])\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"sdTOuNXPmebL","colab_type":"code","colab":{}},"source":["from transformers import BertTokenizer, BertModel, BertConfig\n","BERT_MODEL_NAME = \"bert-base-cased\"\n","\n","class BertEmbedding(nn.Module):\n","\n"," def __init__(self):\n"," super(BertEmbedding, self).__init__()\n"," self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)\n"," \n"," def fix_params(self):\n"," for param in self.bert.parameters():\n"," param.requires_grad = False\n"," \n"," def free_params(self):\n"," for param in self.bert.parameters():\n"," param.requires_grad = True\n","\n"," def forward(self, inputs):\n"," return self.bert(**inputs)[0][:,1:-1,:]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4o4RS7Zk93fe","colab_type":"text"},"source":["Helper functions to make the code more readable.\n","\n"]},{"cell_type":"code","metadata":{"id":"Uj_Idqzj93fe","colab_type":"code","colab":{}},"source":["def argmax(vec):\n"," # return the argmax as a python int\n"," _, idx = torch.max(vec, 1)\n"," return idx.item()\n","\n","def prepare_sequence(seq, tags, tokenizer, tag_to_ix):\n"," tags = tags.split()\n"," targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)\n","\n"," # prepare inputs for bert model and find start tokens (for word piece tokens)\n"," input_ids = tokenizer(seq, return_tensors=\"pt\")\n"," word_pieces = tokenizer.convert_ids_to_tokens(input_ids['input_ids'].squeeze())[1:-1]\n"," token_starts = torch.LongTensor([i for i, wp in enumerate(word_pieces) if not wp.startswith(\"##\")])\n"," return input_ids, targets, token_starts\n","\n","# Compute log sum exp in a numerically stable way for the forward algorithm\n","def log_sum_exp(vec):\n"," max_score = vec[0, argmax(vec)]\n"," max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])\n"," return max_score + \\\n"," torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ogWDAvsb93fg","colab_type":"text"},"source":["Create model\n","\n"]},{"cell_type":"code","metadata":{"id":"pdc-ZNri93fh","colab_type":"code","colab":{}},"source":["class BiLSTM_CRF(nn.Module):\n","\n"," def __init__(self, tag_to_ix, embedding_dim=768, hidden_dim=768):\n"," super(BiLSTM_CRF, self).__init__()\n"," self.embedding_dim = embedding_dim\n"," self.hidden_dim = hidden_dim\n"," self.tag_to_ix = tag_to_ix\n"," self.tagset_size = len(tag_to_ix)\n","\n"," self.word_embeds = BertEmbedding()\n"," self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,\n"," num_layers=1, bidirectional=True)\n","\n"," # Maps the output of the LSTM into tag space.\n"," self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)\n","\n"," # Matrix of transition parameters. Entry i,j is the score of\n"," # transitioning *to* i *from* j.\n"," self.transitions = nn.Parameter(\n"," torch.randn(self.tagset_size, self.tagset_size))\n","\n"," # These two statements enforce the constraint that we never transfer\n"," # to the start tag and we never transfer from the stop tag\n"," self.transitions.data[tag_to_ix[START_TAG], :] = -10000\n"," self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000\n","\n"," self.hidden = self.init_hidden()\n","\n"," def fix_bert(self):\n"," self.word_embeds.fix_params()\n","\n"," def free_bert(self):\n"," self.word_embeds.free_params()\n","\n"," def init_hidden(self):\n"," return (torch.randn(2, 1, self.hidden_dim // 2),\n"," torch.randn(2, 1, self.hidden_dim // 2))\n","\n"," def _forward_alg(self, feats):\n"," # Do the forward algorithm to compute the partition function\n"," init_alphas = torch.full((1, self.tagset_size), -10000.)\n"," # START_TAG has all of the score.\n"," init_alphas[0][self.tag_to_ix[START_TAG]] = 0.\n","\n"," # Wrap in a variable so that we will get automatic backprop\n"," forward_var = init_alphas\n","\n"," # Iterate through the sentence\n"," for feat in feats:\n"," # alphas_t = [] # The forward tensors at this timestep\n"," # for next_tag in range(self.tagset_size):\n"," # # broadcast the emission score: it is the same regardless of\n"," # # the previous tag\n"," # emit_score = feat[next_tag].view(\n"," # 1, -1).expand(1, self.tagset_size)\n"," # # the ith entry of trans_score is the score of transitioning to\n"," # # next_tag from i\n"," # trans_score = self.transitions[next_tag].view(1, -1)\n"," # # The ith entry of next_tag_var is the value for the\n"," # # edge (i -> next_tag) before we do log-sum-exp\n"," # next_tag_var = forward_var + trans_score + emit_score\n"," # # The forward variable for this tag is log-sum-exp of all the\n"," # # scores.\n"," # alphas_t.append(log_sum_exp(next_tag_var).view(1))\n"," # forward_var = torch.cat(alphas_t).view(1, -1)\n","\n"," forward_var = torch.logsumexp(feat.expand(self.tagset_size, -1) + \n"," self.transitions.T + forward_var.view(-1, 1), \n"," dim=0, \n"," keepdim=True)\n"," \n"," \n"," terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]\n"," alpha = log_sum_exp(terminal_var)\n"," return alpha\n","\n"," def _get_lstm_features(self, embeds):\n"," self.hidden = self.init_hidden()\n"," embeds = embeds.view(embeds.shape[1], embeds.shape[0], -1)\n"," lstm_out, self.hidden = self.lstm(embeds, self.hidden)\n"," lstm_out = lstm_out.view(lstm_out.shape[1], lstm_out.shape[0], self.hidden_dim)\n"," lstm_feats = self.hidden2tag(lstm_out)\n"," return lstm_feats\n","\n"," def _score_sentence(self, feats, tags):\n"," # Gives the score of a provided tag sequence\n"," score = torch.zeros(1)\n"," tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])\n"," for i, feat in enumerate(feats):\n"," score = score + \\\n"," self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]\n"," score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]\n"," return score\n","\n"," def _viterbi_decode(self, feats):\n"," backpointers = []\n","\n"," # Initialize the viterbi variables in log space\n"," init_vvars = torch.full((1, self.tagset_size), -10000.)\n"," init_vvars[0][self.tag_to_ix[START_TAG]] = 0\n","\n"," # forward_var at step i holds the viterbi variables for step i-1\n"," forward_var = init_vvars\n"," for feat in feats:\n"," # bptrs_t = [] # holds the backpointers for this step\n"," # viterbivars_t = [] # holds the viterbi variables for this step\n","\n"," # for next_tag in range(self.tagset_size):\n"," # # next_tag_var[i] holds the viterbi variable for tag i at the\n"," # # previous step, plus the score of transitioning\n"," # # from tag i to next_tag.\n"," # # We don't include the emission scores here because the max\n"," # # does not depend on them (we add them in below)\n"," # next_tag_var = forward_var + self.transitions[next_tag]\n"," # # print(next_tag_var)\n"," # best_tag_id = argmax(next_tag_var)\n"," # bptrs_t.append(best_tag_id)\n"," # viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))\n","\n"," # # Now add in the emission scores, and assign forward_var to the set\n"," # # of viterbi variables we just computed\n"," # forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)\n"," # backpointers.append(bptrs_t)\n","\n"," scores = self.transitions + forward_var\n"," forward_var, bptrs = torch.max(scores, dim=1)\n"," forward_var = forward_var.view(1, -1) + feat.view(1, -1)\n"," backpointers.append(bptrs.cpu().numpy().tolist())\n","\n"," # Transition to STOP_TAG\n"," terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]\n"," best_tag_id = argmax(terminal_var)\n"," path_score = terminal_var[0][best_tag_id]\n","\n"," # Follow the back pointers to decode the best path.\n"," best_path = [best_tag_id]\n"," for bptrs_t in reversed(backpointers):\n"," best_tag_id = bptrs_t[best_tag_id]\n"," best_path.append(best_tag_id)\n"," # Pop off the start tag (we dont want to return that to the caller)\n"," start = best_path.pop()\n"," assert start == self.tag_to_ix[START_TAG] # Sanity check\n"," best_path.reverse()\n"," return path_score, best_path\n","\n"," def neg_log_likelihood(self, input_ids, tags, token_starts):\n"," embeds = self.word_embeds(input_ids)[:, token_starts]\n"," feats = self._get_lstm_features(embeds).squeeze() \n"," forward_score = self._forward_alg(feats)\n"," gold_score = self._score_sentence(feats, tags)\n"," return forward_score - gold_score\n"," \n","\n"," def forward(self, input_ids, token_starts): # dont confuse this with _forward_alg above.\n"," embeds = self.word_embeds(input_ids)[:, token_starts]\n"," # Get the emission scores from the BiLSTM\n"," lstm_feats = self._get_lstm_features(embeds).squeeze()\n","\n"," # Find the best path, given the features.\n"," score, tag_seq = self._viterbi_decode(lstm_feats)\n","\n"," return score, tag_seq"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"AXWIi5ELcdwW","colab_type":"code","colab":{}},"source":[""],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"M89cYK8P93fi","colab_type":"text"},"source":["Run training\n","\n"]},{"cell_type":"code","metadata":{"id":"rXq14dpY93fj","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":1000},"executionInfo":{"status":"ok","timestamp":1598018477035,"user_tz":-480,"elapsed":423773,"user":{"displayName":"Chen Qiao","photoUrl":"","userId":"09795836144824686107"}},"outputId":"b9d66845-5b30-4275-89b1-4f25a25d40c8"},"source":["START_TAG = \"<START>\"\n","STOP_TAG = \"<STOP>\"\n","EMBEDDING_DIM = 768\n","HIDDEN_DIM = 768\n","\n","# Make up some training data\n","training_data = [(\n"," \"the wall street journal reported today that apple corporation made money\",\n"," \"B I I I O O O B I O O\"\n","), \n","\n","(\n"," \"georgia tech is a university in georgia\",\n"," \"B I O O O O B\"\n",")\n","\n","]\n","\n","tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)\n","\n","tag_to_ix = {\"B\": 0, \"I\": 1, \"O\": 2, START_TAG: 3, STOP_TAG: 4}\n","\n","model = BiLSTM_CRF(tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)\n","# model.fix_bert()\n","# model.free_bert()\n","optimizer = optim.SGD([param for param in model.parameters() if param.requires_grad], lr=0.01, weight_decay=1e-4)\n","\n","# Check predictions before training\n","with torch.no_grad():\n"," precheck_sent, targets, token_starts = prepare_sequence(training_data[0][0], training_data[0][1], tokenizer, tag_to_ix)\n"," print(precheck_sent)\n"," print(targets)\n"," print(token_starts)\n"," print(model(precheck_sent, token_starts))\n","\n","\n","# Make sure prepare_sequence from earlier in the LSTM section is loaded\n","for epoch in range(\n"," 300): # again, normally you would NOT do 300 epochs, it is toy data\n"," for sentence, tags in training_data:\n"," # Step 1. Remember that Pytorch accumulates gradients.\n"," # We need to clear them out before each instance\n"," model.zero_grad()\n","\n"," # Step 2. Get our inputs ready for the network, that is,\n"," # turn them into Tensors of word indices.\n"," sentence_in, targets, token_starts = prepare_sequence(sentence, tags, tokenizer, tag_to_ix)\n"," # Step 3. Run our forward pass.\n"," loss = model.neg_log_likelihood(sentence_in, targets, token_starts)\n"," print(loss.item())\n"," # Step 4. Compute the loss, gradients, and update the parameters by\n"," # calling optimizer.step()\n"," loss.backward()\n"," optimizer.step()\n","\n","# Check predictions after training\n","with torch.no_grad():\n"," precheck_sent, targets, token_starts = prepare_sequence(training_data[0][0], training_data[0][1], tokenizer, tag_to_ix)\n"," print(model(precheck_sent, token_starts))\n","# We got it!"],"execution_count":null,"outputs":[{"output_type":"stream","text":["{'input_ids': tensor([[ 101, 1103, 2095, 2472, 4897, 2103, 2052, 1115, 12075, 9715,\n"," 1189, 1948, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n","tensor([0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2])\n","tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n","(tensor(13.2155), [0, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2])\n","18.524065017700195\n","9.69771957397461\n","14.04741096496582\n","8.564186096191406\n","12.153670310974121\n","6.623537540435791\n","10.821130752563477\n","4.869028091430664\n","6.799995422363281\n","3.2920217514038086\n","4.022052764892578\n","2.0438756942749023\n","2.6931686401367188\n","1.468423843383789\n","1.6920318603515625\n","1.061056137084961\n","1.3279285430908203\n","0.7923049926757812\n","0.9889354705810547\n","0.6197624206542969\n","0.8111076354980469\n","0.49001502990722656\n","0.5839614868164062\n","0.4056587219238281\n","0.4795722961425781\n","0.3584175109863281\n","0.40242767333984375\n","0.28954315185546875\n","0.35292816162109375\n","0.26021385192871094\n","0.31992340087890625\n","0.23885726928710938\n","0.28530120849609375\n","0.20218849182128906\n","0.2623939514160156\n","0.18845367431640625\n","0.23521804809570312\n","0.18128585815429688\n","0.2100372314453125\n","0.166534423828125\n","0.18750381469726562\n","0.1447906494140625\n","0.18436050415039062\n","0.14190673828125\n","0.1678924560546875\n","0.12019920349121094\n","0.15915298461914062\n","0.12767410278320312\n","0.156585693359375\n","0.11075401306152344\n","0.13356399536132812\n","0.112640380859375\n","0.13848114013671875\n","0.09398269653320312\n","0.11859512329101562\n","0.11187934875488281\n","0.12015533447265625\n","0.097015380859375\n","0.1128692626953125\n","0.08238029479980469\n","0.11193466186523438\n","0.0859375\n","0.10143661499023438\n","0.07348251342773438\n","0.09615325927734375\n","0.08072662353515625\n","0.09915542602539062\n","0.07670974731445312\n","0.09319305419921875\n","0.07427406311035156\n","0.08670425415039062\n","0.07334327697753906\n","0.08296585083007812\n","0.058940887451171875\n","0.08138656616210938\n","0.07139396667480469\n","0.07628631591796875\n","0.057964324951171875\n","0.07128524780273438\n","0.06273269653320312\n","0.07092666625976562\n","0.0547943115234375\n","0.06726837158203125\n","0.05399322509765625\n","0.07040786743164062\n","0.04906463623046875\n","0.06667709350585938\n","0.059417724609375\n","0.06259536743164062\n","0.04938507080078125\n","0.0604248046875\n","0.04944610595703125\n","0.0582275390625\n","0.045818328857421875\n","0.056118011474609375\n","0.05054473876953125\n","0.058383941650390625\n","0.046283721923828125\n","0.05562591552734375\n","0.041900634765625\n","0.05492401123046875\n","0.04074859619140625\n","0.056545257568359375\n","0.039409637451171875\n","0.04883575439453125\n","0.043308258056640625\n","0.048919677734375\n","0.04802703857421875\n","0.048126220703125\n","0.0400238037109375\n","0.04732513427734375\n","0.03691864013671875\n","0.045360565185546875\n","0.034305572509765625\n","0.04703521728515625\n","0.036441802978515625\n","0.04430389404296875\n","0.033931732177734375\n","0.04705810546875\n","0.0346527099609375\n","0.04831695556640625\n","0.032558441162109375\n","0.04573822021484375\n","0.03507232666015625\n","0.041988372802734375\n","0.031864166259765625\n","0.044342041015625\n","0.034694671630859375\n","0.042476654052734375\n","0.03163909912109375\n","0.0418243408203125\n","0.030185699462890625\n","0.04074859619140625\n","0.031215667724609375\n","0.038516998291015625\n","0.02901458740234375\n","0.035953521728515625\n","0.030879974365234375\n","0.034549713134765625\n","0.03258514404296875\n","0.033969879150390625\n","0.029605865478515625\n","0.03438568115234375\n","0.028820037841796875\n","0.034946441650390625\n","0.025196075439453125\n","0.038326263427734375\n","0.027629852294921875\n","0.03270721435546875\n","0.027362823486328125\n","0.03369903564453125\n","0.0268707275390625\n","0.033721923828125\n","0.027248382568359375\n","0.0313262939453125\n","0.021869659423828125\n","0.031650543212890625\n","0.023372650146484375\n","0.03081512451171875\n","0.026371002197265625\n","0.032398223876953125\n","0.023090362548828125\n","0.02802276611328125\n","0.02558135986328125\n","0.030231475830078125\n","0.021945953369140625\n","0.02909088134765625\n","0.0247344970703125\n","0.0286102294921875\n","0.023036956787109375\n","0.028675079345703125\n","0.0205535888671875\n","0.027423858642578125\n","0.0208587646484375\n","0.026210784912109375\n","0.021144866943359375\n","0.02787017822265625\n","0.021575927734375\n","0.029216766357421875\n","0.020732879638671875\n","0.026523590087890625\n","0.024517059326171875\n","0.025310516357421875\n","0.020252227783203125\n","0.02447509765625\n","0.0202789306640625\n","0.025203704833984375\n","0.0197906494140625\n","0.0258941650390625\n","0.02239227294921875\n","0.023662567138671875\n","0.021244049072265625\n","0.024349212646484375\n","0.020130157470703125\n","0.023754119873046875\n","0.019435882568359375\n","0.0241546630859375\n","0.01947784423828125\n","0.023235321044921875\n","0.016143798828125\n","0.023433685302734375\n","0.01739501953125\n","0.02231597900390625\n","0.017894744873046875\n","0.022922515869140625\n","0.02083587646484375\n","0.02474212646484375\n","0.02251434326171875\n","0.021392822265625\n","0.01758575439453125\n","0.022747039794921875\n","0.015537261962890625\n","0.02065277099609375\n","0.017360687255859375\n","0.020465850830078125\n","0.017726898193359375\n","0.022495269775390625\n","0.014987945556640625\n","0.020709991455078125\n","0.01609039306640625\n","0.020999908447265625\n","0.016460418701171875\n","0.01959228515625\n","0.01551055908203125\n","0.02266693115234375\n","0.016361236572265625\n","0.02111053466796875\n","0.01638031005859375\n","0.02060699462890625\n","0.01663970947265625\n","0.020244598388671875\n","0.015842437744140625\n","0.020843505859375\n","0.014667510986328125\n","0.01929473876953125\n","0.015880584716796875\n","0.02040863037109375\n","0.015750885009765625\n","0.018482208251953125\n","0.0142059326171875\n","0.01972198486328125\n","0.013668060302734375\n","0.0186920166015625\n","0.015918731689453125\n","0.018817901611328125\n","0.0158233642578125\n","0.0180816650390625\n","0.01428985595703125\n","0.01825714111328125\n","0.0151824951171875\n","0.019374847412109375\n","0.014495849609375\n","0.01856231689453125\n","0.014423370361328125\n","0.017303466796875\n","0.013736724853515625\n","0.018497467041015625\n","0.015117645263671875\n","0.017253875732421875\n","0.014644622802734375\n","0.016506195068359375\n","0.0121612548828125\n","0.0164642333984375\n","0.014312744140625\n","0.016620635986328125\n","0.014163970947265625\n","0.0162811279296875\n","0.01226806640625\n","0.01753997802734375\n","0.011386871337890625\n","0.01616668701171875\n","0.011974334716796875\n","0.016300201416015625\n","0.013896942138671875\n","0.017597198486328125\n","0.01239013671875\n","0.0173187255859375\n","0.01284027099609375\n","0.015384674072265625\n","0.01336669921875\n","0.01540374755859375\n","0.011402130126953125\n","0.015346527099609375\n","0.013134002685546875\n","0.015522003173828125\n","0.013195037841796875\n","0.01406097412109375\n","0.011493682861328125\n","0.014446258544921875\n","0.010650634765625\n","0.014507293701171875\n","0.011463165283203125\n","0.014896392822265625\n","0.01290130615234375\n","0.014240264892578125\n","0.01092529296875\n","0.015430450439453125\n","0.0108642578125\n","0.01447296142578125\n","0.012973785400390625\n","0.01438140869140625\n","0.01300048828125\n","0.013576507568359375\n","0.011295318603515625\n","0.01415252685546875\n","0.0108642578125\n","0.01309967041015625\n","0.01047515869140625\n","0.01422119140625\n","0.010467529296875\n","0.0141448974609375\n","0.01103973388671875\n","0.014190673828125\n","0.0120849609375\n","0.0131988525390625\n","0.010616302490234375\n","0.013736724853515625\n","0.010257720947265625\n","0.013278961181640625\n","0.010540008544921875\n","0.0131683349609375\n","0.01039886474609375\n","0.013446807861328125\n","0.011157989501953125\n","0.01340484619140625\n","0.0098114013671875\n","0.013446807861328125\n","0.009975433349609375\n","0.012508392333984375\n","0.01080322265625\n","0.014354705810546875\n","0.009792327880859375\n","0.012340545654296875\n","0.0105438232421875\n","0.013671875\n","0.0110626220703125\n","0.012943267822265625\n","0.010120391845703125\n","0.013378143310546875\n","0.00951385498046875\n","0.0131072998046875\n","0.0108489990234375\n","0.0121612548828125\n","0.009967803955078125\n","0.01238250732421875\n","0.010311126708984375\n","0.0124664306640625\n","0.010467529296875\n","0.0120391845703125\n","0.0117034912109375\n","0.0121307373046875\n","0.009685516357421875\n","0.012298583984375\n","0.0098724365234375\n","0.0128173828125\n","0.009761810302734375\n","0.01177978515625\n","0.01038360595703125\n","0.011627197265625\n","0.00908660888671875\n","0.0117034912109375\n","0.0099334716796875\n","0.012271881103515625\n","0.009540557861328125\n","0.011627197265625\n","0.010009765625\n","0.0126190185546875\n","0.01012420654296875\n","0.01081085205078125\n","0.007740020751953125\n","0.012298583984375\n","0.008716583251953125\n","0.01180267333984375\n","0.00844573974609375\n","0.0113372802734375\n","0.008884429931640625\n","0.011505126953125\n","0.00984954833984375\n","0.01119232177734375\n","0.00861358642578125\n","0.01031494140625\n","0.008487701416015625\n","0.01244354248046875\n","0.0093231201171875\n","0.01107025146484375\n","0.008533477783203125\n","0.01091766357421875\n","0.008617401123046875\n","0.010406494140625\n","0.0083465576171875\n","0.01100921630859375\n","0.009456634521484375\n","0.01100921630859375\n","0.008056640625\n","0.01030731201171875\n","0.007488250732421875\n","0.010772705078125\n","0.007480621337890625\n","0.01053619384765625\n","0.008472442626953125\n","0.010589599609375\n","0.0088958740234375\n","0.01128387451171875\n","0.009918212890625\n","0.01088714599609375\n","0.008514404296875\n","0.01080322265625\n","0.0071258544921875\n","0.010009765625\n","0.008731842041015625\n","0.0098724365234375\n","0.0072479248046875\n","0.0101318359375\n","0.008106231689453125\n","0.01024627685546875\n","0.00862884521484375\n","0.00954437255859375\n","0.0081939697265625\n","0.00975799560546875\n","0.008026123046875\n","0.009552001953125\n","0.00832366943359375\n","0.0098419189453125\n","0.0070343017578125\n","0.0099334716796875\n","0.0074920654296875\n","0.00946044921875\n","0.00853729248046875\n","0.00921630859375\n","0.0077972412109375\n","0.00951385498046875\n","0.007396697998046875\n","0.0100250244140625\n","0.009342193603515625\n","0.00954437255859375\n","0.0079803466796875\n","0.008880615234375\n","0.008434295654296875\n","0.0090484619140625\n","0.007110595703125\n","0.00968170166015625\n","0.0068359375\n","0.009857177734375\n","0.007167816162109375\n","0.00902557373046875\n","0.006305694580078125\n","0.0089874267578125\n","0.006649017333984375\n","0.00882720947265625\n","0.008533477783203125\n","0.0087127685546875\n","0.007343292236328125\n","0.008880615234375\n","0.008068084716796875\n","0.00888824462890625\n","0.007770538330078125\n","0.009124755859375\n","0.00626373291015625\n","0.00911712646484375\n","0.00612640380859375\n","0.0088958740234375\n","0.00647735595703125\n","0.0089111328125\n","0.007434844970703125\n","0.0087890625\n","0.007080078125\n","0.00887298583984375\n","0.00897216796875\n","0.00885009765625\n","0.008426666259765625\n","0.00882720947265625\n","0.00696563720703125\n","0.0092010498046875\n","0.007831573486328125\n","0.0085296630859375\n","0.010101318359375\n","0.00839996337890625\n","0.00801849365234375\n","0.0079345703125\n","0.006923675537109375\n","0.0084991455078125\n","0.006603240966796875\n","0.00806427001953125\n","0.006256103515625\n","0.008270263671875\n","0.005626678466796875\n","0.00952911376953125\n","0.006160736083984375\n","0.0076446533203125\n","0.006496429443359375\n","0.00849151611328125\n","0.006252288818359375\n","0.0078582763671875\n","0.008190155029296875\n","0.00811004638671875\n","0.006290435791015625\n","0.00807952880859375\n","0.0070343017578125\n","0.007781982421875\n","0.00659942626953125\n","0.007843017578125\n","0.00652313232421875\n","0.00832366943359375\n","0.00632476806640625\n","0.00812530517578125\n","0.0057373046875\n","0.0083160400390625\n","0.006359100341796875\n","0.00809478759765625\n","0.006908416748046875\n","0.00762939453125\n","0.006832122802734375\n","0.0077972412109375\n","0.006381988525390625\n","0.00762939453125\n","0.00646209716796875\n","0.00800323486328125\n","0.0063323974609375\n","0.0072479248046875\n","0.006053924560546875\n","0.00753021240234375\n","0.0057525634765625\n","0.0079193115234375\n","0.006107330322265625\n","0.00705718994140625\n","0.006580352783203125\n","0.00762176513671875\n","0.005779266357421875\n","0.00726318359375\n","0.006378173828125\n","0.007659912109375\n","0.0057525634765625\n","0.00732421875\n","0.00592041015625\n","0.00737762451171875\n","0.00604248046875\n","0.00757598876953125\n","0.00582122802734375\n","0.00702667236328125\n","0.00572967529296875\n","0.00705718994140625\n","0.005855560302734375\n","0.00701904296875\n","0.007171630859375\n","0.0075225830078125\n","0.00605010986328125\n","0.00777435302734375\n","0.006053924560546875\n","0.0075225830078125\n","0.006076812744140625\n","0.00733184814453125\n","0.005329132080078125\n","0.00736236572265625\n","0.00611114501953125\n","0.0070953369140625\n","0.005405426025390625\n","0.00687408447265625\n","0.005764007568359375\n","0.00737762451171875\n","0.005706787109375\n","0.00720977783203125\n","0.005462646484375\n","0.00725555419921875\n","0.00591278076171875\n","0.0068359375\n","0.00504302978515625\n","0.00681304931640625\n","0.00567626953125\n","0.00745391845703125\n","0.00548553466796875\n","0.0076904296875\n","0.006725311279296875\n","0.0070037841796875\n","0.006229400634765625\n","0.0073699951171875\n","0.005184173583984375\n","0.00695037841796875\n","0.004901885986328125\n","0.00705718994140625\n","0.005374908447265625\n","0.00798797607421875\n","0.005619049072265625\n","0.00702667236328125\n","0.00490570068359375\n","0.0074462890625\n","0.005405426025390625\n","0.0065155029296875\n","0.0051116943359375\n","0.0071868896484375\n","0.00540924072265625\n","0.0068511962890625\n","0.004940032958984375\n","0.00740814208984375\n","0.006267547607421875\n","0.0067901611328125\n","0.0050201416015625\n","0.006439208984375\n","0.00514984130859375\n","0.0063018798828125\n","0.00565338134765625\n","(tensor(69.1735), [0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2])\n"],"name":"stdout"}]}]}
{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.8"},"colab":{"name":"Bert-BiLSTM-CRF_tutorial.ipynb","provenance":[{"file_id":"https://github.com/pytorch/tutorials/blob/gh-pages/_downloads/b3265db81c2bf86cc3e2b0dcdaddc850/advanced_tutorial.ipynb","timestamp":1597973227736}],"collapsed_sections":[]}},"cells":[{"cell_type":"code","metadata":{"id":"Meh59-bZ93fZ","colab_type":"code","colab":{}},"source":["%matplotlib inline"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"B9bKjdsL93fc","colab_type":"text"},"source":["\n","# Bert-BiLSTM-CRF 命名实体识别\n","\n","by Qiao for NLP7 2020-08-21\n","\n","1) 首先请回顾BiLSTM-CRF的review,Bert只作为encoder,替换BiLSTM原本的embedding,下游任务并无变化。\n","\n","2) notebook续用Pytorch官方[BiLSTM-CRF](https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html)的教程。在此基础上加上bert相关处理。\n","\n","### 核心:\n","- 前向算法(forward)\n","- Viterbi算法\n","\n","======================================================\n","### 要点\n","\n","令 $y$ 为标注序列,$x$为token序列, 模型计算的是条件概率:\n","\n","\\begin{align}P(y|x) = \\frac{\\exp{(\\text{Score}(x, y)})}{\\sum_{y'} \\exp{(\\text{Score}(x, y')})}\\end{align}\n","\n","得分函数可定义为句子各个位置的发射分数$f(y_t|x)$(特征)以及先后位置之间的转移分数$g(y_t|x, y_{t-1})$ 之和:\n","\n","\\begin{align}\\text{Score}(x,y) = \\sum_{t=1}^{T} f(y_t|x) + \\sum_{t=2}^{T} g(y_t|x, y_{t-1})\\end{align}\n","\n","请回顾,在Bi-LSTM CRF中, $f(y_t|x)$可由第t个token的隐状态来表示。$g(y_t|x,y_{t-1})$ 由参数矩阵$\\mathbf{P} \\in R^{K \\times K}$中的$P_{t,t-1}$表示, $K$ 是标签集合的元素个数. 在代码中$P_{ij}$表示的是由标签$t_j$ 转移到标签$t_i$,\n","\n","实际我们优化的是$\\log(P(y|x))$, (或最小化Negative Probability):\n","\\begin{align}\n","\\log(P(y|x)) = & \\text{ Score}(x, y) - \\log \\bigg(\\sum_{y'} \\exp \\big (\\text{Score}(x, y') \\big) \\bigg) \\\\ \n","= &\\sum_{t=1}^{T} f(y_t|x) + \\sum_{t=2}^{T} g(y_t|x, y_{t-1}) - \\\\ \n","&- \\log \\bigg ( {\\sum_{y'} \\bigg \\{ \\exp \\big( \\sum_{t=1}^{T} f(y'_t|x) + \\sum_{t=2}^{T} g(y'_t|x, y'_{t-1}) \\big ) \\bigg \\}} \\bigg ) \\\\\n","\\end{align}\n","\n","\n","\n","\n","\n","\n"]},{"cell_type":"code","metadata":{"id":"x7J_JGYk93fc","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"executionInfo":{"status":"ok","timestamp":1598018058243,"user_tz":-480,"elapsed":5092,"user":{"displayName":"Chen Qiao","photoUrl":"","userId":"09795836144824686107"}},"outputId":"cd202ba3-4efa-45cc-e1d8-9cb4a17b6e05"},"source":["import torch\n","import torch.autograd as autograd\n","import torch.nn as nn\n","import torch.optim as optim\n","!pip install -q transformers\n","torch.manual_seed(1)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["<torch._C.Generator at 0x7f294aa74120>"]},"metadata":{"tags":[]},"execution_count":2}]},{"cell_type":"code","metadata":{"id":"sNdGSyJCrx7V","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":101},"executionInfo":{"status":"ok","timestamp":1598018065368,"user_tz":-480,"elapsed":12175,"user":{"displayName":"Chen Qiao","photoUrl":"","userId":"09795836144824686107"}},"outputId":"71027e3e-7f05-40ed-c59d-5f7cb1fd5368"},"source":["def show_bert_doctrine():\n"," tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)\n"," bert = BertModel.from_pretrained(BERT_MODEL_NAME)\n"," for k, v in tokenizer(\"I am a boy\", return_tensors=\"pt\").items():\n"," print(k, v)\n"," if k == \"input_ids\":\n"," print(tokenizer.convert_ids_to_tokens(v.squeeze()))\n"," h = bert(**tokenizer(\"I am a boy\", return_tensors=\"pt\"))[0]\n"," print(h.shape)\n","show_bert_doctrine()"],"execution_count":null,"outputs":[{"output_type":"stream","text":["input_ids tensor([[ 101, 146, 1821, 170, 2298, 102]])\n","['[CLS]', 'I', 'am', 'a', 'boy', '[SEP]']\n","token_type_ids tensor([[0, 0, 0, 0, 0, 0]])\n","attention_mask tensor([[1, 1, 1, 1, 1, 1]])\n","torch.Size([1, 6, 768])\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"sdTOuNXPmebL","colab_type":"code","colab":{}},"source":["from transformers import BertTokenizer, BertModel, BertConfig\n","BERT_MODEL_NAME = \"bert-base-cased\"\n","\n","class BertEmbedding(nn.Module):\n","\n"," def __init__(self):\n"," super(BertEmbedding, self).__init__()\n"," self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)\n"," \n"," def fix_params(self):\n"," for param in self.bert.parameters():\n"," param.requires_grad = False\n"," \n"," def free_params(self):\n"," for param in self.bert.parameters():\n"," param.requires_grad = True\n","\n"," def forward(self, inputs):\n"," return self.bert(**inputs)[0][:,1:-1,:]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4o4RS7Zk93fe","colab_type":"text"},"source":["Helper functions to make the code more readable.\n","\n"]},{"cell_type":"code","metadata":{"id":"Uj_Idqzj93fe","colab_type":"code","colab":{}},"source":["def argmax(vec):\n"," # return the argmax as a python int\n"," _, idx = torch.max(vec, 1)\n"," return idx.item()\n","\n","def prepare_sequence(seq, tags, tokenizer, tag_to_ix):\n"," tags = tags.split()\n"," targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)\n","\n"," # prepare inputs for bert model and find start tokens (for word piece tokens)\n"," input_ids = tokenizer(seq, return_tensors=\"pt\")\n"," word_pieces = tokenizer.convert_ids_to_tokens(input_ids['input_ids'].squeeze())[1:-1]\n"," token_starts = torch.LongTensor([i for i, wp in enumerate(word_pieces) if not wp.startswith(\"##\")])\n"," return input_ids, targets, token_starts\n","\n","# Compute log sum exp in a numerically stable way for the forward algorithm\n","def log_sum_exp(vec):\n"," max_score = vec[0, argmax(vec)]\n"," max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])\n"," return max_score + \\\n"," torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ogWDAvsb93fg","colab_type":"text"},"source":["Create model\n","\n"]},{"cell_type":"code","metadata":{"id":"pdc-ZNri93fh","colab_type":"code","colab":{}},"source":["class BiLSTM_CRF(nn.Module):\n","\n"," def __init__(self, tag_to_ix, embedding_dim=768, hidden_dim=768):\n"," super(BiLSTM_CRF, self).__init__()\n"," self.embedding_dim = embedding_dim\n"," self.hidden_dim = hidden_dim\n"," self.tag_to_ix = tag_to_ix\n"," self.tagset_size = len(tag_to_ix)\n","\n"," self.word_embeds = BertEmbedding()\n"," self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,\n"," num_layers=1, bidirectional=True)\n","\n"," # Maps the output of the LSTM into tag space.\n"," self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)\n","\n"," # Matrix of transition parameters. Entry i,j is the score of\n"," # transitioning *to* i *from* j.\n"," self.transitions = nn.Parameter(\n"," torch.randn(self.tagset_size, self.tagset_size))\n","\n"," # These two statements enforce the constraint that we never transfer\n"," # to the start tag and we never transfer from the stop tag\n"," self.transitions.data[tag_to_ix[START_TAG], :] = -10000\n"," self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000\n","\n"," self.hidden = self.init_hidden()\n","\n"," def fix_bert(self):\n"," self.word_embeds.fix_params()\n","\n"," def free_bert(self):\n"," self.word_embeds.free_params()\n","\n"," def init_hidden(self):\n"," return (torch.randn(2, 1, self.hidden_dim // 2),\n"," torch.randn(2, 1, self.hidden_dim // 2))\n","\n"," def _forward_alg(self, feats):\n"," # Do the forward algorithm to compute the partition function\n"," init_alphas = torch.full((1, self.tagset_size), -10000.)\n"," # START_TAG has all of the score.\n"," init_alphas[0][self.tag_to_ix[START_TAG]] = 0.\n","\n"," # Wrap in a variable so that we will get automatic backprop\n"," forward_var = init_alphas\n","\n"," # Iterate through the sentence\n"," for feat in feats:\n"," # alphas_t = [] # The forward tensors at this timestep\n"," # for next_tag in range(self.tagset_size):\n"," # # broadcast the emission score: it is the same regardless of\n"," # # the previous tag\n"," # emit_score = feat[next_tag].view(\n"," # 1, -1).expand(1, self.tagset_size)\n"," # # the ith entry of trans_score is the score of transitioning to\n"," # # next_tag from i\n"," # trans_score = self.transitions[next_tag].view(1, -1)\n"," # # The ith entry of next_tag_var is the value for the\n"," # # edge (i -> next_tag) before we do log-sum-exp\n"," # next_tag_var = forward_var + trans_score + emit_score\n"," # # The forward variable for this tag is log-sum-exp of all the\n"," # # scores.\n"," # alphas_t.append(log_sum_exp(next_tag_var).view(1))\n"," # forward_var = torch.cat(alphas_t).view(1, -1)\n","\n"," forward_var = torch.logsumexp(feat.expand(self.tagset_size, -1) + \n"," self.transitions.T + forward_var.view(-1, 1), \n"," dim=0, \n"," keepdim=True)\n"," \n"," \n"," terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]\n"," alpha = log_sum_exp(terminal_var)\n"," return alpha\n","\n"," def _get_lstm_features(self, embeds):\n"," self.hidden = self.init_hidden()\n"," embeds = embeds.view(embeds.shape[1], embeds.shape[0], -1)\n"," lstm_out, self.hidden = self.lstm(embeds, self.hidden)\n"," lstm_out = lstm_out.view(lstm_out.shape[1], lstm_out.shape[0], self.hidden_dim)\n"," lstm_feats = self.hidden2tag(lstm_out)\n"," return lstm_feats\n","\n"," def _score_sentence(self, feats, tags):\n"," # Gives the score of a provided tag sequence\n"," score = torch.zeros(1)\n"," tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])\n"," for i, feat in enumerate(feats):\n"," score = score + \\\n"," self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]\n"," score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]\n"," return score\n","\n"," def _viterbi_decode(self, feats):\n"," backpointers = []\n","\n"," # Initialize the viterbi variables in log space\n"," init_vvars = torch.full((1, self.tagset_size), -10000.)\n"," init_vvars[0][self.tag_to_ix[START_TAG]] = 0\n","\n"," # forward_var at step i holds the viterbi variables for step i-1\n"," forward_var = init_vvars\n"," for feat in feats:\n"," # bptrs_t = [] # holds the backpointers for this step\n"," # viterbivars_t = [] # holds the viterbi variables for this step\n","\n"," # for next_tag in range(self.tagset_size):\n"," # # next_tag_var[i] holds the viterbi variable for tag i at the\n"," # # previous step, plus the score of transitioning\n"," # # from tag i to next_tag.\n"," # # We don't include the emission scores here because the max\n"," # # does not depend on them (we add them in below)\n"," # next_tag_var = forward_var + self.transitions[next_tag]\n"," # # print(next_tag_var)\n"," # best_tag_id = argmax(next_tag_var)\n"," # bptrs_t.append(best_tag_id)\n"," # viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))\n","\n"," # # Now add in the emission scores, and assign forward_var to the set\n"," # # of viterbi variables we just computed\n"," # forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)\n"," # backpointers.append(bptrs_t)\n","\n"," scores = self.transitions + forward_var\n"," forward_var, bptrs = torch.max(scores, dim=1)\n"," forward_var = forward_var.view(1, -1) + feat.view(1, -1)\n"," backpointers.append(bptrs.cpu().numpy().tolist())\n","\n"," # Transition to STOP_TAG\n"," terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]\n"," best_tag_id = argmax(terminal_var)\n"," path_score = terminal_var[0][best_tag_id]\n","\n"," # Follow the back pointers to decode the best path.\n"," best_path = [best_tag_id]\n"," for bptrs_t in reversed(backpointers):\n"," best_tag_id = bptrs_t[best_tag_id]\n"," best_path.append(best_tag_id)\n"," # Pop off the start tag (we dont want to return that to the caller)\n"," start = best_path.pop()\n"," assert start == self.tag_to_ix[START_TAG] # Sanity check\n"," best_path.reverse()\n"," return path_score, best_path\n","\n"," def neg_log_likelihood(self, input_ids, tags, token_starts):\n"," embeds = self.word_embeds(input_ids)[:, token_starts]\n"," feats = self._get_lstm_features(embeds).squeeze() \n"," forward_score = self._forward_alg(feats)\n"," gold_score = self._score_sentence(feats, tags)\n"," return forward_score - gold_score\n"," \n","\n"," def forward(self, input_ids, token_starts): # dont confuse this with _forward_alg above.\n"," embeds = self.word_embeds(input_ids)[:, token_starts]\n"," # Get the emission scores from the BiLSTM\n"," lstm_feats = self._get_lstm_features(embeds).squeeze()\n","\n"," # Find the best path, given the features.\n"," score, tag_seq = self._viterbi_decode(lstm_feats)\n","\n"," return score, tag_seq"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"AXWIi5ELcdwW","colab_type":"code","colab":{}},"source":[""],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"M89cYK8P93fi","colab_type":"text"},"source":["Run training\n","\n"]},{"cell_type":"code","metadata":{"id":"rXq14dpY93fj","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":1000},"executionInfo":{"status":"ok","timestamp":1598018477035,"user_tz":-480,"elapsed":423773,"user":{"displayName":"Chen Qiao","photoUrl":"","userId":"09795836144824686107"}},"outputId":"b9d66845-5b30-4275-89b1-4f25a25d40c8"},"source":["START_TAG = \"<START>\"\n","STOP_TAG = \"<STOP>\"\n","EMBEDDING_DIM = 768\n","HIDDEN_DIM = 768\n","\n","# Make up some training data\n","training_data = [(\n"," \"the wall street journal reported today that apple corporation made money\",\n"," \"B I I I O O O B I O O\"\n","), \n","\n","(\n"," \"georgia tech is a university in georgia\",\n"," \"B I O O O O B\"\n",")\n","\n","]\n","\n","tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)\n","\n","tag_to_ix = {\"B\": 0, \"I\": 1, \"O\": 2, START_TAG: 3, STOP_TAG: 4}\n","\n","model = BiLSTM_CRF(tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)\n","# model.fix_bert()\n","# model.free_bert()\n","optimizer = optim.SGD([param for param in model.parameters() if param.requires_grad], lr=0.01, weight_decay=1e-4)\n","\n","# Check predictions before training\n","with torch.no_grad():\n"," precheck_sent, targets, token_starts = prepare_sequence(training_data[0][0], training_data[0][1], tokenizer, tag_to_ix)\n"," print(precheck_sent)\n"," print(targets)\n"," print(token_starts)\n"," print(model(precheck_sent, token_starts))\n","\n","\n","# Make sure prepare_sequence from earlier in the LSTM section is loaded\n","for epoch in range(\n"," 300): # again, normally you would NOT do 300 epochs, it is toy data\n"," for sentence, tags in training_data:\n"," # Step 1. Remember that Pytorch accumulates gradients.\n"," # We need to clear them out before each instance\n"," model.zero_grad()\n","\n"," # Step 2. Get our inputs ready for the network, that is,\n"," # turn them into Tensors of word indices.\n"," sentence_in, targets, token_starts = prepare_sequence(sentence, tags, tokenizer, tag_to_ix)\n"," # Step 3. Run our forward pass.\n"," loss = model.neg_log_likelihood(sentence_in, targets, token_starts)\n"," print(loss.item())\n"," # Step 4. Compute the loss, gradients, and update the parameters by\n"," # calling optimizer.step()\n"," loss.backward()\n"," optimizer.step()\n","\n","# Check predictions after training\n","with torch.no_grad():\n"," precheck_sent, targets, token_starts = prepare_sequence(training_data[0][0], training_data[0][1], tokenizer, tag_to_ix)\n"," print(model(precheck_sent, token_starts))\n","# We got it!"],"execution_count":null,"outputs":[{"output_type":"stream","text":["{'input_ids': tensor([[ 101, 1103, 2095, 2472, 4897, 2103, 2052, 1115, 12075, 9715,\n"," 1189, 1948, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n","tensor([0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2])\n","tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n","(tensor(13.2155), [0, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2])\n","18.524065017700195\n","9.69771957397461\n","14.04741096496582\n","8.564186096191406\n","12.153670310974121\n","6.623537540435791\n","10.821130752563477\n","4.869028091430664\n","6.799995422363281\n","3.2920217514038086\n","4.022052764892578\n","2.0438756942749023\n","2.6931686401367188\n","1.468423843383789\n","1.6920318603515625\n","1.061056137084961\n","1.3279285430908203\n","0.7923049926757812\n","0.9889354705810547\n","0.6197624206542969\n","0.8111076354980469\n","0.49001502990722656\n","0.5839614868164062\n","0.4056587219238281\n","0.4795722961425781\n","0.3584175109863281\n","0.40242767333984375\n","0.28954315185546875\n","0.35292816162109375\n","0.26021385192871094\n","0.31992340087890625\n","0.23885726928710938\n","0.28530120849609375\n","0.20218849182128906\n","0.2623939514160156\n","0.18845367431640625\n","0.23521804809570312\n","0.18128585815429688\n","0.2100372314453125\n","0.166534423828125\n","0.18750381469726562\n","0.1447906494140625\n","0.18436050415039062\n","0.14190673828125\n","0.1678924560546875\n","0.12019920349121094\n","0.15915298461914062\n","0.12767410278320312\n","0.156585693359375\n","0.11075401306152344\n","0.13356399536132812\n","0.112640380859375\n","0.13848114013671875\n","0.09398269653320312\n","0.11859512329101562\n","0.11187934875488281\n","0.12015533447265625\n","0.097015380859375\n","0.1128692626953125\n","0.08238029479980469\n","0.11193466186523438\n","0.0859375\n","0.10143661499023438\n","0.07348251342773438\n","0.09615325927734375\n","0.08072662353515625\n","0.09915542602539062\n","0.07670974731445312\n","0.09319305419921875\n","0.07427406311035156\n","0.08670425415039062\n","0.07334327697753906\n","0.08296585083007812\n","0.058940887451171875\n","0.08138656616210938\n","0.07139396667480469\n","0.07628631591796875\n","0.057964324951171875\n","0.07128524780273438\n","0.06273269653320312\n","0.07092666625976562\n","0.0547943115234375\n","0.06726837158203125\n","0.05399322509765625\n","0.07040786743164062\n","0.04906463623046875\n","0.06667709350585938\n","0.059417724609375\n","0.06259536743164062\n","0.04938507080078125\n","0.0604248046875\n","0.04944610595703125\n","0.0582275390625\n","0.045818328857421875\n","0.056118011474609375\n","0.05054473876953125\n","0.058383941650390625\n","0.046283721923828125\n","0.05562591552734375\n","0.041900634765625\n","0.05492401123046875\n","0.04074859619140625\n","0.056545257568359375\n","0.039409637451171875\n","0.04883575439453125\n","0.043308258056640625\n","0.048919677734375\n","0.04802703857421875\n","0.048126220703125\n","0.0400238037109375\n","0.04732513427734375\n","0.03691864013671875\n","0.045360565185546875\n","0.034305572509765625\n","0.04703521728515625\n","0.036441802978515625\n","0.04430389404296875\n","0.033931732177734375\n","0.04705810546875\n","0.0346527099609375\n","0.04831695556640625\n","0.032558441162109375\n","0.04573822021484375\n","0.03507232666015625\n","0.041988372802734375\n","0.031864166259765625\n","0.044342041015625\n","0.034694671630859375\n","0.042476654052734375\n","0.03163909912109375\n","0.0418243408203125\n","0.030185699462890625\n","0.04074859619140625\n","0.031215667724609375\n","0.038516998291015625\n","0.02901458740234375\n","0.035953521728515625\n","0.030879974365234375\n","0.034549713134765625\n","0.03258514404296875\n","0.033969879150390625\n","0.029605865478515625\n","0.03438568115234375\n","0.028820037841796875\n","0.034946441650390625\n","0.025196075439453125\n","0.038326263427734375\n","0.027629852294921875\n","0.03270721435546875\n","0.027362823486328125\n","0.03369903564453125\n","0.0268707275390625\n","0.033721923828125\n","0.027248382568359375\n","0.0313262939453125\n","0.021869659423828125\n","0.031650543212890625\n","0.023372650146484375\n","0.03081512451171875\n","0.026371002197265625\n","0.032398223876953125\n","0.023090362548828125\n","0.02802276611328125\n","0.02558135986328125\n","0.030231475830078125\n","0.021945953369140625\n","0.02909088134765625\n","0.0247344970703125\n","0.0286102294921875\n","0.023036956787109375\n","0.028675079345703125\n","0.0205535888671875\n","0.027423858642578125\n","0.0208587646484375\n","0.026210784912109375\n","0.021144866943359375\n","0.02787017822265625\n","0.021575927734375\n","0.029216766357421875\n","0.020732879638671875\n","0.026523590087890625\n","0.024517059326171875\n","0.025310516357421875\n","0.020252227783203125\n","0.02447509765625\n","0.0202789306640625\n","0.025203704833984375\n","0.0197906494140625\n","0.0258941650390625\n","0.02239227294921875\n","0.023662567138671875\n","0.021244049072265625\n","0.024349212646484375\n","0.020130157470703125\n","0.023754119873046875\n","0.019435882568359375\n","0.0241546630859375\n","0.01947784423828125\n","0.023235321044921875\n","0.016143798828125\n","0.023433685302734375\n","0.01739501953125\n","0.02231597900390625\n","0.017894744873046875\n","0.022922515869140625\n","0.02083587646484375\n","0.02474212646484375\n","0.02251434326171875\n","0.021392822265625\n","0.01758575439453125\n","0.022747039794921875\n","0.015537261962890625\n","0.02065277099609375\n","0.017360687255859375\n","0.020465850830078125\n","0.017726898193359375\n","0.022495269775390625\n","0.014987945556640625\n","0.020709991455078125\n","0.01609039306640625\n","0.020999908447265625\n","0.016460418701171875\n","0.01959228515625\n","0.01551055908203125\n","0.02266693115234375\n","0.016361236572265625\n","0.02111053466796875\n","0.01638031005859375\n","0.02060699462890625\n","0.01663970947265625\n","0.020244598388671875\n","0.015842437744140625\n","0.020843505859375\n","0.014667510986328125\n","0.01929473876953125\n","0.015880584716796875\n","0.02040863037109375\n","0.015750885009765625\n","0.018482208251953125\n","0.0142059326171875\n","0.01972198486328125\n","0.013668060302734375\n","0.0186920166015625\n","0.015918731689453125\n","0.018817901611328125\n","0.0158233642578125\n","0.0180816650390625\n","0.01428985595703125\n","0.01825714111328125\n","0.0151824951171875\n","0.019374847412109375\n","0.014495849609375\n","0.01856231689453125\n","0.014423370361328125\n","0.017303466796875\n","0.013736724853515625\n","0.018497467041015625\n","0.015117645263671875\n","0.017253875732421875\n","0.014644622802734375\n","0.016506195068359375\n","0.0121612548828125\n","0.0164642333984375\n","0.014312744140625\n","0.016620635986328125\n","0.014163970947265625\n","0.0162811279296875\n","0.01226806640625\n","0.01753997802734375\n","0.011386871337890625\n","0.01616668701171875\n","0.011974334716796875\n","0.016300201416015625\n","0.013896942138671875\n","0.017597198486328125\n","0.01239013671875\n","0.0173187255859375\n","0.01284027099609375\n","0.015384674072265625\n","0.01336669921875\n","0.01540374755859375\n","0.011402130126953125\n","0.015346527099609375\n","0.013134002685546875\n","0.015522003173828125\n","0.013195037841796875\n","0.01406097412109375\n","0.011493682861328125\n","0.014446258544921875\n","0.010650634765625\n","0.014507293701171875\n","0.011463165283203125\n","0.014896392822265625\n","0.01290130615234375\n","0.014240264892578125\n","0.01092529296875\n","0.015430450439453125\n","0.0108642578125\n","0.01447296142578125\n","0.012973785400390625\n","0.01438140869140625\n","0.01300048828125\n","0.013576507568359375\n","0.011295318603515625\n","0.01415252685546875\n","0.0108642578125\n","0.01309967041015625\n","0.01047515869140625\n","0.01422119140625\n","0.010467529296875\n","0.0141448974609375\n","0.01103973388671875\n","0.014190673828125\n","0.0120849609375\n","0.0131988525390625\n","0.010616302490234375\n","0.013736724853515625\n","0.010257720947265625\n","0.013278961181640625\n","0.010540008544921875\n","0.0131683349609375\n","0.01039886474609375\n","0.013446807861328125\n","0.011157989501953125\n","0.01340484619140625\n","0.0098114013671875\n","0.013446807861328125\n","0.009975433349609375\n","0.012508392333984375\n","0.01080322265625\n","0.014354705810546875\n","0.009792327880859375\n","0.012340545654296875\n","0.0105438232421875\n","0.013671875\n","0.0110626220703125\n","0.012943267822265625\n","0.010120391845703125\n","0.013378143310546875\n","0.00951385498046875\n","0.0131072998046875\n","0.0108489990234375\n","0.0121612548828125\n","0.009967803955078125\n","0.01238250732421875\n","0.010311126708984375\n","0.0124664306640625\n","0.010467529296875\n","0.0120391845703125\n","0.0117034912109375\n","0.0121307373046875\n","0.009685516357421875\n","0.012298583984375\n","0.0098724365234375\n","0.0128173828125\n","0.009761810302734375\n","0.01177978515625\n","0.01038360595703125\n","0.011627197265625\n","0.00908660888671875\n","0.0117034912109375\n","0.0099334716796875\n","0.012271881103515625\n","0.009540557861328125\n","0.011627197265625\n","0.010009765625\n","0.0126190185546875\n","0.01012420654296875\n","0.01081085205078125\n","0.007740020751953125\n","0.012298583984375\n","0.008716583251953125\n","0.01180267333984375\n","0.00844573974609375\n","0.0113372802734375\n","0.008884429931640625\n","0.011505126953125\n","0.00984954833984375\n","0.01119232177734375\n","0.00861358642578125\n","0.01031494140625\n","0.008487701416015625\n","0.01244354248046875\n","0.0093231201171875\n","0.01107025146484375\n","0.008533477783203125\n","0.01091766357421875\n","0.008617401123046875\n","0.010406494140625\n","0.0083465576171875\n","0.01100921630859375\n","0.009456634521484375\n","0.01100921630859375\n","0.008056640625\n","0.01030731201171875\n","0.007488250732421875\n","0.010772705078125\n","0.007480621337890625\n","0.01053619384765625\n","0.008472442626953125\n","0.010589599609375\n","0.0088958740234375\n","0.01128387451171875\n","0.009918212890625\n","0.01088714599609375\n","0.008514404296875\n","0.01080322265625\n","0.0071258544921875\n","0.010009765625\n","0.008731842041015625\n","0.0098724365234375\n","0.0072479248046875\n","0.0101318359375\n","0.008106231689453125\n","0.01024627685546875\n","0.00862884521484375\n","0.00954437255859375\n","0.0081939697265625\n","0.00975799560546875\n","0.008026123046875\n","0.009552001953125\n","0.00832366943359375\n","0.0098419189453125\n","0.0070343017578125\n","0.0099334716796875\n","0.0074920654296875\n","0.00946044921875\n","0.00853729248046875\n","0.00921630859375\n","0.0077972412109375\n","0.00951385498046875\n","0.007396697998046875\n","0.0100250244140625\n","0.009342193603515625\n","0.00954437255859375\n","0.0079803466796875\n","0.008880615234375\n","0.008434295654296875\n","0.0090484619140625\n","0.007110595703125\n","0.00968170166015625\n","0.0068359375\n","0.009857177734375\n","0.007167816162109375\n","0.00902557373046875\n","0.006305694580078125\n","0.0089874267578125\n","0.006649017333984375\n","0.00882720947265625\n","0.008533477783203125\n","0.0087127685546875\n","0.007343292236328125\n","0.008880615234375\n","0.008068084716796875\n","0.00888824462890625\n","0.007770538330078125\n","0.009124755859375\n","0.00626373291015625\n","0.00911712646484375\n","0.00612640380859375\n","0.0088958740234375\n","0.00647735595703125\n","0.0089111328125\n","0.007434844970703125\n","0.0087890625\n","0.007080078125\n","0.00887298583984375\n","0.00897216796875\n","0.00885009765625\n","0.008426666259765625\n","0.00882720947265625\n","0.00696563720703125\n","0.0092010498046875\n","0.007831573486328125\n","0.0085296630859375\n","0.010101318359375\n","0.00839996337890625\n","0.00801849365234375\n","0.0079345703125\n","0.006923675537109375\n","0.0084991455078125\n","0.006603240966796875\n","0.00806427001953125\n","0.006256103515625\n","0.008270263671875\n","0.005626678466796875\n","0.00952911376953125\n","0.006160736083984375\n","0.0076446533203125\n","0.006496429443359375\n","0.00849151611328125\n","0.006252288818359375\n","0.0078582763671875\n","0.008190155029296875\n","0.00811004638671875\n","0.006290435791015625\n","0.00807952880859375\n","0.0070343017578125\n","0.007781982421875\n","0.00659942626953125\n","0.007843017578125\n","0.00652313232421875\n","0.00832366943359375\n","0.00632476806640625\n","0.00812530517578125\n","0.0057373046875\n","0.0083160400390625\n","0.006359100341796875\n","0.00809478759765625\n","0.006908416748046875\n","0.00762939453125\n","0.006832122802734375\n","0.0077972412109375\n","0.006381988525390625\n","0.00762939453125\n","0.00646209716796875\n","0.00800323486328125\n","0.0063323974609375\n","0.0072479248046875\n","0.006053924560546875\n","0.00753021240234375\n","0.0057525634765625\n","0.0079193115234375\n","0.006107330322265625\n","0.00705718994140625\n","0.006580352783203125\n","0.00762176513671875\n","0.005779266357421875\n","0.00726318359375\n","0.006378173828125\n","0.007659912109375\n","0.0057525634765625\n","0.00732421875\n","0.00592041015625\n","0.00737762451171875\n","0.00604248046875\n","0.00757598876953125\n","0.00582122802734375\n","0.00702667236328125\n","0.00572967529296875\n","0.00705718994140625\n","0.005855560302734375\n","0.00701904296875\n","0.007171630859375\n","0.0075225830078125\n","0.00605010986328125\n","0.00777435302734375\n","0.006053924560546875\n","0.0075225830078125\n","0.006076812744140625\n","0.00733184814453125\n","0.005329132080078125\n","0.00736236572265625\n","0.00611114501953125\n","0.0070953369140625\n","0.005405426025390625\n","0.00687408447265625\n","0.005764007568359375\n","0.00737762451171875\n","0.005706787109375\n","0.00720977783203125\n","0.005462646484375\n","0.00725555419921875\n","0.00591278076171875\n","0.0068359375\n","0.00504302978515625\n","0.00681304931640625\n","0.00567626953125\n","0.00745391845703125\n","0.00548553466796875\n","0.0076904296875\n","0.006725311279296875\n","0.0070037841796875\n","0.006229400634765625\n","0.0073699951171875\n","0.005184173583984375\n","0.00695037841796875\n","0.004901885986328125\n","0.00705718994140625\n","0.005374908447265625\n","0.00798797607421875\n","0.005619049072265625\n","0.00702667236328125\n","0.00490570068359375\n","0.0074462890625\n","0.005405426025390625\n","0.0065155029296875\n","0.0051116943359375\n","0.0071868896484375\n","0.00540924072265625\n","0.0068511962890625\n","0.004940032958984375\n","0.00740814208984375\n","0.006267547607421875\n","0.0067901611328125\n","0.0050201416015625\n","0.006439208984375\n","0.00514984130859375\n","0.0063018798828125\n","0.00565338134765625\n","(tensor(69.1735), [0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2])\n"],"name":"stdout"}]}]}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment