Commit 450bd27a by 20200203141

project 1 initial commit

parent 6da00683
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" languageLevel="JDK_14" default="false" project-jdk-name="Python 3.6 (py36)" project-jdk-type="Python SDK" />
<component name="PyCharmProfessionalAdvertiser">
<option name="shown" value="true" />
</component>
......
......@@ -5,4 +5,7 @@
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PackageRequirementsSettings">
<option name="requirementsPath" value="" />
</component>
</module>
\ No newline at end of file
sklearn
jieba
bert-embedding
nltk
matplotlib
\ No newline at end of file
......@@ -103,10 +103,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import json\n",
"\n",
"def read_corpus():\n",
" \"\"\"\n",
" 读取给定的语料库,并把问题列表和答案列表分别写入到 qlist, alist 里面。 在此过程中,不用对字符换做任何的处理(这部分需要在 Part 2.3里处理)\n",
......@@ -115,11 +118,30 @@
" 务必要让每一个问题和答案对应起来(下标位置一致)\n",
" \"\"\"\n",
" # TODO 需要完成的代码部分 ...\n",
" \n",
" \n",
" \n",
" qlist = []\n",
" alist = []\n",
" filename = 'train-v2.0.json'\n",
" datas = json.load(open(filename,'r'))\n",
" data = datas['data']\n",
" for d in data:\n",
" paragraph = d['paragraphs']\n",
" for p in paragraph:\n",
" qas = p['qas']\n",
" for qa in qas:\n",
" #print(qa)\n",
" #处理is_impossible为True时answers空\n",
" if(not qa['is_impossible']):\n",
" qlist.append(qa['question'])\n",
" alist.append(qa['answers'][0]['text'])\n",
" #print(qlist[0])\n",
" #print(alist[0])\n",
" assert len(qlist) == len(alist) # 确保长度一样\n",
" return qlist, alist"
" # reduce the q/a data size to avoid out of memory\n",
" qlist = qlist[0:len(qlist)//4]\n",
" alist = alist[0:len(alist)//4]\n",
" return qlist, alist\n",
"\n",
"qlist,alist = read_corpus()"
]
},
{
......@@ -135,25 +157,70 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"23027\n"
]
}
],
"source": [
"# TODO: 统计一下在qlist中总共出现了多少个单词? 总共出现了多少个不同的单词(unique word)?\n",
"# 这里需要做简单的分词,对于英文我们根据空格来分词即可,其他过滤暂不考虑(只需分词)\n",
"words_qlist = dict()\n",
"for q in qlist:\n",
" #以空格为分词,都转为小写\n",
" words = q.strip().split(' ')\n",
" for w in words:\n",
" if w.lower() in words_qlist:\n",
" words_qlist[w.lower()] += 1\n",
" else:\n",
" words_qlist[w.lower()] = 1\n",
"word_total = len(words_qlist)\n",
"\n",
"print (word_total)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD4CAYAAADlwTGnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAASw0lEQVR4nO3dfbBcdX3H8feXJBADQQMXbCA4N1hQQClPTS1QhabUQBmxih3qgIFEMnUUhNECkRmhowwUKA1adIZKEEoGmmoEpjNFeawDA2gSHpJrQBAoXogkJFCwGEjIt3/sueQa7hO7d7O79/d+zTB397dn7/2dw2E/7Dl79hOZiSSpPNu1egKSpNYwACSpUAaAJBXKAJCkQhkAklSo8a2eAEBXV1d2d3e3ehqS1FGWLVv2YmbuVu/z2yIAuru7Wbp0aaunIUkdJSL+p5HnewhIkgplAEhSoQwASSpUW5wDkKT+Nm7cSG9vLxs2bGj1VNrCxIkTmTZtGhMmTBjV32sASGo7vb29TJ48me7ubiKi1dNpqcxk3bp19Pb2Mn369FH93R4CktR2NmzYwK677lr8iz9ARLDrrrs25d2QASCpLfniv0WztoUBIEmFGvYcQETsBVwP/AGwGbg6M6+MiF2Afwe6gWeAv8nMl6rnzAfmAm8CZ2bmj5sye0lFuPS2x3jptY2j9vumTJrAObM+OGq/751asGAB8+bNY9KkSS2bA4zsJPAm4CuZuTwiJgPLIuJ24FTgzsy8JCLOA84Dzo2I/YGTgAOAPYA7ImLfzHxzsD/wm1c2MH/JikbXpWGt3ikkDeyl1zZy8ac+PGq/r9WvNwsWLODkk09u/wDIzNXA6ur2qxGxCtgTOAE4qlrsOuAe4Nxq/KbMfB14OiKeBGYA9w/2N97cnKP6L7derd4pJLWP66+/nssvv5yI4MADD+Sb3/wmc+bMYe3atey2225ce+21vO997+PUU0/l+OOP58QTTwRgp5124re//S333HMPF154IV1dXaxcuZJDDz2UG264gW9/+9s8//zzHH300XR1dXHHHXcwd+5cli5dSkQwZ84czj777G2yju/oY6AR0Q0cDDwIvLcKBzJzdUTsXi22J/BAv6f1VmNb/655wDyAd++x9zueuCQ1S09PDxdddBH33XcfXV1drF+/ntmzZ/O5z32O2bNns3DhQs4880xuvvnmIX/PQw89RE9PD3vssQdHHHEE9913H2eeeSZXXHEFd999N11dXSxbtoznnnuOlStXAvDyyy83fwUrIz4JHBE7AT8EzsrMV4ZadICxtxUPZ+bVmXlYZh72rndNHOk0JKnp7rrrLk488US6uroA2GWXXbj//vv57Gc/C8App5zCvffeO+zvmTFjBtOmTWO77bbjoIMO4plnnnnbMnvvvTdPPfUUZ5xxBrfddhs777zzqK7LUEYUABExgdqL/6LMXFINvxARU6vHpwJrqvFeYK9+T58GPD8605Wk5svMYT962ff4+PHj2bx581vPe+ONN95aZocddnjr9rhx49i0adPbfs+UKVN45JFHOOqoo7jqqqv4/Oc/PxqrMCLDBkDU1vIaYFVmXtHvoVuB2dXt2cAt/cZPiogdImI6sA/ws9GbsiQ118yZM1m8eDHr1q0DYP369Rx++OHcdNNNACxatIgjjzwSqH2d/bJlywC45ZZb2Lhx+E8rTZ48mVdffRWAF198kc2bN/PpT3+ab3zjGyxfvrwZqzSgkZwDOAI4BVgREQ9XY18DLgEWR8Rc4FngMwCZ2RMRi4FfUPsE0ReH+gSQJA1nyqQJo/ohjSmThv5OnQMOOIDzzz+fj33sY4wbN46DDz6Yb33rW8yZM4fLLrvsrZPAAKeffjonnHACM2bMYObMmey4447D/v158+Zx7LHHMnXqVBYsWMBpp5321ruIiy++uPEVHKHIfNvh+W1u6h8ekKuf7Gn1NJi/ZEVbfBpJKt2qVavYb7/9Wj2NtjLQNomIZZl5WL2/0yuBJalQBoAkFcoAkNSW2uHwdLto1rYwACS1nYkTJ7Ju3TpDgC19ABMnjv71UhbCSGo706ZNo7e3l7Vr17Z6Km2hrxFstBkAktrOhAkTRr39Sm/nISBJKpQBIEmFMgAkqVAGgCQVygCQpEKN5NtAF0bEmohY2W/soIh4ICIejoilETGj32PzI+LJiHg8Ij7erIlLkhozkncA3wdmbTV2KfAPmXkQ8PXqPlv1Ac8CvhMR40ZrspKk0TNsAGTmT4H1Ww8DfbU172ZL4ctbfcCZ+TTQ1wcsSWoz9V4Idhbw44i4nFqIHF6Nj6gPGOwElqRWq/ck8BeAszNzL+Bsao1hMMI+YLATWJJard4AmA30dQP/B1sO89gHLEkdot4AeB74WHX7z4Enqtv2AUtShxj2HEBE3AgcBXRFRC9wAXA6cGVEjAc2UB3Ltw9YkjrHsAGQmX87yEOHDrL8RcBFjUxKktR8XgksSYUyACSpUAaAJBXKAJCkQhkAklQoA0CSCmUASFKhDABJKpQBIEmFMgAkqVAGgCQVygCQpELVVQpfjZ9RFb/3RMSl/cYthZekDjCSSsjvA/8CXN83EBFHU+v/PTAzX4+I3avx/qXwewB3RMS+fiW0JLWfekvhvwBckpmvV8usqcYthZekDlHvOYB9gT+LiAcj4r8j4o+r8T2BX/dbbshS+IhYGhFLf/e7DXVOQ5JUr3oDYDwwBfgI8PfA4ogILIWXpI5RbwD0Akuy5mfAZqALS+ElqWPUGwA3UyuDJyL2BbYHXsRSeEnqGPWWwi8EFlYfDX0DmJ2ZCVgKL0kdopFS+JMHWd5SeEnqAF4JLEmFMgAkqVAGgCQVygCQpEIZAJJUKANAkgplAEhSoQwASSqUASBJhTIAJKlQBoAkFaruTuDqsa9GREZEV78xO4ElqQOM5B3A94FZWw9GxF7AMcCz/cb6dwLPAr4TEeNGZaaSpFFVbycwwD8D5/D7jV92AktSh6jrHEBEfAJ4LjMf2eohO4ElqUO84wCIiEnA+cDXB3p4gDE7gSWpDQ1bCDOA9wPTgUdqPfBMA5ZHxAzsBJakjvGO3wFk5orM3D0zuzOzm9qL/iGZ+RvsBJakjjGSj4HeCNwPfCAieiNi7mDLZmYP0NcJfBt2AktS22qkE7jv8e6t7tsJLEkdwCuBJalQBoAkFcoAkKRCGQCSVCgDQJIKZQBIUqEMAEkqlAEgSYUyACSpUAaAJBXKAJCkQhkAklSoukrhI+KyiHgsIh6NiB9FxHv6PWYpvCR1gHpL4W8HPpSZBwK/BOaDpfCS1EnqKoXPzJ9k5qbq7gPUmr/AUnhJ6hijcQ5gDvBf1W1L4SWpQzQUABFxPrAJWNQ3NMBilsJLUhuqpxQegIiYDRwPzMzMvhd5S+ElqUPU9Q4gImYB5wKfyMzX+j1kKbwkdYhh3wFUpfBHAV0R0QtcQO1TPzsAt0cEwAOZ+XeZ2RMRfaXwm7AUXpLaVr2l8NcMsbyl8JLUAbwSWJIKZQBIUqEMAEkqlAEgSYUyACSpUAaAJBXKAJCkQhkAklQoA0CSCmUASFKhDABJKlS9ncC7RMTtEfFE9XNKv8fsBJakDlBvJ/B5wJ2ZuQ9wZ3XfTmBJ6iB1dQJT6/69rrp9HfDJfuN2AktSB6j3HMB7M3M1QPVz92rcTmBJ6hCjfRLYTmBJ6hD1BsALETEVoPq5phq3E1iSOkS9AXArMLu6PRu4pd+4ncCS1AHq7QS+BFgcEXOBZ4HPANgJLEmdo95OYICZgyxvJ7AkdQCvBJakQhkAklQoA0CSCmUASFKhDABJKpQBIEmFMgAkqVAGgCQVygCQpEIZAJJUKANAkgplAEhSoRoKgIg4OyJ6ImJlRNwYEROHKoyXJLWPugMgIvYEzgQOy8wPAeOoFcIPWBgvSWovjR4CGg+8KyLGA5OotX8NVhgvSWojdQdAZj4HXE6tEGY18L+Z+RMGL4z/PZbCS1JrNXIIaAq1/9ufDuwB7BgRJ4/0+ZbCS1JrNXII6C+ApzNzbWZuBJYAhzN4YbwkqY00EgDPAh+JiEkREdQqIlcxeGG8JKmNDNsJPJjMfDAifgAsp1YA/xBwNbATAxTGS5LaS90BAJCZFwAXbDX8OoMUxkuS2odXAktSoQwASSqUASBJhTIAJKlQBoAkFcoAkKRCGQCSVCgDQJIKZQBIUqEMAEkqlAEgSYVqtBP4PRHxg4h4LCJWRcSf2gksSZ2h0XcAVwK3ZeYHgT+i9nXQdgJLUgdopBFsZ+CjwDUAmflGZr6MncCS1BEaeQewN7AWuDYiHoqI70XEjtgJLEkdoZEAGA8cAnw3Mw8G/o93cLjHTmBJaq1GAqAX6M3MB6v7P6AWCHYCS1IHqDsAMvM3wK8j4gPV0EzgF9gJLEkdoaFKSOAMYFFEbA88BZxGLVTsBJakNtdoJ/DDwGEDPGQnsCS1Oa8ElqRCGQCSVCgDQJIKZQBIUqEMAEkqlAEgSYUyACSpUAaAJBXKAJCkQhkAklQoA0CSCmUASFKhGg6AiBhXNYL9Z3XfUnhJ6gCj8Q7gy9TK4PtYCi9JHaChAIiIacBfAd/rN2wpvCR1gEbfASwAzgE29xuzFF6SOkDdARARxwNrMnNZPc+3FF6SWquRRrAjgE9ExHHARGDniLiBqhQ+M1dbCi9J7auRUvj5mTktM7uBk4C7MvNkLIWXpI7QjOsALgGOiYgngGOq+5KkNtNQKXyfzLwHuKe6vQ5L4SWp7XklsCQVygCQpEIZAJJUKANAkgplAEhSoQwASSqUASBJhTIAJKlQBoAkFcoAkKRCGQCSVKhG+gD2ioi7I2JVRPRExJercTuBJakDNPIOYBPwlczcD/gI8MWI2B87gSWpIzTSB7A6M5dXt1+lVgy/J3YCS1JHGJVzABHRDRwMPIidwJLUERoOgIjYCfghcFZmvjLS59kJLEmt1VAARMQEai/+izJzSTX8QtUFjJ3AktS+GvkUUADXAKsy84p+D9kJLEkdoJFKyCOAU4AVEfFwNfY1ah3AiyNiLvAs8JmGZihJaoq6AyAz7wVikIftBJakNueVwJJUKANAkgrVyDmAMWfKpAnMX7Ki5XM4Z9YHWzoHSWUwAPpphxfeVgeQpHJ4CEiSCmUASFKhDABJKpQBIEmFMgAkqVAGgCQVyo+BakCX3vYYL722saVz8JoIqbkMAA3opdc2cvGnPtzSOXhNhNRcTQuAiJgFXAmMA76XmZc062+NJe1wNXLfPCSNbU0JgIgYB1wFHAP0Aj+PiFsz8xfN+HtjiYc81M48NDi2NOsdwAzgycx8CiAibqJWFm8AqOO0w4teu5gyaYKHBseQZgXAnsCv+93vBf6k/wIRMQ+YV919PSJWNmkunaYLeLHVk2gTXZe4Lfq0zX5xbqsn4H7R3wcaeXKzAmCgopj8vTuZVwNXA0TE0sw8rElz6Shuiy3cFlu4LbZwW2wREUsbeX6zrgPoBfbqd38a8HyT/pYkqQ7NCoCfA/tExPSI2B44iVpZvCSpTTTlEFBmboqILwE/pvYx0IWZ2TPEU65uxjw6lNtiC7fFFm6LLdwWWzS0LSIzh19KkjTm+F1AklQoA0CSCtXyAIiIWRHxeEQ8GRHntXo+21pEPBMRKyLi4b6PdEXELhFxe0Q8Uf2c0up5NkNELIyINf2vARlq3SNifrWfPB4RH2/NrJtjkG1xYUQ8V+0bD0fEcf0eG5PbIiL2ioi7I2JVRPRExJer8eL2iyG2xejtF5nZsn+onSD+FbA3sD3wCLB/K+fUgm3wDNC11dilwHnV7fOAf2z1PJu07h8FDgFWDrfuwP7V/rEDML3ab8a1eh2avC0uBL46wLJjdlsAU4FDqtuTgV9W61vcfjHEthi1/aLV7wDe+sqIzHwD6PvKiNKdAFxX3b4O+GTrptI8mflTYP1Ww4Ot+wnATZn5emY+DTxJbf8ZEwbZFoMZs9siM1dn5vLq9qvAKmrfLFDcfjHEthjMO94WrQ6Agb4yYqgVHIsS+ElELKu+HgPgvZm5Gmo7AbB7y2a37Q227qXuK1+KiEerQ0R9hz2K2BYR0Q0cDDxI4fvFVtsCRmm/aHUADPuVEQU4IjMPAY4FvhgRH231hNpUifvKd4H3AwcBq4F/qsbH/LaIiJ2AHwJnZeYrQy06wNhY3xajtl+0OgCK/8qIzHy++rkG+BG1t2wvRMRUgOrnmtbNcJsbbN2L21cy84XMfDMzNwP/ypa382N6W0TEBGoveIsyc0k1XOR+MdC2GM39otUBUPRXRkTEjhExue828JfASmrbYHa12GzgltbMsCUGW/dbgZMiYoeImA7sA/ysBfPbZvpe8Cp/TW3fgDG8LSIigGuAVZl5Rb+HitsvBtsWo7pftMGZ7uOond3+FXB+q+ezjdd9b2pn7R8BevrWH9gVuBN4ovq5S6vn2qT1v5HaW9iN1P7vZe5Q6w6cX+0njwPHtnr+22Bb/BuwAni0+o976ljfFsCR1A5bPAo8XP1zXIn7xRDbYtT2C78KQpIK1epDQJKkFjEAJKlQBoAkFcoAkKRCGQCSVCgDQJIKZQBIUqH+H2N9CyF5yuo5AAAAAElFTkSuQmCC\n"
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# TODO: 统计一下qlist中出现1次,2次,3次... 出现的单词个数, 然后画一个plot. 这里的x轴是单词出现的次数(1,2,3,..), y轴是单词个数。\n",
"# 从左到右分别是 出现1次的单词数,出现2次的单词数,出现3次的单词数... \n",
"\n"
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"#counts:key出现N次,value:出现N次词有多少\n",
"counts = dict()\n",
"for w,c in words_qlist.items():\n",
" if c in counts:\n",
" counts[c] += 1\n",
" else:\n",
" counts[c] = 1\n",
"#print(counts)\n",
"#以histogram画图\n",
"fig,ax = plt.subplots()\n",
"ax.hist(counts.values(),bins = np.arange(0,250,25),histtype='step',alpha=0.6,label=\"counts\")\n",
"ax.legend()\n",
"ax.set_xlim(0,250)\n",
"ax.set_yticks(np.arange(0,220,20))\n",
"plt.show()\n"
]
},
{
......@@ -187,13 +254,90 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# TODO: 需要做文本方面的处理。 从上述几个常用的方法中选择合适的方法给qlist做预处理(不一定要按照上面的顺序,不一定要全部使用)\n",
"import nltk\n",
"from nltk.corpus import stopwords\n",
"import codecs\n",
"import re\n",
"\n",
"def tokenizer(ori_list):\n",
" #分词时处理标点符号\n",
" SYMBOLS = re.compile('[\\s;\\\"\\\",.!?\\\\/\\[\\]\\{\\}\\(\\)-]+')\n",
" new_list = []\n",
" for q in ori_list:\n",
" words = SYMBOLS.split(q.lower().strip())\n",
" new_list.append(' '.join(words))\n",
" return new_list\n",
"\n",
"def removeStopWord(ori_list):\n",
" new_list = []\n",
" #nltk中stopwords包含what等,但是在QA问题中,这算关键词,所以不看作关键词\n",
" restored = ['what','when','which','how','who','where']\n",
" english_stop_words = list(set(stopwords.words('english')))#['what','when','which','how','who','where','a','an','the'] #\n",
" for w in restored:\n",
" english_stop_words.remove(w)\n",
" for q in ori_list:\n",
" sentence = ' '.join([w for w in q.strip().split(' ') if w not in english_stop_words])\n",
" new_list.append(sentence)\n",
" return new_list\n",
"\n",
"def removeLowFrequence(ori_list,vocabulary,thres = 10):\n",
" #根据thres筛选词表,小于thres的词去掉\n",
" new_list = []\n",
" for q in ori_list:\n",
" sentence = ' '.join([w for w in q.strip().split(' ') if vocabulary[w] >= thres])\n",
" new_list.append(sentence)\n",
" return new_list\n",
"\n",
"def replaceDigits(ori_list,replace = '#number'):\n",
" #将数字统一替换为replace,默认#number\n",
" DIGITS = re.compile('\\d+')\n",
" new_list = []\n",
" for q in ori_list:\n",
" q = DIGITS.sub(replace,q)\n",
" new_list.append(q)\n",
" return new_list\n",
"\n",
"def createVocab(ori_list):\n",
" count = 0\n",
" vocab_count = dict()\n",
" for q in ori_list:\n",
" words = q.strip().split(' ')\n",
" count += len(words)\n",
" for w in words:\n",
" if w in vocab_count:\n",
" vocab_count[w] += 1\n",
" else:\n",
" vocab_count[w] = 1\n",
" return vocab_count,count\n",
"def writeFile(oriList,filename):\n",
" with codecs.open(filename,'w','utf8') as Fout:\n",
" for q in oriList:\n",
" Fout.write(q + u'\\n')\n",
"\n",
"qlist = # 更新后的问题列表"
"def writeVocab(vocabulary,filename):\n",
" sortedList = sorted(vocabulary.items(),key = lambda d:d[1])\n",
" with codecs.open(filename,'w','utf8') as Fout:\n",
" for (w,c) in sortedList:\n",
" Fout.write(w + u':' + str(c) + u'\\n')\n",
"new_list = tokenizer(qlist)\n",
"#writeFile(qlist,'ori.txt')\n",
"\n",
"#new_list = removeStopWord(new_list)\n",
"#writeFile(new_list,'removeStop.txt')\n",
"new_list = replaceDigits(new_list)\n",
"#writeFile(new_list,'removeDigts.txt')\n",
"vocabulary,count = createVocab(new_list)\n",
"new_list = removeLowFrequence(new_list,vocabulary,5)\n",
"#writeFile(new_list,'lowFrequence.txt')\n",
"#重新统计词频\n",
"vocab_count,count = createVocab(new_list)\n",
"writeVocab(vocab_count,\"train.vocab\")\n",
"qlist = new_list # 更新后的问题列表"
]
},
{
......@@ -220,14 +364,75 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4212\n",
"[0.0301549 0.04699083 0.00771518 ... 0. 0. 0. ]\n",
"(21705, 4212)\n"
]
}
],
"source": [
"# TODO \n",
"vectorizer = # 定义一个tf-idf的vectorizer\n",
"# TODO\n",
"import numpy as np\n",
"\n",
"def computeTF(vocab,c):\n",
" #计算每次词的词频\n",
" #vocabCount已经统计好的每词的次数\n",
" #c是统计好的总次数\n",
" TF = np.ones(len(vocab))\n",
" word2id = dict()\n",
" id2word = dict()\n",
" for word,fre in vocab.items():\n",
" TF[len(word2id)] = 1.0 * fre / c\n",
" id2word[len(word2id)] = word\n",
" word2id[word] = len(word2id)\n",
" return TF,word2id,id2word\n",
"\n",
"def computeIDF(word2id,qlist):\n",
" #IDF计算,没有类别,以句子为一个类\n",
" IDF = np.ones(len(word2id))\n",
" for q in qlist:\n",
" words = set(q.strip().split())\n",
" for w in words:\n",
" IDF[word2id[w]] += 1\n",
" IDF /= len(qlist)\n",
" IDF = -1.0 * np.log2(IDF)\n",
" return IDF\n",
"\n",
"X_tfidf = # 结果存放在X矩阵里"
"def computeSentenceEach(sentence,tfidf,word2id):\n",
" #给定句子,计算句子TF-IDF\n",
" #tfidf是一个1*M的矩阵,M为词表大小\n",
" #不在词表中的词不统计\n",
" sentence_tfidf = np.zeros(len(word2id))\n",
" for w in sentence.strip().split(' '):\n",
" if w not in word2id:\n",
" continue\n",
" sentence_tfidf[word2id[w]] = tfidf[word2id[w]]\n",
" return sentence_tfidf\n",
"\n",
"def computeSentence(qlist,word2id,tfidf):\n",
" #对所有句子分别求tfidf\n",
" X_tfidf = np.zeros((len(qlist),len(word2id)))\n",
" for i,q in enumerate(qlist):\n",
" X_tfidf[i] = computeSentenceEach(q,tfidf,word2id)\n",
" #print(X_tfidf[i])\n",
" return X_tfidf\n",
"\n",
"TF,word2id,id2word = computeTF(vocab_count,count)\n",
"print(len(word2id))\n",
"IDF = computeIDF(word2id,qlist)\n",
"#用TF,IDF计算最终的tf-idf\n",
"vectorizer = np.multiply(TF,IDF)# 定义一个tf-idf的vectorizer\n",
"#print(vectorizer)\n",
"X_tfidf = computeSentence(qlist,word2id,vectorizer) # 结果存放在X矩阵里\n",
"print(X_tfidf[0])\n",
"print(X_tfidf.shape)\n"
]
},
{
......@@ -245,15 +450,72 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\gensim\\utils.py:1209: UserWarning: detected Windows; aliasing chunkize to chunkize_serial\n",
" warnings.warn(\"detected Windows; aliasing chunkize to chunkize_serial\")\n",
"c:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\smart_open\\smart_open_lib.py:410: UserWarning: This function is deprecated, use smart_open.open instead. See the migration notes for details: https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst#migrating-to-the-new-open-function\n",
" 'See the migration notes for details: %s' % _MIGRATION_NOTES_URL\n"
]
},
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'glove.6B.200d.txt'",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mFileNotFoundError\u001B[0m Traceback (most recent call last)",
"\u001B[1;32m<ipython-input-17-38d1c61f36e2>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 29\u001B[0m \u001B[1;31m#print(X_w2v)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 30\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mX_w2v\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 31\u001B[1;33m \u001B[0memb\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mloadEmbedding\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m'glove.6B.200d.txt'\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;31m# 这是 D*H的矩阵,这里的D是词典库的大小, H是词向量的大小。 这里面我们给定的每个单词的词向量,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 32\u001B[0m \u001B[1;31m# 这需要从文本中读取\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 33\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32m<ipython-input-17-38d1c61f36e2>\u001B[0m in \u001B[0;36mloadEmbedding\u001B[1;34m(filename)\u001B[0m\n\u001B[0;32m 6\u001B[0m \u001B[1;31m#加载glove模型,转化为word2vec,再加载word2vec模型\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 7\u001B[0m \u001B[0mword2vec_temp_file\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;34m'word2vec_temp.txt'\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m----> 8\u001B[1;33m \u001B[0mglove2word2vec\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mfilename\u001B[0m\u001B[1;33m,\u001B[0m\u001B[0mword2vec_temp_file\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 9\u001B[0m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mKeyedVectors\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_word2vec_format\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mword2vec_temp_file\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 10\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\gensim\\scripts\\glove2word2vec.py\u001B[0m in \u001B[0;36mglove2word2vec\u001B[1;34m(glove_input_file, word2vec_output_file)\u001B[0m\n\u001B[0;32m 102\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 103\u001B[0m \"\"\"\n\u001B[1;32m--> 104\u001B[1;33m \u001B[0mnum_lines\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnum_dims\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mget_glove_info\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mglove_input_file\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 105\u001B[0m \u001B[0mlogger\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0minfo\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"converting %i vectors from %s to %s\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnum_lines\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mglove_input_file\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mword2vec_output_file\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 106\u001B[0m \u001B[1;32mwith\u001B[0m \u001B[0msmart_open\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mword2vec_output_file\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;34m'wb'\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mas\u001B[0m \u001B[0mfout\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\gensim\\scripts\\glove2word2vec.py\u001B[0m in \u001B[0;36mget_glove_info\u001B[1;34m(glove_file_name)\u001B[0m\n\u001B[0;32m 79\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 80\u001B[0m \"\"\"\n\u001B[1;32m---> 81\u001B[1;33m \u001B[1;32mwith\u001B[0m \u001B[0msmart_open\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mglove_file_name\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mas\u001B[0m \u001B[0mf\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 82\u001B[0m \u001B[0mnum_lines\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0msum\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;36m1\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0m_\u001B[0m \u001B[1;32min\u001B[0m \u001B[0mf\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 83\u001B[0m \u001B[1;32mwith\u001B[0m \u001B[0msmart_open\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mglove_file_name\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mas\u001B[0m \u001B[0mf\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\smart_open\\smart_open_lib.py\u001B[0m in \u001B[0;36msmart_open\u001B[1;34m(uri, mode, **kw)\u001B[0m\n\u001B[0;32m 464\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 465\u001B[0m return open(uri, mode, ignore_ext=ignore_extension,\n\u001B[1;32m--> 466\u001B[1;33m transport_params=transport_params, **scrubbed_kwargs)\n\u001B[0m\u001B[0;32m 467\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 468\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\smart_open\\smart_open_lib.py\u001B[0m in \u001B[0;36mopen\u001B[1;34m(uri, mode, buffering, encoding, errors, newline, closefd, opener, ignore_ext, transport_params)\u001B[0m\n\u001B[0;32m 307\u001B[0m \u001B[0mbuffering\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mbuffering\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 308\u001B[0m \u001B[0mencoding\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mencoding\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 309\u001B[1;33m \u001B[0merrors\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0merrors\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 310\u001B[0m )\n\u001B[0;32m 311\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mfobj\u001B[0m \u001B[1;32mis\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\smart_open\\smart_open_lib.py\u001B[0m in \u001B[0;36m_shortcut_open\u001B[1;34m(uri, mode, ignore_ext, buffering, encoding, errors)\u001B[0m\n\u001B[0;32m 523\u001B[0m \u001B[1;31m#\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 524\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0msix\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mPY3\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 525\u001B[1;33m \u001B[1;32mreturn\u001B[0m \u001B[0m_builtin_open\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mparsed_uri\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0muri_path\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mmode\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mbuffering\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mbuffering\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mopen_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 526\u001B[0m \u001B[1;32melif\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[0mopen_kwargs\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 527\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0m_builtin_open\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mparsed_uri\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0muri_path\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mmode\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mbuffering\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mbuffering\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'glove.6B.200d.txt'"
]
}
],
"source": [
"# TODO 基于Glove向量获取句子向量\n",
"emb = # 这是 D*H的矩阵,这里的D是词典库的大小, H是词向量的大小。 这里面我们给定的每个单词的词向量,\n",
"from gensim.models import KeyedVectors\n",
"from gensim.scripts.glove2word2vec import glove2word2vec\n",
"\n",
"def loadEmbedding(filename):\n",
" #加载glove模型,转化为word2vec,再加载word2vec模型\n",
" word2vec_temp_file = 'word2vec_temp.txt'\n",
" glove2word2vec(filename,word2vec_temp_file)\n",
" model = KeyedVectors.load_word2vec_format(word2vec_temp_file)\n",
" return model\n",
"\n",
"def computeGloveSentenceEach(sentence,embedding):\n",
" #查找句子中每个词的embedding,将所有embedding进行加和求均值\n",
" emb = np.zeros(200)\n",
" words = sentence.strip().split(' ')\n",
" for w in words:\n",
" if w not in embedding:\n",
" #没有lookup的即为unknown\n",
" w = 'unknown'\n",
" #emb += embedding.get_vector(w)\n",
" emb += embedding[w]\n",
" return emb / len(words)\n",
"\n",
"def computeGloveSentence(qlist,embedding):\n",
" #对每一个句子进行求均值的embedding\n",
" X_w2v = np.zeros((len(qlist),200))\n",
" for i,q in enumerate(qlist):\n",
" X_w2v[i] = computeGloveSentenceEach(q,embedding)\n",
" #print(X_w2v)\n",
" return X_w2v\n",
"emb = loadEmbedding('glove.6B.200d.txt')# 这是 D*H的矩阵,这里的D是词典库的大小, H是词向量的大小。 这里面我们给定的每个单词的词向量,\n",
" # 这需要从文本中读取\n",
" \n",
"X_w2v = # 初始化完emb之后就可以对每一个句子来构建句子向量了,这个过程使用average pooling来实现\n"
"\n",
"X_w2v = computeGloveSentence(qlist,emb)# 初始化完emb之后就可以对每一个句子来构建句子向量了,这个过程使用average pooling来实现\n"
]
},
{
......@@ -267,12 +529,39 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading C:\\Users\\Qing\\.mxnet\\models\\bert_12_768_12_wiki_multilingual_cased-b0f57a20.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/bert_12_768_12_wiki_multilingual_cased-b0f57a20.zip...\n"
]
}
],
"source": [
"# TODO 基于BERT的句子向量计算\n",
"\n",
"X_bert = # 每一个句子的向量结果存放在X_bert矩阵里。行数为句子的总个数,列数为一个句子embedding大小。 "
"from bert_embedding import BertEmbedding\n",
"sentence_embedding = np.ones((len(qlist),768))\n",
"#加载Bert模型,model,dataset_name,须指定\n",
"bert_embedding = BertEmbedding(model='bert_12_768_12', dataset_name='wiki_multilingual_cased')\n",
"#查询所有句子的Bert embedding\n",
"#all_embedding = []\n",
"#for q in qlist:\n",
"# all_embedding.append(bert_embedding([q],'sum'))\n",
"all_embedding = bert_embedding(qlist,'sum')\n",
"for i in range(len(all_embedding)):\n",
" #print(all_embedding[i][1])\n",
" sentence_embedding[i] = np.sum(all_embedding[i][1],axis = 0) / len(q.strip().split(' '))\n",
" if i == 0:\n",
" print(sentence_embedding[i])\n",
"\n",
"X_bert = sentence_embedding # 每一个句子的向量结果存放在X_bert矩阵里。行数为句子的总个数,列数为一个句子embedding大小。"
]
},
{
......@@ -306,6 +595,13 @@
"metadata": {},
"outputs": [],
"source": [
"import queue as Q\n",
"#优先级队列实现大顶堆Heap,每次输出都是相似度最大值\n",
"que = Q.PriorityQueue()\n",
"def cosineSimilarity(vec1,vec2):\n",
" #定义余弦相似度\n",
" return np.dot(vec1,vec2.T)/(np.sqrt(np.sum(vec1**2))*np.sqrt(np.sum(vec2**2)))\n",
"\n",
"def get_top_results_tfidf_noindex(query):\n",
" # TODO 需要编写\n",
" \"\"\"\n",
......@@ -314,11 +610,23 @@
" 2. 计算跟每个库里的问题之间的相似度\n",
" 3. 找出相似度最高的top5问题的答案\n",
" \"\"\"\n",
" \n",
" top_idxs = [] # top_idxs存放相似度最高的(存在qlist里的)问题的下标 \n",
" # hint: 请使用 priority queue来找出top results. 思考为什么可以这么做? \n",
" \n",
" return alist[top_idxs] # 返回相似度最高的问题对应的答案,作为TOP5答案 "
" top = 5\n",
" query_tfidf = computeSentenceEach(query.lower(),vectorizer,word2id)\n",
" for i,vec in enumerate(X_tfidf):\n",
" result = cosineSimilarity(vec,query_tfidf)\n",
" #print(result)\n",
" que.put((-1 * result,i))\n",
" i = 0\n",
"\n",
" top_idxs = [] # top_idxs存放相似度最高的(存在qlist里的)问题的下标\n",
" # hint: 请使用 priority queue来找出top results. 思考为什么可以这么做?\n",
" while(i < top and not que.empty()):\n",
" top_idxs.append(que.get()[1])\n",
" i += 1\n",
" print(top_idxs)\n",
" return np.array(alist)[top_idxs] # 返回相似度最高的问题对应的答案,作为TOP5答案\n",
"results = get_top_results_tfidf_noindex('In what city and state did Beyonce grow up')\n",
"print(results)"
]
},
{
......@@ -349,12 +657,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# TODO 请创建倒排表\n",
"inverted_idx = {} # 定一个一个简单的倒排表,是一个map结构。 循环所有qlist一遍就可以"
"word_doc = dict()\n",
"#key:word,value:包含该词的句子序号的列表\n",
"for i,q in enumerate(qlist):\n",
" words = q.strip().split(' ')\n",
" for w in set(words):\n",
" if w not in word_doc:\n",
" #没在word_doc中的,建立一个空list\n",
" word_doc[w] = set([])\n",
" word_doc[w] = word_doc[w] | set([i])\n",
"inverted_idx = word_doc # 定一个一个简单的倒排表,是一个map结构。 循环所有qlist一遍就可以\n"
]
},
{
......@@ -371,14 +688,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# TODO 读取语义相关的单词\n",
"def get_related_words(file):\n",
" \n",
" return related_words\n",
" return []\n",
"\n",
"related_words = get_related_words('related_words.txt') # 直接放在文件夹的根目录下,不要修改此路径。"
]
......@@ -398,27 +715,61 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import queue as Q\n",
"def cosineSimilarity(vec1,vec2):\n",
" #定义余弦相似度\n",
" return np.dot(vec1,vec2.T)/(np.sqrt(np.sum(vec1**2))*np.sqrt(np.sum(vec2**2)))\n",
"def getCandidate(query):\n",
" #根据查询句子中每个词所在的序号列表,求交集\n",
" searched = set()\n",
" for w in query.strip().split(' '):\n",
" if w not in word2id or w not in inverted_idx:\n",
" continue\n",
" #搜索原词所在的序号列表\n",
" if len(searched) == 0:\n",
" searched = set(inverted_idx[w])\n",
" else:\n",
" searched = searched & set(inverted_idx[w])\n",
" #搜索相似词所在的列表\n",
" if w in related_words:\n",
" for similar in related_words[w]:\n",
" searched = searched & set(inverted_idx[similar])\n",
" return searched\n",
"\n",
"def get_top_results_tfidf(query):\n",
" \"\"\"\n",
" 给定用户输入的问题 query, 返回最有可能的TOP 5问题。这里面需要做到以下几点:\n",
" 1. 利用倒排表来筛选 candidate (需要使用related_words). \n",
" 1. 利用倒排表来筛选 candidate (需要使用related_words).\n",
" 2. 对于候选文档,计算跟输入问题之间的相似度\n",
" 3. 找出相似度最高的top5问题的答案\n",
" \"\"\"\n",
" \n",
" top_idxs = [] # top_idxs存放相似度最高的(存在qlist里的)问题的下表 \n",
" # hint: 利用priority queue来找出top results. 思考为什么可以这么做? \n",
" \n",
" return alist[top_idxs] # 返回相似度最高的问题对应的答案,作为TOP5答案"
" top = 5\n",
" query_tfidf = computeSentenceEach(query,vectorizer,word2id)\n",
" results = Q.PriorityQueue()\n",
" searched = getCandidate(query)\n",
" #print(len(searched))\n",
" for candidate in searched:\n",
" #计算candidate与query的余弦相似度\n",
" result = cosineSimilarity(query_tfidf,X_tfidf[candidate])\n",
" #优先级队列中保存相似度和对应的candidate序号\n",
" #-1保证降序\n",
" results.put((-1 * result,candidate))\n",
" i = 0\n",
" top_idxs = [] # top_idxs存放相似度最高的(存在qlist里的)问题的下表\n",
" # hint: 利用priority queue来找出top results. 思考为什么可以这么做?\n",
" while i < top and not results.empty():\n",
" top_idxs.append(results.get()[1])\n",
" i += 1\n",
" return np.array(alist)[top_idxs] # 返回相似度最高的问题对应的答案,作为TOP5答案"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
......@@ -429,16 +780,25 @@
" 2. 对于候选文档,计算跟输入问题之间的相似度\n",
" 3. 找出相似度最高的top5问题的答案\n",
" \"\"\"\n",
" \n",
" top_idxs = [] # top_idxs存放相似度最高的(存在qlist里的)问题的下表 \n",
" # hint: 利用priority queue来找出top results. 思考为什么可以这么做? \n",
" \n",
" return alist[top_idxs] # 返回相似度最高的问题对应的答案,作为TOP5答案"
" top = 5\n",
" query_emb = computeGloveSentenceEach(query,emb)\n",
" results = Q.PriorityQueue()\n",
" searched = getCandidate(query)\n",
" for candidate in searched:\n",
" result = cosineSimilarity(query_emb,X_w2v[candidate])\n",
" results.put((-1 * result,candidate))\n",
" top_idxs = [] # top_idxs存放相似度最高的(存在qlist里的)问题的下表\n",
" # hint: 利用priority queue来找出top results. 思考为什么可以这么做?\n",
" i = 0\n",
" while i < top and not results.empty():\n",
" top_idxs.append(results.get()[1])\n",
" i += 1\n",
" return np.array(alist)[top_idxs] # 返回相似度最高的问题对应的答案,作为TOP5答案"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
......@@ -450,30 +810,51 @@
" 3. 找出相似度最高的top5问题的答案\n",
" \"\"\"\n",
" \n",
" top_idxs = [] # top_idxs存放相似度最高的(存在qlist里的)问题的下表 \n",
" # hint: 利用priority queue来找出top results. 思考为什么可以这么做? \n",
" \n",
" return alist[top_idxs] # 返回相似度最高的问题对应的答案,作为TOP5答案"
" top = 5\n",
" query_emb = np.sum(bert_embedding([query],'sum')[0][1],axis = 0) / len(query.strip().split())\n",
" results = Q.PriorityQueue()\n",
" searched = getCandidate(query)\n",
" for candidate in searched:\n",
" result = cosineSimilarity(query_emb,X_bert[candidate])\n",
" #print(result)\n",
" results.put((-1 * result,candidate))\n",
" top_idxs = [] # top_idxs存放相似度最高的(存在qlist里的)问题的下表\n",
" # hint: 利用priority queue来找出top results. 思考为什么可以这么做?\n",
" i = 0\n",
" while i < top and not results.empty():\n",
" top_idxs.append(results.get()[1])\n",
" i += 1\n",
"\n",
" return np.array(alist)[top_idxs] # 返回相似度最高的问题对应的答案,作为TOP5答案"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['in the late 1990s']\n",
"['births and deaths']\n"
]
}
],
"source": [
"# TODO: 编写几个测试用例,并输出结果\n",
"\n",
"test_query1 = \"\"\n",
"test_query2 = \"\"\n",
"test_query1 = \"When did Beyonce start becoming popular\"\n",
"test_query2 = \"What counted for more of the population change\"\n",
"\n",
"print (get_top_results_tfidf(test_query1))\n",
"print (get_top_results_w2v(test_query1))\n",
"print (get_top_results_bert(test_query1))\n",
"#print (get_top_results_w2v(test_query1))\n",
"#print (get_top_results_bert(test_query1))\n",
"\n",
"print (get_top_results_tfidf(test_query2))\n",
"print (get_top_results_w2v(test_query2))\n",
"print (get_top_results_bert(test_query2))"
"#print (get_top_results_w2v(test_query2))\n",
"#print (get_top_results_bert(test_query2))"
]
},
{
......@@ -509,9 +890,30 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "LookupError",
"evalue": "\n**********************************************************************\n Resource \u001B[93mreuters\u001B[0m not found.\n Please use the NLTK Downloader to obtain the resource:\n\n \u001B[31m>>> import nltk\n >>> nltk.download('reuters')\n \u001B[0m\n For more information see: https://www.nltk.org/data.html\n\n Attempted to load \u001B[93mcorpora/reuters\u001B[0m\n\n Searched in:\n - 'C:\\\\Users\\\\Qing/nltk_data'\n - 'c:\\\\users\\\\qing\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\py36\\\\nltk_data'\n - 'c:\\\\users\\\\qing\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\py36\\\\share\\\\nltk_data'\n - 'c:\\\\users\\\\qing\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\py36\\\\lib\\\\nltk_data'\n - 'C:\\\\Users\\\\Qing\\\\AppData\\\\Roaming\\\\nltk_data'\n - 'C:\\\\nltk_data'\n - 'D:\\\\nltk_data'\n - 'E:\\\\nltk_data'\n**********************************************************************\n",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mLookupError\u001B[0m Traceback (most recent call last)",
"\u001B[1;32mc:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\nltk\\corpus\\util.py\u001B[0m in \u001B[0;36m__load\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m 82\u001B[0m \u001B[1;32mtry\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 83\u001B[1;33m \u001B[0mroot\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mnltk\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mdata\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mfind\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"{}/{}\"\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mformat\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msubdir\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mzip_name\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 84\u001B[0m \u001B[1;32mexcept\u001B[0m \u001B[0mLookupError\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\nltk\\data.py\u001B[0m in \u001B[0;36mfind\u001B[1;34m(resource_name, paths)\u001B[0m\n\u001B[0;32m 584\u001B[0m \u001B[0mresource_not_found\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;34m\"\\n%s\\n%s\\n%s\\n\"\u001B[0m \u001B[1;33m%\u001B[0m \u001B[1;33m(\u001B[0m\u001B[0msep\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mmsg\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0msep\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 585\u001B[1;33m \u001B[1;32mraise\u001B[0m \u001B[0mLookupError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mresource_not_found\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 586\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;31mLookupError\u001B[0m: \n**********************************************************************\n Resource \u001B[93mreuters\u001B[0m not found.\n Please use the NLTK Downloader to obtain the resource:\n\n \u001B[31m>>> import nltk\n >>> nltk.download('reuters')\n \u001B[0m\n For more information see: https://www.nltk.org/data.html\n\n Attempted to load \u001B[93mcorpora/reuters.zip/reuters/\u001B[0m\n\n Searched in:\n - 'C:\\\\Users\\\\Qing/nltk_data'\n - 'c:\\\\users\\\\qing\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\py36\\\\nltk_data'\n - 'c:\\\\users\\\\qing\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\py36\\\\share\\\\nltk_data'\n - 'c:\\\\users\\\\qing\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\py36\\\\lib\\\\nltk_data'\n - 'C:\\\\Users\\\\Qing\\\\AppData\\\\Roaming\\\\nltk_data'\n - 'C:\\\\nltk_data'\n - 'D:\\\\nltk_data'\n - 'E:\\\\nltk_data'\n**********************************************************************\n",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001B[1;31mLookupError\u001B[0m Traceback (most recent call last)",
"\u001B[1;32m<ipython-input-17-da502f8a2af8>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 2\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 3\u001B[0m \u001B[1;31m# 读取语料库的数据\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m----> 4\u001B[1;33m \u001B[0mcategories\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mreuters\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mcategories\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 5\u001B[0m \u001B[0mcorpus\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mreuters\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msents\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcategories\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mcategories\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 6\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\nltk\\corpus\\util.py\u001B[0m in \u001B[0;36m__getattr__\u001B[1;34m(self, attr)\u001B[0m\n\u001B[0;32m 118\u001B[0m \u001B[1;32mraise\u001B[0m \u001B[0mAttributeError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"LazyCorpusLoader object has no attribute '__bases__'\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 119\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 120\u001B[1;33m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__load\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 121\u001B[0m \u001B[1;31m# This looks circular, but its not, since __load() changes our\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 122\u001B[0m \u001B[1;31m# __class__ to something new:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\nltk\\corpus\\util.py\u001B[0m in \u001B[0;36m__load\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m 83\u001B[0m \u001B[0mroot\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mnltk\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mdata\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mfind\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"{}/{}\"\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mformat\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msubdir\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mzip_name\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 84\u001B[0m \u001B[1;32mexcept\u001B[0m \u001B[0mLookupError\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 85\u001B[1;33m \u001B[1;32mraise\u001B[0m \u001B[0me\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 86\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 87\u001B[0m \u001B[1;31m# Load the corpus.\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\nltk\\corpus\\util.py\u001B[0m in \u001B[0;36m__load\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m 78\u001B[0m \u001B[1;32melse\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 79\u001B[0m \u001B[1;32mtry\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 80\u001B[1;33m \u001B[0mroot\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mnltk\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mdata\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mfind\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"{}/{}\"\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mformat\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msubdir\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__name\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 81\u001B[0m \u001B[1;32mexcept\u001B[0m \u001B[0mLookupError\u001B[0m \u001B[1;32mas\u001B[0m \u001B[0me\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 82\u001B[0m \u001B[1;32mtry\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\qing\\appdata\\local\\conda\\conda\\envs\\py36\\lib\\site-packages\\nltk\\data.py\u001B[0m in \u001B[0;36mfind\u001B[1;34m(resource_name, paths)\u001B[0m\n\u001B[0;32m 583\u001B[0m \u001B[0msep\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;34m\"*\"\u001B[0m \u001B[1;33m*\u001B[0m \u001B[1;36m70\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 584\u001B[0m \u001B[0mresource_not_found\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;34m\"\\n%s\\n%s\\n%s\\n\"\u001B[0m \u001B[1;33m%\u001B[0m \u001B[1;33m(\u001B[0m\u001B[0msep\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mmsg\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0msep\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 585\u001B[1;33m \u001B[1;32mraise\u001B[0m \u001B[0mLookupError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mresource_not_found\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 586\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 587\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;31mLookupError\u001B[0m: \n**********************************************************************\n Resource \u001B[93mreuters\u001B[0m not found.\n Please use the NLTK Downloader to obtain the resource:\n\n \u001B[31m>>> import nltk\n >>> nltk.download('reuters')\n \u001B[0m\n For more information see: https://www.nltk.org/data.html\n\n Attempted to load \u001B[93mcorpora/reuters\u001B[0m\n\n Searched in:\n - 'C:\\\\Users\\\\Qing/nltk_data'\n - 'c:\\\\users\\\\qing\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\py36\\\\nltk_data'\n - 'c:\\\\users\\\\qing\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\py36\\\\share\\\\nltk_data'\n - 'c:\\\\users\\\\qing\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\py36\\\\lib\\\\nltk_data'\n - 'C:\\\\Users\\\\Qing\\\\AppData\\\\Roaming\\\\nltk_data'\n - 'C:\\\\nltk_data'\n - 'D:\\\\nltk_data'\n - 'E:\\\\nltk_data'\n**********************************************************************\n"
]
}
],
"source": [
"from nltk.corpus import reuters\n",
"\n",
......@@ -520,7 +922,54 @@
"corpus = reuters.sents(categories=categories)\n",
"\n",
"# 循环所有的语料库并构建bigram probability. bigram[word1][word2]: 在word1出现的情况下下一个是word2的概率。 \n",
"\n",
"from nltk.corpus import reuters\n",
"import numpy as np\n",
"import codecs\n",
"# 读取语料库的数据\n",
"categories = reuters.categories()\n",
"corpus = reuters.sents(categories=categories)\n",
"#print(corpus[0])\n",
"# 循环所有的语料库并构建bigram probability. bigram[word1][word2]: 在word1出现的情况下下一个是word2的概率。\n",
"new_corpus = []\n",
"for sent in corpus:\n",
" #句子前后加入<s>,</s>表示开始和结束\n",
" new_corpus.append(['<s> '] + sent + [' </s>'])\n",
"print(new_corpus[0])\n",
"word2id = dict()\n",
"id2word = dict()\n",
"for sent in new_corpus:\n",
" for w in sent:\n",
" w = w.lower()\n",
" if w in word2id:\n",
" continue\n",
" id2word[len(word2id)] = w\n",
" word2id[w] = len(word2id)\n",
"vocab_size = len(word2id)\n",
"count_uni = np.zeros(vocab_size)\n",
"count_bi = np.zeros((vocab_size,vocab_size))\n",
"#writeVocab(word2id,\"lm_vocab.txt\")\n",
"for sent in new_corpus:\n",
" for i,w in enumerate(sent):\n",
" w = w.lower()\n",
" count_uni[word2id[w]] += 1\n",
" if i < len(sent) - 1:\n",
" count_bi[word2id[w],word2id[sent[i + 1].lower()]] += 1\n",
"print(\"unigram done\")\n",
"bigram = np.zeros((vocab_size,vocab_size))\n",
"#计算bigram LM,有bigram统计值的加一除以|vocab|+uni统计值,没有统计值,\n",
"#1 除以 |vocab|+uni统计值\n",
"for i in range(vocab_size):\n",
" for j in range(vocab_size):\n",
" if count_bi[i,j] == 0:\n",
" bigram[i,j] = 1.0 / (vocab_size + count_uni[i])\n",
" else:\n",
" bigram[i,j] = (1.0 + count_bi[i,j]) / (vocab_size + count_uni[i])\n",
"def checkLM(word1,word2):\n",
" if word1.lower() in word2id and word2.lower() in word2id:\n",
" return bigram[word2id[word1.lower()],word2id[word2.lower()]]\n",
" else:\n",
" return 0.0\n",
"print(checkLM('I','like'))\n",
"\n"
]
},
......@@ -540,13 +989,17 @@
"source": [
"# TODO 构建channel probability \n",
"channel = {}\n",
"\n",
"#读取文件,格式为w1:w2,w3..\n",
"#w1为正确词,w2,w3...为错误词\n",
"#没有给出不同w2-wn的概率,暂时按等概率处理\n",
"for line in open('spell-errors.txt'):\n",
" # TODO\n",
"\n",
"# TODO\n",
"\n",
"print(channel) "
" (correct,error) = line.strip().split(':')\n",
" errors = error.split(',')\n",
" errorProb = dict()\n",
" for e in errors:\n",
" errorProb[e.strip()] = 1.0 / len(errors)\n",
" channel[correct.strip()] = errorProb"
]
},
{
......@@ -563,9 +1016,51 @@
"metadata": {},
"outputs": [],
"source": [
"def filter(words):\n",
" #将不在词表中的词过滤\n",
" new_words = []\n",
" for w in words:\n",
" if w in word2id:\n",
" new_words.append(w)\n",
" return set(new_words)\n",
"\n",
"def generate_candidates1(word):\n",
" #生成DTW距离为1的词,\n",
" #对于英语来说,插入,替换,删除26个字母\n",
" chars = 'abcdefghijklmnopqrstuvwxyz'\n",
" words = set([])\n",
" #insert 1\n",
" words = set(word[0:i] + chars[j] + word[i:] for i in range(len(word)) for j in range(len(chars)))\n",
" #sub 1\n",
" words = words | set(word[0:i] + chars[j] + word[i+1:] for i in range(len(word)) for j in range(len(chars)))\n",
" #delete 1\n",
" words = words | set(word[0:i] + word[i + 1:] for i in range(len(chars)))\n",
" #交换相邻\n",
" #print(set(word[0:i - 1] + word[i] + word[i - 1] + word[i + 1:] for i in range(1,len(word))))\n",
" words = words | set(word[0:i - 1] + word[i] + word[i - 1] + word[i + 1:] for i in range(1,len(word)))\n",
" #将不在词表中的词去掉\n",
" words = filter(words)\n",
" #去掉word本身\n",
" if word in words:\n",
" words.remove(word)\n",
" return words\n",
"\n",
"def generate_candidates(word):\n",
" # 基于拼写错误的单词,生成跟它的编辑距离为1或者2的单词,并通过词典库的过滤。\n",
" # 只留写法上正确的单词。 \n",
" # 只留写法上正确的单词。\n",
" words = generate_candidates1(word)\n",
" words2 = set([])\n",
" for word in words:\n",
" #将距离为1词,再分别计算距离为1的词,\n",
" #作为距离为2的词候选\n",
" words2 = generate_candidates1(word)\n",
" #过滤掉不在词表中的词\n",
" words2 = filter(words)\n",
" #距离为1,2的词合并列表\n",
" words = words | words2\n",
" return words\n",
"words = generate_candidates('strat')\n",
"print(words)\n",
" \n",
" \n"
]
......@@ -586,13 +1081,42 @@
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"import queue as Q\n",
"def word_corrector(word,context):\n",
" word = word.lower()\n",
" candidate = generate_candidates(word)\n",
" if len(candidate) == 0:\n",
" return word\n",
" correctors = Q.PriorityQueue()\n",
" for w in candidate:\n",
" if w in channel and word in channel[w] and w in word2id and context[0].lower() in word2id and context[1].lower() in word2id:\n",
" probility = np.log(channel[w][word] + 0.0001) + np.log(bigram[word2id[context[0].lower()],word2id[w]]) + np.log(bigram[word2id[context[1].lower()],word2id[w]])\n",
" correctors.put((-1 * probility,w))\n",
" if correctors.empty():\n",
" return word\n",
" return correctors.get()[1]\n",
"word = word_corrector('strat',('to','in'))\n",
"print(word)\n",
"def spell_corrector(line):\n",
" # 1. 首先做分词,然后把``line``表示成``tokens``\n",
" # 2. 循环每一token, 然后判断是否存在词库里。如果不存在就意味着是拼写错误的,需要修正。 \n",
" # 修正的过程就使用上述提到的``noisy channel model``, 然后从而找出最好的修正之后的结果。 \n",
" \n",
" return newline # 修正之后的结果,假如用户输入没有问题,那这时候``newline = line``\n"
" # 2. 循环每一token, 然后判断是否存在词库里。如果不存在就意味着是拼写错误的,需要修正。\n",
" # 修正的过程就使用上述提到的``noisy channel model``, 然后从而找出最好的修正之后的结果。\n",
" new_words = []\n",
" words = ['<s>'] + line.strip().lower().split(' ') + ['</s>']\n",
" for i,word in enumerate(words):\n",
" if i == len(words) - 1:\n",
" break\n",
" word = word.lower()\n",
" if word not in word2id:\n",
" #认为错误,需要修正,句子前后加了<s>,</s>\n",
" #不在词表中词,肯定位于[1,len - 2]之间\n",
" new_words.append(word_corrector(word,(words[i - 1].lower(),words[i + 1].lower())))\n",
" else:\n",
" new_words.append(word)\n",
" newline = ' '.join(new_words[1:])\n",
" return newline # 修正之后的结果,假如用户输入没有问题,那这时候``newline = line``\n",
"sentence = spell_corrector('When did Beyonce strat becoming popular')\n",
"print(sentence)\n"
]
},
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment