{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## 搭建一个简单的问答系统 (Building a Simple QA System)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本次项目的目标是搭建一个基于检索式的简易的问答系统,这是一个最经典的方法也是最有效的方法。  \n",
    "\n",
    "```不要单独创建一个文件,所有的都在这里面编写,不要试图改已经有的函数名字 (但可以根据需求自己定义新的函数)```\n",
    "\n",
    "```预估完成时间```: 5-10小时"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 检索式的问答系统\n",
    "问答系统所需要的数据已经提供,对于每一个问题都可以找得到相应的答案,所以可以理解为每一个样本数据是 ``<问题、答案>``。 那系统的核心是当用户输入一个问题的时候,首先要找到跟这个问题最相近的已经存储在库里的问题,然后直接返回相应的答案即可(但实际上也可以抽取其中的实体或者关键词)。 举一个简单的例子:\n",
    "\n",
    "假设我们的库里面已有存在以下几个<问题,答案>:\n",
    "- <\"贪心学院主要做什么方面的业务?”, “他们主要做人工智能方面的教育”>\n",
    "- <“国内有哪些做人工智能教育的公司?”, “贪心学院”>\n",
    "- <\"人工智能和机器学习的关系什么?\", \"其实机器学习是人工智能的一个范畴,很多人工智能的应用要基于机器学习的技术\">\n",
    "- <\"人工智能最核心的语言是什么?\", ”Python“>\n",
    "- .....\n",
    "\n",
    "假设一个用户往系统中输入了问题 “贪心学院是做什么的?”, 那这时候系统先去匹配最相近的“已经存在库里的”问题。 那在这里很显然是 “贪心学院是做什么的”和“贪心学院主要做什么方面的业务?”是最相近的。 所以当我们定位到这个问题之后,直接返回它的答案 “他们主要做人工智能方面的教育”就可以了。 所以这里的核心问题可以归结为计算两个问句(query)之间的相似度。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 项目中涉及到的任务描述\n",
    "问答系统看似简单,但其中涉及到的内容比较多。 在这里先做一个简单的解释,总体来讲,我们即将要搭建的模块包括:\n",
    "\n",
    "- 文本的读取: 需要从相应的文件里读取```(问题,答案)```\n",
    "- 文本预处理: 清洗文本很重要,需要涉及到```停用词过滤```等工作\n",
    "- 文本的表示: 如果表示一个句子是非常核心的问题,这里会涉及到```tf-idf```, ```Glove```以及```BERT Embedding```\n",
    "- 文本相似度匹配: 在基于检索式系统中一个核心的部分是计算文本之间的```相似度```,从而选择相似度最高的问题然后返回这些问题的答案\n",
    "- 倒排表: 为了加速搜索速度,我们需要设计```倒排表```来存储每一个词与出现的文本\n",
    "- 词义匹配:直接使用倒排表会忽略到一些意思上相近但不完全一样的单词,我们需要做这部分的处理。我们需要提前构建好```相似的单词```然后搜索阶段使用\n",
    "- 拼写纠错:我们不能保证用户输入的准确,所以第一步需要做用户输入检查,如果发现用户拼错了,我们需要及时在后台改正,然后按照修改后的在库里面搜索\n",
    "- 文档的排序: 最后返回结果的排序根据文档之间```余弦相似度```有关,同时也跟倒排表中匹配的单词有关\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 项目中需要的数据:\n",
    "1. ```dev-v2.0.json```: 这个数据包含了问题和答案的pair, 但是以JSON格式存在,需要编写parser来提取出里面的问题和答案。 \n",
    "2. ```glove.6B```: 这个文件需要从网上下载,下载地址为:https://nlp.stanford.edu/projects/glove/, 请使用d=200的词向量\n",
    "3. ```spell-errors.txt``` 这个文件主要用来编写拼写纠错模块。 文件中第一列为正确的单词,之后列出来的单词都是常见的错误写法。 但这里需要注意的一点是我们没有给出他们之间的概率,也就是p(错误|正确),所以我们可以认为每一种类型的错误都是```同等概率```\n",
    "4. ```vocab.txt``` 这里列了几万个英文常见的单词,可以用这个词库来验证是否有些单词被拼错\n",
    "5. ```testdata.txt``` 这里搜集了一些测试数据,可以用来测试自己的spell corrector。这个文件只是用来测试自己的程序。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在本次项目中,你将会用到以下几个工具:\n",
    "- ```sklearn```。具体安装请见:http://scikit-learn.org/stable/install.html  sklearn包含了各类机器学习算法和数据处理工具,包括本项目需要使用的词袋模型,均可以在sklearn工具包中找得到。 \n",
    "- ```jieba```,用来做分词。具体使用方法请见 https://github.com/fxsjy/jieba\n",
    "- ```bert embedding```: https://github.com/imgarylai/bert-embedding\n",
    "- ```nltk```:https://www.nltk.org/index.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 第一部分:对于训练数据的处理:读取文件和预处理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- ```文本的读取```: 需要从文本中读取数据,此处需要读取的文件是```dev-v2.0.json```,并把读取的文件存入一个列表里(list)\n",
    "- ```文本预处理```: 对于问题本身需要做一些停用词过滤等文本方面的处理\n",
    "- ```可视化分析```: 对于给定的样本数据,做一些可视化分析来更好地理解数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1.1节: 文本的读取\n",
    "把给定的文本数据读入到```qlist```和```alist```当中,这两个分别是列表,其中```qlist```是问题的列表,```alist```是对应的答案列表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Beyoncé\n",
      "Frédéric_Chopin\n",
      "Sino-Tibetan_relations_during_the_Ming_dynasty\n",
      "IPod\n",
      "The_Legend_of_Zelda:_Twilight_Princess\n",
      "Spectre_(2015_film)\n",
      "2008_Sichuan_earthquake\n",
      "New_York_City\n",
      "To_Kill_a_Mockingbird\n",
      "Solar_energy\n",
      "Kanye_West\n",
      "Buddhism\n",
      "American_Idol\n",
      "Dog\n",
      "2008_Summer_Olympics_torch_relay\n",
      "Genome\n",
      "Comprehensive_school\n",
      "Republic_of_the_Congo\n",
      "Prime_minister\n",
      "Institute_of_technology\n",
      "Wayback_Machine\n",
      "Dutch_Republic\n",
      "Symbiosis\n",
      "Canadian_Armed_Forces\n",
      "Cardinal_(Catholicism)\n",
      "Iranian_languages\n",
      "Lighting\n",
      "Separation_of_powers_under_the_United_States_Constitution\n",
      "Architecture\n",
      "Human_Development_Index\n",
      "Southern_Europe\n",
      "BBC_Television\n",
      "Arnold_Schwarzenegger\n",
      "Plymouth\n",
      "Heresy\n",
      "Warsaw_Pact\n",
      "Materialism\n",
      "Christian\n",
      "Sony_Music_Entertainment\n",
      "Oklahoma_City\n",
      "Hunter-gatherer\n",
      "United_Nations_Population_Fund\n",
      "Russian_Soviet_Federative_Socialist_Republic\n",
      "Alexander_Graham_Bell\n",
      "Pub\n",
      "Internet_service_provider\n",
      "Comics\n",
      "Saint_Helena\n",
      "Aspirated_consonant\n",
      "Hydrogen\n",
      "Space_Race\n",
      "Web_browser\n",
      "BeiDou_Navigation_Satellite_System\n",
      "Canon_law\n",
      "Communications_in_Somalia\n",
      "Catalan_language\n",
      "Boston\n",
      "Universal_Studios\n",
      "Estonian_language\n",
      "Paper\n",
      "Adult_contemporary_music\n",
      "Daylight_saving_time\n",
      "Royal_Institute_of_British_Architects\n",
      "National_Archives_and_Records_Administration\n",
      "Tristan_da_Cunha\n",
      "University_of_Kansas\n",
      "Nanjing\n",
      "Arena_Football_League\n",
      "Dialect\n",
      "Bern\n",
      "Westminster_Abbey\n",
      "Political_corruption\n",
      "Classical_music\n",
      "Slavs\n",
      "Southampton\n",
      "Treaty\n",
      "Josip_Broz_Tito\n",
      "Marshall_Islands\n",
      "Szlachta\n",
      "Virgil\n",
      "Alps\n",
      "Gene\n",
      "Guinea-Bissau\n",
      "List_of_numbered_streets_in_Manhattan\n",
      "Brain\n",
      "Near_East\n",
      "Zhejiang\n",
      "Ministry_of_Defence_(United_Kingdom)\n",
      "High-definition_television\n",
      "Wood\n",
      "Somalis\n",
      "Middle_Ages\n",
      "Phonology\n",
      "Computer\n",
      "Black_people\n",
      "The_Times\n",
      "New_Delhi\n",
      "Bird_migration\n",
      "Atlantic_City,_New_Jersey\n",
      "Immunology\n",
      "MP3\n",
      "House_music\n",
      "Letter_case\n",
      "Chihuahua_(state)\n",
      "Imamah_(Shia_doctrine)\n",
      "Pitch_(music)\n",
      "England_national_football_team\n",
      "Houston\n",
      "Copper\n",
      "Identity_(social_science)\n",
      "Himachal_Pradesh\n",
      "Communication\n",
      "Grape\n",
      "Computer_security\n",
      "Orthodox_Judaism\n",
      "Animal\n",
      "Beer\n",
      "Race_and_ethnicity_in_the_United_States_Census\n",
      "United_States_dollar\n",
      "Imperial_College_London\n",
      "Hanover\n",
      "Emotion\n",
      "Everton_F.C.\n",
      "Old_English\n",
      "Aircraft_carrier\n",
      "Federal_Aviation_Administration\n",
      "Lancashire\n",
      "Mesozoic\n",
      "Videoconferencing\n",
      "Gregorian_calendar\n",
      "Xbox_360\n",
      "Military_history_of_the_United_States\n",
      "Hard_rock\n",
      "Great_Plains\n",
      "Infrared\n",
      "Biodiversity\n",
      "ASCII\n",
      "Digestion\n",
      "Gymnastics\n",
      "FC_Barcelona\n",
      "Federal_Bureau_of_Investigation\n",
      "Mary_(mother_of_Jesus)\n",
      "Melbourne\n",
      "John,_King_of_England\n",
      "Macintosh\n",
      "Anti-aircraft_warfare\n",
      "Sanskrit\n",
      "Valencia\n",
      "General_Electric\n",
      "United_States_Army\n",
      "Franco-Prussian_War\n",
      "Adolescence\n",
      "Antarctica\n",
      "Eritrea\n",
      "Uranium\n",
      "Order_of_the_British_Empire\n",
      "Circadian_rhythm\n",
      "Elizabeth_II\n",
      "Sexual_orientation\n",
      "Dell\n",
      "Capital_punishment_in_the_United_States\n",
      "Age_of_Enlightenment\n",
      "Nintendo_Entertainment_System\n",
      "Athanasius_of_Alexandria\n",
      "Seattle\n",
      "Memory\n",
      "Multiracial_American\n",
      "Ashkenazi_Jews\n",
      "Pharmaceutical_industry\n",
      "Umayyad_Caliphate\n",
      "Asphalt\n",
      "Queen_Victoria\n",
      "Freemasonry\n",
      "Israel\n",
      "Hellenistic_period\n",
      "Bill_%26_Melinda_Gates_Foundation\n",
      "Montevideo\n",
      "Poultry\n",
      "Dutch_language\n",
      "Buckingham_Palace\n",
      "Incandescent_light_bulb\n",
      "Arsenal_F.C.\n",
      "Clothing\n",
      "Chicago_Cubs\n",
      "Korean_War\n",
      "Copyright_infringement\n",
      "Greece\n",
      "Royal_Dutch_Shell\n",
      "Mammal\n",
      "East_India_Company\n",
      "Hokkien\n",
      "Professional_wrestling\n",
      "Film_speed\n",
      "Mexico_City\n",
      "Napoleon\n",
      "Germans\n",
      "Southeast_Asia\n",
      "Brigham_Young_University\n",
      "Department_store\n",
      "Intellectual_property\n",
      "Florida\n",
      "Queen_(band)\n",
      "Presbyterianism\n",
      "Thuringia\n",
      "Predation\n",
      "Marvel_Comics\n",
      "British_Empire\n",
      "Botany\n",
      "Madonna_(entertainer)\n",
      "Law_of_the_United_States\n",
      "Myanmar\n",
      "Jews\n",
      "Cotton\n",
      "Data_compression\n",
      "The_Sun_(United_Kingdom)\n",
      "Pesticide\n",
      "Somerset\n",
      "Yale_University\n",
      "Late_Middle_Ages\n",
      "Ann_Arbor,_Michigan\n",
      "Gothic_architecture\n",
      "Cubism\n",
      "Political_philosophy\n",
      "Alloy\n",
      "Norfolk_Island\n",
      "Edmund_Burke\n",
      "Samoa\n",
      "Pope_Paul_VI\n",
      "Electric_motor\n",
      "Switzerland\n",
      "Mali\n",
      "Raleigh,_North_Carolina\n",
      "Nutrition\n",
      "Crimean_War\n",
      "Nonprofit_organization\n",
      "Literature\n",
      "Avicenna\n",
      "Chinese_characters\n",
      "Bermuda\n",
      "Nigeria\n",
      "Utrecht\n",
      "Molotov%E2%80%93Ribbentrop_Pact\n",
      "Capacitor\n",
      "History_of_science\n",
      "Digimon\n",
      "Glacier\n",
      "Comcast\n",
      "Tuberculosis\n",
      "Affirmative_action_in_the_United_States\n",
      "FA_Cup\n",
      "New_Haven,_Connecticut\n",
      "Alsace\n",
      "Carnival\n",
      "Baptists\n",
      "Child_labour\n",
      "North_Carolina\n",
      "Heian_period\n",
      "On_the_Origin_of_Species\n",
      "Dissolution_of_the_Soviet_Union\n",
      "Crucifixion_of_Jesus\n",
      "Supreme_court\n",
      "Textual_criticism\n",
      "Gramophone_record\n",
      "Turner_Classic_Movies\n",
      "Hindu_philosophy\n",
      "Political_party\n",
      "A_cappella\n",
      "Dominican_Order\n",
      "Eton_College\n",
      "Cork_(city)\n",
      "Galicia_(Spain)\n",
      "USB\n",
      "Sichuan\n",
      "Unicode\n",
      "Detroit\n",
      "London\n",
      "Culture\n",
      "Sahara\n",
      "Rule_of_law\n",
      "Tibet\n",
      "Exhibition_game\n",
      "Northwestern_University\n",
      "Strasbourg\n",
      "Oklahoma\n",
      "History_of_India\n",
      "Gamal_Abdel_Nasser\n",
      "Pope_John_XXIII\n",
      "Time\n",
      "European_Central_Bank\n",
      "St._John%27s,_Newfoundland_and_Labrador\n",
      "John_von_Neumann\n",
      "PlayStation_3\n",
      "Royal_assent\n",
      "Group_(mathematics)\n",
      "Central_African_Republic\n",
      "Asthma\n",
      "LaserDisc\n",
      "George_VI\n",
      "Federalism\n",
      "Annelid\n",
      "God\n",
      "War_on_Terror\n",
      "Labour_Party_(UK)\n",
      "Estonia\n",
      "Alaska\n",
      "Karl_Popper\n",
      "Mandolin\n",
      "Insect\n",
      "Race_(human_categorization)\n",
      "Paris\n",
      "Apollo\n",
      "United_States_presidential_election,_2004\n",
      "Liberal_Party_of_Australia\n",
      "Samurai\n",
      "Software_testing\n",
      "States_of_Germany\n",
      "Glass\n",
      "Planck_constant\n",
      "Renewable_energy_commercialization\n",
      "Palermo\n",
      "Green\n",
      "Zinc\n",
      "Neoclassical_architecture\n",
      "Serbo-Croatian\n",
      "CBC_Television\n",
      "Appalachian_Mountains\n",
      "IBM\n",
      "Energy\n",
      "East_Prussia\n",
      "Ottoman_Empire\n",
      "Philosophy_of_space_and_time\n",
      "Neolithic\n",
      "Friedrich_Hayek\n",
      "Diarrhea\n",
      "Madrasa\n",
      "Miami\n",
      "Philadelphia\n",
      "John_Kerry\n",
      "Rajasthan\n",
      "Guam\n",
      "Empiricism\n",
      "Idealism\n",
      "Czech_language\n",
      "Education\n",
      "Tennessee\n",
      "Post-punk\n",
      "Canadian_football\n",
      "Seven_Years%27_War\n",
      "Richard_Feynman\n",
      "Muammar_Gaddafi\n",
      "Cyprus\n",
      "Steven_Spielberg\n",
      "Elevator\n",
      "Neptune\n",
      "Railway_electrification_system\n",
      "Spanish_language_in_the_United_States\n",
      "Charleston,_South_Carolina\n",
      "The_Blitz\n",
      "Endangered_Species_Act\n",
      "Vacuum\n",
      "Han_dynasty\n",
      "Quran\n",
      "Geography_of_the_United_States\n",
      "Compact_disc\n",
      "Transistor\n",
      "Modern_history\n",
      "51st_state\n",
      "Antenna_(radio)\n",
      "Flowering_plant\n",
      "Hyderabad\n",
      "Santa_Monica,_California\n",
      "Washington_University_in_St._Louis\n",
      "Central_Intelligence_Agency\n",
      "Pain\n",
      "Database\n",
      "Tucson,_Arizona\n",
      "Armenia\n",
      "Bacteria\n",
      "Printed_circuit_board\n",
      "Greeks\n",
      "Premier_League\n",
      "Roman_Republic\n",
      "Pacific_War\n",
      "San_Diego\n",
      "Muslim_world\n",
      "Iran\n",
      "British_Isles\n",
      "Association_football\n",
      "Georgian_architecture\n",
      "Liberia\n",
      "Alfred_North_Whitehead\n",
      "Antibiotics\n",
      "Windows_8\n",
      "Swaziland\n",
      "Translation\n",
      "Airport\n",
      "Kievan_Rus%27\n",
      "Super_Nintendo_Entertainment_System\n",
      "Sumer\n",
      "Tuvalu\n",
      "Immaculate_Conception\n",
      "Namibia\n",
      "Russian_language\n",
      "United_States_Air_Force\n",
      "Light-emitting_diode\n",
      "Great_power\n",
      "Bird\n",
      "Qing_dynasty\n",
      "Indigenous_peoples_of_the_Americas\n",
      "Red\n",
      "Egypt\n",
      "Mosaic\n",
      "University\n",
      "Religion_in_ancient_Rome\n",
      "YouTube\n",
      "Separation_of_church_and_state_in_the_United_States\n",
      "Protestantism\n",
      "Bras%C3%ADlia\n",
      "Economy_of_Greece\n",
      "Party_leaders_of_the_United_States_House_of_Representatives\n",
      "Armenians\n",
      "Jehovah%27s_Witnesses\n",
      "Dwight_D._Eisenhower\n",
      "The_Bronx\n",
      "Financial_crisis_of_2007%E2%80%9308\n",
      "Portugal\n",
      "Humanism\n",
      "Geological_history_of_Earth\n",
      "Police\n",
      "Genocide\n",
      "Saint_Barth%C3%A9lemy\n",
      "Tajikistan\n",
      "University_of_Notre_Dame\n",
      "Anthropology\n",
      "Montana\n",
      "Punjab,_Pakistan\n",
      "Richmond,_Virginia\n",
      "Infection\n",
      "Hunting\n",
      "Kathmandu\n",
      "Myocardial_infarction\n",
      "Matter\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "def read_corpus():\n",
    "    \"\"\"\n",
    "    读取给定的语料库,并把问题列表和答案列表分别写入到 qlist, alist 里面。 在此过程中,不用对字符换做任何的处理(这部分需要在 Part 2.3里处理)\n",
    "    qlist = [\"问题1\", “问题2”, “问题3” ....]\n",
    "    alist = [\"答案1\", \"答案2\", \"答案3\" ....]\n",
    "    务必要让每一个问题和答案对应起来(下标位置一致)\n",
    "    \"\"\"\n",
    "    # TODO 需要完成的代码部分 ...\n",
    "    \n",
    "    f_path = \"train-v2.0.json\"\n",
    "    with open(f_path,'r',encoding='utf-8') as f:\n",
    "        data = json.load(f)\n",
    "    data = data[\"data\"]\n",
    "    qlist = []\n",
    "    alist = []\n",
    "    for d in data:\n",
    "        print(d[\"title\"])\n",
    "        for x in d[\"paragraphs\"]:\n",
    "            for qa in x[\"qas\"]:\n",
    "                answer_key = \"answers\"\n",
    "                if qa[\"is_impossible\"]:\n",
    "                    answer_key = \"plausible_answers\"\n",
    "                qlist.append(qa[\"question\"])\n",
    "                alist.append(qa[answer_key][0][\"text\"])\n",
    "\n",
    "    \n",
    "    assert len(qlist) == len(alist)  # 确保长度一样\n",
    "    return qlist, alist\n",
    "\n",
    "qlist,alist = read_corpus()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1.2 理解数据(可视化分析/统计信息)\n",
    "对数据的理解是任何AI工作的第一步, 需要对数据有个比较直观的认识。在这里,简单地统计一下:\n",
    "\n",
    "- 在```qlist```出现的总单词个数\n",
    "- 按照词频画一个```histogram``` plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "76475\n"
     ]
    }
   ],
   "source": [
    "# TODO: 统计一下在qlist中总共出现了多少个单词? 总共出现了多少个不同的单词(unique word)?\n",
    "#       这里需要做简单的分词,对于英文我们根据空格来分词即可,其他过滤暂不考虑(只需分词)\n",
    "\n",
    "def statistics(l):\n",
    "    result = {}\n",
    "    for sentence in l:\n",
    "        words = sentence.split(\" \")\n",
    "        for word in words:\n",
    "            if word in result:\n",
    "                result[word] += 1\n",
    "            else:\n",
    "                result[word] = 1\n",
    "    sorted_items = sorted(result.items(),key=lambda x: -x[1])\n",
    "    return sorted_items\n",
    "\n",
    "sorted_items = statistics(qlist)\n",
    "word_total = len(sorted_items)\n",
    "print (word_total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(1, 35949), (2, 13166), (3, 5803), (4, 3747), (5, 2349)]\n",
      "[35949, 13166, 5803, 3747, 2349]\n",
      "88365\n",
      "88365\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEKCAYAAAARnO4WAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEDFJREFUeJzt3X2MZXV9x/H3x91FHjSgZaQKjIsNwaJVgdFoUStYK4qi\nNtJiSmMb4zapjdKaqKip9g+bNPHZqnGLVHzCsCiK1IcuPhGTFmSByPJUrKDyYNEau2CJu8C3f9yz\ndFx3Zs7M3DOzc3/vV3Iz55x7zv19fzN3PnPmd8/93VQVkqTJ95DVLkCStDIMfElqhIEvSY0w8CWp\nEQa+JDXCwJekRhj4ktQIA1+SGmHgS1Ij1q92AbMdeuihtXHjxtUuQ5LWjG3btv20qqb67LtPBf7G\njRu58sorV7sMSVozkvyg774O6UhSIwx8SWqEgS9JjTDwJakRBr4kNWLQq3SS3ArcDdwP3FdVM0O2\nJ0ma20pclnlSVf10BdqRJM3DIR1JasTQgV/ApUm2Jdk0cFuSpHkMPaTzzKq6PcmjgK1Jbqyqy2bv\n0P0h2AQwPT09cDlr334btjy4vHPX6atYiaS1ZtAz/Kq6vft6F3AR8LS97LO5qmaqamZqqtd0EJKk\nJRgs8JMclOThu5eBPwC2D9WeJGl+Qw7pHAZclGR3O5+uqq8M2J4kaR6DBX5VfR948lCPL0laHC/L\nlKRGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJ\naoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RG\nGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDVi8MBPsi7J1UkuGbotSdLcVuIM/3XADSvQ\njiRpHoMGfpIjgFOBc4ZsR5K0sKHP8N8LvAF4YOB2JEkLWD/UAyd5EXBXVW1L8px59tsEbAKYnp4e\nqpyJtN+GLQ8u79x1+ipWMj6T2Ke9aaWf2rcMeYZ/InBakluBzwAnJ/nknjtV1eaqmqmqmampqQHL\nkaS2DRb4VXV2VR1RVRuBM4CvV9WZQ7UnSZqf1+FLUiMGG8Ofraq+CXxzJdqSJO2dZ/iS1AgDX5Ia\nYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREG\nviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBL\nUiMMfElqRK/AT/I7QxciSRpW3zP8DyW5IslfJjl40IokSYPoFfhV9SzgT4AjgW1JPp3keYNWJkka\nq95j+FV1M/BW4I3A7wHvT3Jjkj8cqjhJ0vj0HcN/UpL3ADcAJwMvrqrf7pbfM2B9kqQxWd9zvw8A\n5wBvrqp7d2+sqjuSvHVvByTZH7gMeGjXzoVV9bZl1itJWqK+gX8qcG9V3Q+Q5CHA/lX1v1X1iTmO\n+SVwclXdk2QD8O0kX66qf19+2ZKkxeo7hn8pcMCs9QO7bXOqkXu61Q3drRZdoSRpLPoG/v6zwptu\n+cCFDkqyLsk1wF3A1qq6fGllSpKWq++Qzi+SHF9VVwEkOQG4d4Fj6IaAnpLkEOCiJE+squ2z90my\nCdgEMD09vajiJ8F+G7Y8uLxz1+kL7iNJS9U38M8CtiS5Awjwm8Af922kqn6e5BvAKcD2Pe7bDGwG\nmJmZcchHkgbSK/Cr6jtJHg8c0226qap2zXdMkilgVxf2BwDPA/5hWdVKkpas7xk+wFOBjd0xxyeh\nqj4+z/6PBs5Lso7RawUXVNUlS65UkrQsvQI/ySeA3wKuAe7vNhcwZ+BX1XeB45ZboCRpPPqe4c8A\nx1aVY+yStEb1vSxzO6MXaiVJa1TfM/xDgeuTXMHoHbQAVNVpg1QlSRq7voH/9iGLkCQNr+9lmd9K\n8ljg6Kq6NMmBwLphS5MkjVPf6ZFfDVwIfKTbdDjw+aGKkiSNX98XbV8DnAjsgAc/DOVRQxUlSRq/\nvoH/y6rauXslyXqc+VKS1pS+gf+tJG8GDug+y3YL8MXhypIkjVvfwH8T8BPgWuAvgC8x+nxbSdIa\n0fcqnQeAf+pukqQ1qO9cOrewlzH7qnrc2CuSJA1iMXPp7LY/cDrwyPGXI0kaSq8x/Kr671m326vq\nvYw+2FyStEb0HdI5ftbqQxid8S9mLn1J0irrG9rvmrV8H3Ar8Edjr0aSNJi+V+mcNHQhkqRh9R3S\n+Zv57q+qd4+nHEnSUBZzlc5TgYu79RcDVwA3D1GUJGn8+gb+EcDxVXU3QJK3A/9SVWcOVZgkabz6\nTq1wGLBz1vrObpskaY3oe4b/ceCKJBd16y8FzhumJEnSEPpepfOOJF8GntVt+vOqunq4siRJ49Z3\nSAfgQGBHVb0PuC3JUQPVJEkaQN+POHwb8Ebg7G7TBuCTQxUlSRq/vmf4LwNOA34BUFV3AA8fqihJ\n0vj1DfydVVV0UyQnOWi4kiRJQ+gb+Bck+QhwSJJXA5fih6FI0prS9yqdd3afZbsDOAb426raOmhl\nkqSxWjDwk6wDLu0mUDPkJWmNWnBIp6ruBx5IcvAK1CNJGkjfd9reA1ybZCvdlToAVfXaQaqSJI1d\n38D/XHfrLcmRjKZkOIzR1T2buzdtSZJWwbyBn2S6qn5YVUuZN+c+4PVVdVWShwPbkmytquuXVKkk\naVkWGsP//O6FJJ9dzANX1Z1VdVW3fDdwA3D4oiuUJI3FQoGfWcuPW2ojSTYCxwGXL/UxJEnLs9AY\nfs2x3FuShwGfBc6qqh17uX8TsAlgenp6KU2sCftt2LKofXbuOn3Jj9/32LlqWmzbi62pz/dC+6bl\nPEe1+hY6w39ykh1J7gae1C3vSHJ3kl8L7z0l2cAo7D9VVXt90beqNlfVTFXNTE1NLb4HkqRe5j3D\nr6p1S33gJAE+Ctzgh5xL0upbzHz4i3Ui8KfAyUmu6W4vHLA9SdI8+l6Hv2hV9W1+9UVfSdIqGvIM\nX5K0DzHwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+JDXCwJekRhj4\nktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph4EtSIwx8SWqEgS9J\njTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0YLPCTnJvkriTbh2pDktTfkGf4\nHwNOGfDxJUmLMFjgV9VlwM+GenxJ0uKsX+0CkmwCNgFMT0+vcjWw34YtDy7v3HX6ih8/hNk1LefY\n2f3ps325+jzWXN/jxdbX92e12O/HXMcux1DP0ZV+7vb5WSz2+zquuvesbbHP/X3ld39Pq/6ibVVt\nrqqZqpqZmppa7XIkaWKteuBLklaGgS9JjRjysszzgX8DjklyW5JXDdWWJGlhg71oW1WvGOqxJUmL\n55COJDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhph\n4EtSIwx8SWqEgS9JjTDwJakRBr4kNcLAl6RGGPiS1AgDX5IaYeBLUiMMfElqhIEvSY0w8CWpEQa+\nJDXCwJekRhj4ktQIA1+SGmHgS1IjDHxJaoSBL0mNMPAlqRGDBn6SU5LclOR7Sd40ZFuSpPkNFvhJ\n1gEfBF4AHAu8IsmxQ7UnSZrfkGf4TwO+V1Xfr6qdwGeAlwzYniRpHkMG/uHAj2at39ZtkyStgvWr\nXUCSTcCmbvWeJDct4vBDgZ+Ov6qRZPWOn+fYBfs8VN2L3b7YfRaw136Pq+2l1Lec70fP/fe5n/U4\n257ncQf9WS/FEL8Te1hOlj22745DBv7twJGz1o/otv2KqtoMbF5KA0murKqZpZW3NrXYZ2iz3y32\nGdrs90r1ecghne8ARyc5Ksl+wBnAxQO2J0max2Bn+FV1X5K/Ar4KrAPOrarrhmpPkjS/Qcfwq+pL\nwJcGbGJJQ0FrXIt9hjb73WKfoc1+r0ifU1Ur0Y4kaZU5tYIkNWJNBn4rUzYkOTLJN5Jcn+S6JK/r\ntj8yydYkN3dfH7HatY5bknVJrk5ySbfeQp8PSXJhkhuT3JDkGZPe7yR/3T23tyc5P8n+k9jnJOcm\nuSvJ9lnb5uxnkrO7fLspyfPHVceaC/zGpmy4D3h9VR0LPB14TdfXNwFfq6qjga9165PmdcANs9Zb\n6PP7gK9U1eOBJzPq/8T2O8nhwGuBmap6IqOLO85gMvv8MeCUPbbttZ/d7/gZwBO6Yz7U5d6yrbnA\np6EpG6rqzqq6qlu+m1EAHM6ov+d1u50HvHR1KhxGkiOAU4FzZm2e9D4fDDwb+ChAVe2sqp8z4f1m\ndOHIAUnWAwcCdzCBfa6qy4Cf7bF5rn6+BPhMVf2yqm4Bvsco95ZtLQZ+k1M2JNkIHAdcDhxWVXd2\nd/0YOGyVyhrKe4E3AA/M2jbpfT4K+Anwz91Q1jlJDmKC+11VtwPvBH4I3An8T1X9KxPc5z3M1c/B\nMm4tBn5zkjwM+CxwVlXtmH1fjS6zmphLrZK8CLirqrbNtc+k9bmzHjge+HBVHQf8gj2GMiat392Y\n9UsY/bF7DHBQkjNn7zNpfZ7LSvVzLQZ+rykbJkWSDYzC/lNV9blu838leXR3/6OBu1arvgGcCJyW\n5FZGw3UnJ/kkk91nGJ3F3VZVl3frFzL6AzDJ/f594Jaq+klV7QI+B/wuk93n2ebq52AZtxYDv5kp\nG5KE0ZjuDVX17ll3XQy8slt+JfCFla5tKFV1dlUdUVUbGf1sv15VZzLBfQaoqh8DP0pyTLfpucD1\nTHa/fwg8PcmB3XP9uYxep5rkPs82Vz8vBs5I8tAkRwFHA1eMpcWqWnM34IXAfwD/CbxltesZsJ/P\nZPRv3neBa7rbC4HfYPSq/s3ApcAjV7vWgfr/HOCSbnni+ww8Bbiy+3l/HnjEpPcb+DvgRmA78Ang\noZPYZ+B8Rq9T7GL039yr5usn8JYu324CXjCuOnynrSQ1Yi0O6UiSlsDAl6RGGPiS1AgDX5IaYeBL\nUiMMfDWlm330+XtsOyvJh+c55p7hK5OGZ+CrNeczekPXbGd026WJZuCrNRcCp3bv0t49Kd1jgKuT\nfC3JVUmuTfJrM7Amec7u+fm79X9M8mfd8glJvpVkW5Kv7n7LvLQvMfDVlKr6GaO3qb+g23QGcAFw\nL/CyqjoeOAl4V/d2/wV18x19AHh5VZ0AnAu8Y9y1S8s16IeYS/uo3cM6X+i+vgoI8PdJns1oWubD\nGU1X++Mej3cM8ERga/c3Yh2jt9FL+xQDXy36AvCeJMcDB1bVtm5oZgo4oap2dbN17r/Hcffxq/8V\n774/wHVV9Yxhy5aWxyEdNaeq7gG+wWjoZfeLtQczmod/V5KTgMfu5dAfAMd2sxgewmh2RxhNcDWV\n5BkwGuJJ8oRBOyEtgWf4atX5wEX8/xU7nwK+mORaRjNW3rjnAVX1oyQXMJrZ8Rbg6m77ziQvB97f\nfVThekaf2nXd4L2QFsHZMiWpEQ7pSFIjDHxJaoSBL0mNMPAlqREGviQ1wsCXpEYY+JLUCANfkhrx\nfy/bKevU4s4kAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0xaece9e8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# TODO: 统计一下qlist中出现1次,2次,3次... 出现的单词个数, 然后画一个plot. 这里的x轴是单词出现的次数(1,2,3,..), y轴是单词个数。\n",
    "#       从左到右分别是 出现1次的单词数,出现2次的单词数,出现3次的单词数... \n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "max_freq = sorted_items[0][1]\n",
    "x = { i+1:0 for i in range(max_freq)}\n",
    "for items in sorted_items:\n",
    "    x[items[1]] += 1\n",
    "\n",
    "print(list(x.items())[:5])\n",
    "print(list(x.values())[:5])\n",
    "print(len(x))\n",
    "print(len(list(x.values())))\n",
    "limited = 100\n",
    "plt.hist(x=np.array(list(x.values())[:limited]), bins=[i+1 for i in range(limited)], color='#0504aa')\n",
    "plt.xlabel('Value')\n",
    "plt.ylabel('Frequency')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# TODO: 从上面的图中能观察到什么样的现象? 这样的一个图的形状跟一个非常著名的函数形状很类似,能所出此定理吗? \n",
    "#       hint: [XXX]'s law\n",
    "# \n",
    "# "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "#### 1.3 文本预处理\n",
    "此部分需要做文本方面的处理。 以下是可以用到的一些方法:\n",
    "\n",
    "- 1. 停用词过滤 (去网上搜一下 \"english stop words list\",会出现很多包含停用词库的网页,或者直接使用NLTK自带的)   \n",
    "- 2. 转换成lower_case: 这是一个基本的操作   \n",
    "- 3. 去掉一些无用的符号: 比如连续的感叹号!!!, 或者一些奇怪的单词。\n",
    "- 4. 去掉出现频率很低的词:比如出现次数少于10,20.... (想一下如何选择阈值)\n",
    "- 5. 对于数字的处理: 分词完只有有些单词可能就是数字比如44,415,把所有这些数字都看成是一个单词,这个新的单词我们可以定义为 \"#number\"\n",
    "- 6. lemmazation: 在这里不要使用stemming, 因为stemming的结果有可能不是valid word。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['when did beyonce start becoming popular', 'what areas did beyonce compete in when she was growing up', 'when did beyonce leave destinys child and become a solo singer', 'in what city and state did beyonce  grow up ', 'in which decade did beyonce become famous']\n"
     ]
    }
   ],
   "source": [
    "# TODO: 需要做文本方面的处理。 从上述几个常用的方法中选择合适的方法给qlist做预处理(不一定要按照上面的顺序,不一定要全部使用)\n",
    "\n",
    "import re\n",
    "def clean(sentence):\n",
    "    result = []\n",
    "    for word in sentence.split(\" \"):\n",
    "        # 如果是停用词 则 continue\n",
    "\n",
    "        #  去标点\n",
    "        if re.match(r\"^[\\!@#$%^&*\\(\\)\\\\\\{\\}\\[\\]:\\\"\\';\\<\\>\\,\\./\\-\\+=_【】;‘:“《》,。?、\\?]+$\",word):\n",
    "            continue\n",
    "        word = re.sub(r\"[\\!@#$%^&*\\(\\)\\\\\\{\\}\\[\\]:\\\"\\';\\<\\>\\,\\./\\-\\+=_【】;‘:“《》,。?、\\?]+\",\"\",word)\n",
    "\n",
    "        # 小写\n",
    "        word = word.lower()\n",
    "        \n",
    "        # 去低频词\n",
    "\n",
    "        # 去数字\n",
    "        word = re.sub(r\"[0-9]+\",\"#number\",word)\n",
    "\n",
    "        # 词性还原 lemmazation\n",
    "\n",
    "        result.append(word)\n",
    "    return \" \".join(result)\n",
    "for i in range(len(qlist)):\n",
    "    q = qlist[i]\n",
    "    qlist[i] = clean(q)\n",
    "    \n",
    "print(qlist[:5])\n",
    "qlist =  qlist   # 更新后的问题列表"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 第二部分: 文本的表示\n",
    "当我们做完必要的文本处理之后就需要想办法表示文本了,这里有几种方式\n",
    "\n",
    "- 1. 使用```tf-idf vector```\n",
    "- 2. 使用embedding技术如```word2vec```, ```bert embedding```等\n",
    "\n",
    "下面我们分别提取这三个特征来做对比。 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.1 使用tf-idf表示向量\n",
    "把```qlist```中的每一个问题的字符串转换成```tf-idf```向量, 转换之后的结果存储在```X```矩阵里。 ``X``的大小是: ``N* D``的矩阵。 这里``N``是问题的个数(样本个数),\n",
    "``D``是词典库的大小"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 0. 0. 0.]]\n",
      "[0. 0. 0. ... 0. 0. 0.]\n",
      "2.220026903744897\n"
     ]
    }
   ],
   "source": [
    "# TODO \n",
    "from sklearn import feature_extraction\n",
    "from sklearn.feature_extraction.text import CountVectorizer \n",
    "from sklearn.feature_extraction.text import TfidfTransformer \n",
    "\n",
    "vectorizer =  CountVectorizer()    # 定义一个tf-idf的vectorizer\n",
    "transformer = TfidfTransformer()\n",
    "\n",
    "X = vectorizer.fit_transform(qlist[:1000])\n",
    "tfidf = transformer.fit_transform(X)\n",
    "X_tfidf = tfidf.toarray()  # 结果存放在X矩阵里\n",
    "\n",
    "print(X_tfidf[:5])\n",
    "print(X_tfidf[0])\n",
    "print(sum(X_tfidf[0]))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.2 使用wordvec + average pooling\n",
    "词向量方面需要下载: https://nlp.stanford.edu/projects/glove/ (请下载``glove.6B.zip``),并使用``d=200``的词向量(200维)。国外网址如果很慢,可以在百度上搜索国内服务器上的。 每个词向量获取完之后,即可以得到一个句子的向量。 我们通过``average pooling``来实现句子的向量。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# TODO 基于Glove向量获取句子向量\n",
    "def readWord2Vec(fd):\n",
    "    word2id = dict()\n",
    "    id2word = dict()\n",
    "    word2vec = []\n",
    "    for line in fd.readlines():\n",
    "        l = line.strip()\n",
    "        ele = l.split()\n",
    "        if len(ele) != 201:\n",
    "            continue\n",
    "            print(line)\n",
    "        word = ele[0]\n",
    "        if word not in word2id:\n",
    "            word2id[word] = len(word2id)\n",
    "            id2word[len(id2word)] = word\n",
    "            word2vec.append(ele[1:])\n",
    "    word2vec = np.array(word2vec,dtype='float32')\n",
    "    return word2id,id2word,word2vec\n",
    "\n",
    "def getQueryW2V(query):\n",
    "    words = query.split(\" \")\n",
    "    sentenceVec = np.zeros(200).reshape((1,200))\n",
    "    count = 0\n",
    "    for word in words:\n",
    "        count += 1\n",
    "        if word not in word2id:\n",
    "            continue\n",
    "        sentenceVec += word2vec[word2id[word]].reshape((1,200))\n",
    "    if count != 0:\n",
    "        sentenceVec = sentenceVec/count\n",
    "    return sentenceVec\n",
    "\n",
    "glovefile = open(\"glove.6B.200d.txt\",\"r\",encoding=\"utf-8\") \n",
    "word2id,id2word,word2vec = readWord2Vec(glovefile)\n",
    "print(list(word2id.items())[:5])\n",
    "sentenceVecs = []\n",
    "for q in qlist[:1000]:\n",
    "    sentenceVec = getQueryW2V(q)\n",
    "    sentenceVecs.append(sentenceVec)\n",
    "print(np.sum(sentenceVecs[:5]))\n",
    "emb  =  word2vec # 这是 D*H的矩阵,这里的D是词典库的大小, H是词向量的大小。 这里面我们给定的每个单词的词向量,\n",
    "        # 这需要从文本中读取\n",
    "    \n",
    "X_w2v =   sentenceVecs # 初始化完emb之后就可以对每一个句子来构建句子向量了,这个过程使用average pooling来实现\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.3 使用BERT + average pooling\n",
    "最近流行的BERT也可以用来学出上下文相关的词向量(contex-aware embedding), 在很多问题上得到了比较好的结果。在这里,我们不做任何的训练,而是直接使用已经训练好的BERT embedding。 具体如何训练BERT将在之后章节里体会到。 为了获取BERT-embedding,可以直接下载已经训练好的模型从而获得每一个单词的向量。可以从这里获取: https://github.com/imgarylai/bert-embedding , 请使用```bert_12_768_12```\t当然,你也可以从其他source获取也没问题,只要是合理的词向量。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# TODO 基于BERT的句子向量计算\n",
    "from bert_embedding import BertEmbedding\n",
    "bert_embedding = BertEmbedding(model='bert_12_768_12')\n",
    "result_bert = bert_embedding(qlist[:1000])\n",
    "X_bert = [] # 每一个句子的向量结果存放在X_bert矩阵里。行数为句子的总个数,列数为一个句子embedding大小。 \n",
    "for ele in result_bert:\n",
    "    sentence_embedding = ele[1]\n",
    "    X_bert.append(sentence_embedding)\n",
    "print(X_bert[:5])   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### 第三部分: 相似度匹配以及搜索\n",
    "在这部分里,我们需要把用户每一个输入跟知识库里的每一个问题做一个相似度计算,从而得出最相似的问题。但对于这个问题,时间复杂度其实很高,所以我们需要结合倒排表来获取相似度最高的问题,从而获得答案。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.1 tf-idf + 余弦相似度\n",
    "我们可以直接基于计算出来的``tf-idf``向量,计算用户最新问题与库中存储的问题之间的相似度,从而选择相似度最高的问题的答案。这个方法的复杂度为``O(N)``, ``N``是库中问题的个数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "tfidf_vocab = vectorizer.vocabulary_\n",
    "idf = transformer.idf_\n",
    "def get_top_results_tfidf_noindex(query):\n",
    "# def get_top_results_tfidf_noindex(query,X,y):\n",
    "    # TODO 需要编写\n",
    "    \"\"\"\n",
    "    给定用户输入的问题 query, 返回最有可能的TOP 5问题。这里面需要做到以下几点:\n",
    "    1. 对于用户的输入 query 首先做一系列的预处理(上面提到的方法),然后再转换成tf-idf向量(利用上面的vectorizer)\n",
    "    2. 计算跟每个库里的问题之间的相似度\n",
    "    3. 找出相似度最高的top5问题的答案\n",
    "    \"\"\"\n",
    "    q_emb = getTFIDF(query,tfidf_vocab,idf)\n",
    "    all_cosine = {}\n",
    "    queue = Q.PriorityQueue()\n",
    "    for i in range(len(X_tfidf)):\n",
    "        x = X_tfidf[i]\n",
    "        cos = cosine(q_emb,x)\n",
    "        if cos not in all_cosine:\n",
    "            all_cosine[cos] = []\n",
    "            queue.put(cos)\n",
    "        all_cosine[cos].append(i)\n",
    "\n",
    "    top_idxs = []  # top_idxs存放相似度最高的(存在qlist里的)问题的下标 \n",
    "                   # hint: 请使用 priority queue来找出top results. 思考为什么可以这么做? \n",
    "    top = 5\n",
    "    while len(top_idxs) < top and not queue.empty():\n",
    "        cos = queue.get()\n",
    "        top_idxs.extend(all_cosine[cos])\n",
    "    result = []\n",
    "    for idx in top_idxs:\n",
    "        result.append(alist[idx])\n",
    "    return result  # 返回相似度最高的问题对应的答案,作为TOP5答案    \n",
    "\n",
    "def getTFIDF(query,vocab,idf):\n",
    "    result = [ 0. for _ in range(len(vocab)) ]\n",
    "    for word in query.split(\" \"):\n",
    "        if word not in vocab:\n",
    "            continue\n",
    "        result[vocab[word]] += 1\n",
    "    result = np.array(result)\n",
    "    result = result * idf\n",
    "    denominator = np.sum(result**2)**0.5\n",
    "    return np.array(result/denominator,dtype='float32')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# TODO: 编写几个测试用例,并输出结果\n",
    "print (get_top_results_tfidf_noindex(\"\"))\n",
    "print (get_top_results_tfidf_noindex(\"\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "你会发现上述的程序很慢,没错! 是因为循环了所有库里的问题。为了优化这个过程,我们需要使用一种数据结构叫做```倒排表```。 使用倒排表我们可以把单词和出现这个单词的文档做关键。 之后假如要搜索包含某一个单词的文档,即可以非常快速的找出这些文档。 在这个QA系统上,我们首先使用倒排表来快速查找包含至少一个单词的文档,然后再进行余弦相似度的计算,即可以大大减少```时间复杂度```。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.2 倒排表的创建\n",
    "倒排表的创建其实很简单,最简单的方法就是循环所有的单词一遍,然后记录每一个单词所出现的文档,然后把这些文档的ID保存成list即可。我们可以定义一个类似于```hash_map```, 比如 ``inverted_index = {}``, 然后存放包含每一个关键词的文档出现在了什么位置,也就是,通过关键词的搜索首先来判断包含这些关键词的文档(比如出现至少一个),然后对于candidates问题做相似度比较。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# TODO 请创建倒排表\n",
    "def getInversedIndexedTable(l):\n",
    "    result = {}\n",
    "    for i in range(len(l)):\n",
    "        sentence = l[i]\n",
    "        for word in sentence.split(\" \"):\n",
    "            if word not in result:\n",
    "                result[word] = set()\n",
    "            result[word].add(i)\n",
    "    return result\n",
    "inverted_idx = getInversedIndexedTable(qlist)  # 定一个一个简单的倒排表,是一个map结构。 循环所有qlist一遍就可以"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.3 语义相似度\n",
    "这里有一个问题还需要解决,就是语义的相似度。可以这么理解: 两个单词比如car, auto这两个单词长得不一样,但从语义上还是类似的。如果只是使用倒排表我们不能考虑到这些单词之间的相似度,这就导致如果我们搜索句子里包含了``car``, 则我们没法获取到包含auto的所有的文档。所以我们希望把这些信息也存下来。那这个问题如何解决呢? 其实也不难,可以提前构建好相似度的关系,比如对于``car``这个单词,一开始就找好跟它意思上比较类似的单词比如top 10,这些都标记为``related words``。所以最后我们就可以创建一个保存``related words``的一个``map``. 比如调用``related_words['car']``即可以调取出跟``car``意思上相近的TOP 10的单词。 \n",
    "\n",
    "那这个``related_words``又如何构建呢? 在这里我们仍然使用``Glove``向量,然后计算一下俩俩的相似度(余弦相似度)。之后对于每一个词,存储跟它最相近的top 10单词,最终结果保存在``related_words``里面。 这个计算需要发生在离线,因为计算量很大,复杂度为``O(V*V)``, V是单词的总数。 \n",
    "\n",
    "这个计算过程的代码请放在``related.py``的文件里,然后结果保存在``related_words.txt``里。 我们在使用的时候直接从文件里读取就可以了,不用再重复计算。所以在此notebook里我们就直接读取已经计算好的结果。 作业提交时需要提交``related.py``和``related_words.txt``文件,这样在使用的时候就不再需要做这方面的计算了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# TODO 读取语义相关的单词\n",
    "import json\n",
    "def get_related_words(file):\n",
    "    with open(file,\"r\",encoding=\"utf-8\") as f:\n",
    "        related_words = json.load(f)\n",
    "    return related_words\n",
    "\n",
    "related_words = get_related_words('related_words.txt') # 直接放在文件夹的根目录下,不要修改此路径。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.4 利用倒排表搜索\n",
    "在这里,我们使用倒排表先获得一批候选问题,然后再通过余弦相似度做精准匹配,这样一来可以节省大量的时间。搜索过程分成两步:\n",
    "\n",
    "- 使用倒排表把候选问题全部提取出来。首先,对输入的新问题做分词等必要的预处理工作,然后对于句子里的每一个单词,从``related_words``里提取出跟它意思相近的top 10单词, 然后根据这些top词从倒排表里提取相关的文档,把所有的文档返回。 这部分可以放在下面的函数当中,也可以放在外部。\n",
    "- 然后针对于这些文档做余弦相似度的计算,最后排序并选出最好的答案。\n",
    "\n",
    "可以适当定义自定义函数,使得减少重复性代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_top_results_tfidf(query):\n",
    "    \"\"\"\n",
    "    给定用户输入的问题 query, 返回最有可能的TOP 5问题。这里面需要做到以下几点:\n",
    "    1. 利用倒排表来筛选 candidate (需要使用related_words). \n",
    "    2. 对于候选文档,计算跟输入问题之间的相似度\n",
    "    3. 找出相似度最高的top5问题的答案\n",
    "    \"\"\"\n",
    "    candidates = set()\n",
    "    for word in query.split(\" \"):\n",
    "        for idx in inverted_idx[word]:\n",
    "            candidates.add(idx)\n",
    "        for related_word in related_words[word]:\n",
    "            for idx in inverted_idx[related_word]:\n",
    "                candidates.add(idx)\n",
    "    \n",
    "    q_emb = getTFIDF(query,tfidf_vocab,idf) # 获取query tfidf embedding\n",
    "    all_cosine = {}\n",
    "    queue = Q.PriorityQueue()\n",
    "    for i in candidates:\n",
    "        x = X_tfidf[i] # 获取candidate embedding\n",
    "        cos = cosine(q_emb,x)\n",
    "        if cos not in all_cosine:\n",
    "            all_cosine[cos] = []\n",
    "            queue.put(cos)\n",
    "        all_cosine[cos].append(i)\n",
    "    top_idxs = []  # top_idxs存放相似度最高的(存在qlist里的)问题的下表 \n",
    "                   # hint: 利用priority queue来找出top results. 思考为什么可以这么做? \n",
    "    \n",
    "    top = 5\n",
    "    while len(top_idxs) < top and not queue.empty():\n",
    "        cos = queue.get()\n",
    "        top_idxs.extend(all_cosine[cos])\n",
    "    result = []\n",
    "    for idx in top_idxs:\n",
    "        result.append(alist[idx])\n",
    " \n",
    "    return result  # 返回相似度最高的问题对应的答案,作为TOP5答案"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_top_results_w2v(query):\n",
    "    \"\"\"\n",
    "    给定用户输入的问题 query, 返回最有可能的TOP 5问题。这里面需要做到以下几点:\n",
    "    1. 利用倒排表来筛选 candidate (需要使用related_words). \n",
    "    2. 对于候选文档,计算跟输入问题之间的相似度\n",
    "    3. 找出相似度最高的top5问题的答案\n",
    "    \"\"\"\n",
    "    candidates = set()\n",
    "    for word in query.split(\" \"):\n",
    "        for idx in inverted_idx[word]:\n",
    "            candidates.add(idx)\n",
    "        for related_word in related_words[word]:\n",
    "            for idx in inverted_idx[related_word]:\n",
    "                candidates.add(idx)\n",
    "    \n",
    "    q_emb = getQueryW2V(query) # 获取query w2v embedding\n",
    "    all_cosine = {}\n",
    "    queue = Q.PriorityQueue()\n",
    "    for i in candidates:\n",
    "        x = X_w2v[i] # 获取candidate embedding\n",
    "        cos = cosine(q_emb,x)\n",
    "        if cos not in all_cosine:\n",
    "            all_cosine[cos] = []\n",
    "            queue.put(cos)\n",
    "        all_cosine[cos].append(i)\n",
    "    top_idxs = []  # top_idxs存放相似度最高的(存在qlist里的)问题的下表 \n",
    "                   # hint: 利用priority queue来找出top results. 思考为什么可以这么做? \n",
    "    \n",
    "    top = 5\n",
    "    while len(top_idxs) < top and not queue.empty():\n",
    "        cos = queue.get()\n",
    "        top_idxs.extend(all_cosine[cos])\n",
    "    result = []\n",
    "    for idx in top_idxs:\n",
    "        result.append(alist[idx])\n",
    " \n",
    "    return result  # 返回相似度最高的问题对应的答案,作为TOP5答案"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_top_results_bert(query):\n",
    "    \"\"\"\n",
    "    给定用户输入的问题 query, 返回最有可能的TOP 5问题。这里面需要做到以下几点:\n",
    "    1. 利用倒排表来筛选 candidate (需要使用related_words). \n",
    "    2. 对于候选文档,计算跟输入问题之间的相似度\n",
    "    3. 找出相似度最高的top5问题的答案\n",
    "    \"\"\"\n",
    "    candidates = set()\n",
    "    for word in query.split(\" \"):\n",
    "        for idx in inverted_idx[word]:\n",
    "            candidates.add(idx)\n",
    "        for related_word in related_words[word]:\n",
    "            for idx in inverted_idx[related_word]:\n",
    "                candidates.add(idx)\n",
    "    \n",
    "    q_emb = bert_embedding([query])[0] # 获取bert embedding\n",
    "    all_cosine = {}\n",
    "    queue = Q.PriorityQueue()\n",
    "    for i in candidates:\n",
    "        x = X_bert[i] # 获取candidate embedding\n",
    "        cos = cosine(q_emb,x)\n",
    "        if cos not in all_cosine:\n",
    "            all_cosine[cos] = []\n",
    "            queue.put(cos)\n",
    "        all_cosine[cos].append(i)\n",
    "    top_idxs = []  # top_idxs存放相似度最高的(存在qlist里的)问题的下表 \n",
    "                   # hint: 利用priority queue来找出top results. 思考为什么可以这么做? \n",
    "    \n",
    "    top = 5\n",
    "    while len(top_idxs) < top and not queue.empty():\n",
    "        cos = queue.get()\n",
    "        top_idxs.extend(all_cosine[cos])\n",
    "    result = []\n",
    "    for idx in top_idxs:\n",
    "        result.append(alist[idx])\n",
    " \n",
    "    return result # 返回相似度最高的问题对应的答案,作为TOP5答案"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# TODO: 编写几个测试用例,并输出结果\n",
    "\n",
    "test_query1 = \"\"\n",
    "test_query2 = \"\"\n",
    "\n",
    "print (get_top_results_tfidf(test_query1))\n",
    "print (get_top_results_w2v(test_query1))\n",
    "print (get_top_results_bert(test_query1))\n",
    "\n",
    "print (get_top_results_tfidf(test_query2))\n",
    "print (get_top_results_w2v(test_query2))\n",
    "print (get_top_results_bert(test_query2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. 拼写纠错\n",
    "其实用户在输入问题的时候,不能期待他一定会输入正确,有可能输入的单词的拼写错误的。这个时候我们需要后台及时捕获拼写错误,并进行纠正,然后再通过修正之后的结果再跟库里的问题做匹配。这里我们需要实现一个简单的拼写纠错的代码,然后自动去修复错误的单词。\n",
    "\n",
    "这里使用的拼写纠错方法是课程里讲过的方法,就是使用noisy channel model。 我们回想一下它的表示:\n",
    "\n",
    "$c^* = \\text{argmax}_{c\\in candidates} ~~p(c|s) = \\text{argmax}_{c\\in candidates} ~~p(s|c)p(c)$\n",
    "\n",
    "这里的```candidates```指的是针对于错误的单词的候选集,这部分我们可以假定是通过edit_distance来获取的(比如生成跟当前的词距离为1/2的所有的valid 单词。 valid单词可以定义为存在词典里的单词。 ```c```代表的是正确的单词, ```s```代表的是用户错误拼写的单词。 所以我们的目的是要寻找出在``candidates``里让上述概率最大的正确写法``c``。 \n",
    "\n",
    "$p(s|c)$,这个概率我们可以通过历史数据来获得,也就是对于一个正确的单词$c$, 有百分之多少人把它写成了错误的形式1,形式2...  这部分的数据可以从``spell_errors.txt``里面找得到。但在这个文件里,我们并没有标记这个概率,所以可以使用uniform probability来表示。这个也叫做channel probability。\n",
    "\n",
    "$p(c)$,这一项代表的是语言模型,也就是假如我们把错误的$s$,改造成了$c$, 把它加入到当前的语句之后有多通顺?在本次项目里我们使用bigram来评估这个概率。 举个例子: 假如有两个候选 $c_1, c_2$, 然后我们希望分别计算出这个语言模型的概率。 由于我们使用的是``bigram``, 我们需要计算出两个概率,分别是当前词前面和后面词的``bigram``概率。 用一个例子来表示:\n",
    "\n",
    "给定: ``We are go to school tomorrow``, 对于这句话我们希望把中间的``go``替换成正确的形式,假如候选集里有个,分别是``going``, ``went``, 这时候我们分别对这俩计算如下的概率:\n",
    "$p(going|are)p(to|going)$和 $p(went|are)p(to|went)$, 然后把这个概率当做是$p(c)$的概率。 然后再跟``channel probability``结合给出最终的概率大小。\n",
    "\n",
    "那这里的$p(are|going)$这些bigram概率又如何计算呢?答案是训练一个语言模型! 但训练一个语言模型需要一些文本数据,这个数据怎么找? 在这次项目作业里我们会用到``nltk``自带的``reuters``的文本类数据来训练一个语言模型。当然,如果你有资源你也可以尝试其他更大的数据。最终目的就是计算出``bigram``概率。 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4.1 训练一个语言模型\n",
    "在这里,我们使用``nltk``自带的``reuters``数据来训练一个语言模型。 使用``add-one smoothing``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from nltk.corpus import reuters\n",
    "\n",
    "\n",
    "# 循环所有的语料库并构建bigram probability. bigram[word1][word2]: 在word1出现的情况下下一个是word2的概率。 \n",
    "import re\n",
    "import numpy as np\n",
    "import json\n",
    "import os\n",
    "\n",
    "def clean1(sentence):\n",
    "    result = []\n",
    "    for word in sentence:\n",
    "        # 如果是停用词 则 continue\n",
    "\n",
    "        #  去标点\n",
    "        if re.match(r\"^[\\!@#$%^&*\\(\\)\\\\\\{\\}\\[\\]:\\\"\\';\\<\\>\\,\\./\\-\\+=_【】;‘:“《》,。?、\\?]+$\",word):\n",
    "            continue\n",
    "        word = re.sub(r\"[\\!@#$%^&*\\(\\)\\\\\\{\\}\\[\\]:\\\"\\';\\<\\>\\,\\./\\-\\+=_【】;‘:“《》,。?、\\?]+\",\"\",word)\n",
    "\n",
    "        # 小写\n",
    "        word = word.lower()\n",
    "\n",
    "        # 去低频词\n",
    "\n",
    "        # 去数字\n",
    "        word = re.sub(r\"[0-9]+\",\"#number\",word)\n",
    "\n",
    "        # 词性还原 lemmazation\n",
    "\n",
    "        result.append(word)\n",
    "    return result\n",
    "\n",
    "def create(corpus):\n",
    "    dual = {}\n",
    "    single = {}\n",
    "    single[\"<START>\"] = len(corpus)\n",
    "    single[\"<END>\"] = len(corpus) \n",
    "    for sen in corpus:\n",
    "        sen = clean(sen)\n",
    "        for i in range(len(sen)):\n",
    "            word = sen[i]\n",
    "            if word not in single:\n",
    "                single[word] = 0\n",
    "            single[word] += 1\n",
    "            if i == 0 :\n",
    "                key = \"{pre}##{cur}\".format(pre=\"<START>\",cur=word)\n",
    "            else:\n",
    "                key = \"{pre}##{cur}\".format(pre=sen[i-1],cur=word)\n",
    "            if key not in dual:\n",
    "                dual[key] = 0\n",
    "            dual[key] += 1\n",
    "        if len(sen) != 0:\n",
    "            key = \"{pre}##{cur}\".format(pre=sen[-1],cur=\"<END>\")\n",
    "            if key not in dual:\n",
    "                dual[key] = 0\n",
    "            dual[key] += 1\n",
    "    return single,dual\n",
    "\n",
    "def getProb1(query,single,dual):\n",
    "    if type(query) == type(str()):\n",
    "        query = clean(query.split(\" \"))\n",
    "    vocab_size = len(single)\n",
    "    result = 0\n",
    "    if len(query) != 0:\n",
    "        numerator = 0\n",
    "        if query[0] in single:\n",
    "            numerator += single[query[0]]\n",
    "        result += numerator/vocab_size\n",
    "    for i in range(1,len(query)):\n",
    "        word = query[i]\n",
    "        denominator = 0\n",
    "        if i == 0:\n",
    "            pre = \"<START>\"\n",
    "        else:\n",
    "            pre = query[i-1]\n",
    "        if pre in single:\n",
    "            denominator += single[pre]\n",
    "        key = \"{pre}##{cur}\".format(pre=pre,cur=word)\n",
    "        numerator = 0\n",
    "        if key in dual:\n",
    "            numerator += dual[key]\n",
    "        if denominator == 0 or numerator == 0:\n",
    "            prob = (numerator + 1.0) / (denominator + vocab_size)\n",
    "        else:\n",
    "            prob = numerator / denominator\n",
    "        result += np.log10(prob)\n",
    "    if result != 0:\n",
    "        result = result/len(query)\n",
    "        result = 10**result\n",
    "    return result\n",
    "\n",
    "single_path = \"single.json\"\n",
    "dual_path = \"dual.json\"\n",
    "if not os.path.exists(single_path) or not os.path.exists(dual_path):\n",
    "    # 读取语料库的数据\n",
    "    categories = reuters.categories()\n",
    "    corpus = reuters.sents(categories=categories)\n",
    "    single,dual = create(corpus)\n",
    "    with open(single_path,\"w\",encoding=\"utf-8\") as f:\n",
    "        json.dump(single,f,ensure_ascii=False)\n",
    "    with open(dual_path,\"w\",encoding=\"utf-8\") as f:\n",
    "        json.dump(dual,f,ensure_ascii=False)\n",
    "else:\n",
    "    with open(single_path,\"r\",encoding=\"utf-8\") as f:\n",
    "        single = json.load(f)\n",
    "    with open(dual_path,\"r\",encoding=\"utf-8\") as f:\n",
    "        dual = json.load(f)\n",
    "sens =  [\"They told Reuter correspondents in Asian capitals a U . S . Move against Japan might boost protectionist sentiment in the U . S . And lead to curbs on American imports of their products .\",\"x y z\", \"world true good\",\"car vechicle movies film music\",\n",
    "            \"i like movie\"]\n",
    "for sen in sens:\n",
    "    print(sen,getProb1(sen,single,dual))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4.2 构建Channel Probs\n",
    "基于``spell_errors.txt``文件构建``channel probability``, 其中$channel[c][s]$表示正确的单词$c$被写错成$s$的概率。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# TODO 构建channel probability  \n",
    "channel = {}\n",
    "\n",
    "with open(\"spell-errors.txt\",\"r\",encoding=\"utf-8\") as f:\n",
    "    for line in f.readlines():\n",
    "        eles = line.strip().split(\":\")\n",
    "        if len(eles) <= 1:\n",
    "            print(line)\n",
    "            continue\n",
    "        c = eles[0]\n",
    "        errors_dict = {}\n",
    "        errors = eles[1].split(\",\")\n",
    "        for s in errors:\n",
    "            s = s.strip()\n",
    "            errors_dict[s] = 1/len(errors)\n",
    "        channel[c] = errors_dict\n",
    "\n",
    "\n",
    "print(channel)   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4.3 根据错别字生成所有候选集合\n",
    "给定一个错误的单词,首先生成跟这个单词距离为1或者2的所有的候选集合。 这部分的代码我们在课程上也讲过,可以参考一下。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    " # encoding: utf-8\n",
    "\n",
    "vocab = set()\n",
    "file_path = \"spell-errors.txt\"\n",
    "with open(file_path,\"r\",encoding=\"utf-8\") as f:\n",
    "    for line in f.readlines():\n",
    "        eles = line.strip().split(\":\")\n",
    "        if len(eles) <= 1:\n",
    "            continue\n",
    "        vocab.add(eles[0])\n",
    "\n",
    "def generate_candidates(word):\n",
    "    # 基于拼写错误的单词,生成跟它的编辑距离为1或者2的单词,并通过词典库的过滤。\n",
    "    # 只留写法上正确的单词。 \n",
    "    \n",
    "    candidates = {word}\n",
    "    for _ in [1,2]:\n",
    "        store = set()\n",
    "        for cand in candidates:\n",
    "            tmp = generate_candidates_with_one_ED(cand)\n",
    "            for t in tmp:\n",
    "                if t != word:\n",
    "                    store.add(t)\n",
    "        candidates = store\n",
    "    result = set()\n",
    "    for cand in candidates:\n",
    "        if cand in vocab:\n",
    "            result.add(cand)\n",
    "    return result\n",
    "        \n",
    "\n",
    "def generate_candidates_with_one_ED(word):\n",
    "    candidates = set()\n",
    "    # 增\n",
    "    alternative = \" \" + word + \" \"\n",
    "    for i in range(1,len(alternative)-1):\n",
    "        for j in range(ord(\"a\"),ord(\"a\") + 26):\n",
    "            c = chr(j)\n",
    "            candidate = alternative[:i] + c + alternative[i:]\n",
    "            candidate = candidate.strip()\n",
    "            candidates.add(candidate)\n",
    "    # 删\n",
    "    for i in range(1,len(alternative)-1):\n",
    "        candidate = alternative[:i] + alternative[i+1:]\n",
    "        candidate = candidate.strip()\n",
    "        candidates.add(candidate)\n",
    "    # 改\n",
    "    for i in range(1,len(alternative)-1):\n",
    "        for j in range(ord(\"a\"),ord(\"a\") + 26):\n",
    "            c = chr(j)\n",
    "            if c == alternative[i]:\n",
    "                continue\n",
    "            candidate = alternative[:i] + c + alternative[i+1:]\n",
    "            candidate = candidate.strip()\n",
    "            candidates.add(candidate)\n",
    "    return candidates\n",
    "\n",
    "words = [\"reserve\",\n",
    "         \"reverse\"]\n",
    "for word in words:\n",
    "    print(word,word in vocab)\n",
    "    cands = generate_candidates(word)\n",
    "    print(cands)   \n",
    "    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4.4 给定一个输入,如果有错误需要纠正\n",
    "\n",
    "给定一个输入``query``, 如果这里有些单词是拼错的,就需要把它纠正过来。这部分的实现可以简单一点: 对于``query``分词,然后把分词后的每一个单词在词库里面搜一下,假设搜不到的话可以认为是拼写错误的! 人如果拼写错误了再通过``channel``和``bigram``来计算最适合的候选。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def spell_corrector(line):\n",
    "    # 1. 首先做分词,然后把``line``表示成``tokens``\n",
    "    # 2. 循环每一token, 然后判断是否存在词库里。如果不存在就意味着是拼写错误的,需要修正。 \n",
    "    #    修正的过程就使用上述提到的``noisy channel model``, 然后从而找出最好的修正之后的结果。\n",
    "    tokens = [ token for token in line.split(\" \") if str(token).isalnum() ]\n",
    "    optimal = []\n",
    "    for i in range(len(tokens)):\n",
    "        token = tokens[i]\n",
    "        if token not in vocab:\n",
    "            cands = generate_candidates(token)\n",
    "            if len(cands) != 0:\n",
    "                info = {\"token\":token,\"idx\":i,\"correct\":\"\",\"prob\":0.0}\n",
    "                for cand in cands:\n",
    "                    # p(s|c)\n",
    "                    if token not in channel[cand]:\n",
    "                        psc = 1 / (len(vocab) + len(channel[cand]))\n",
    "                    else:\n",
    "                        psc = channel[cand][token]\n",
    "                    pc = getProb1(\" \".join(tokens[:i]+[cand]+tokens[i+1:]),single,dual)\n",
    "                    prob = 10**(np.log10(psc)+np.log(pc))\n",
    "                    if info[\"prob\"] < prob:\n",
    "                        info[\"prob\"] = prob\n",
    "                        info[\"correct\"] = cand\n",
    "                optimal.append(info)\n",
    "\n",
    "    for info in optimal:\n",
    "        tokens[info[\"idx\"]] = info[\"correct\"]\n",
    "    print(optimal)\n",
    "    newline = \" \".join(tokens)\n",
    "    return newline   # 修正之后的结果,假如用户输入没有问题,那这时候``newline = line``\n",
    "\n",
    "sens = [\"am rienind the book \",\"everything will become goad\"]\n",
    "for sen in sens :\n",
    "    newsen = spell_corrector(sen)\n",
    "    print(newsen)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4.5 基于拼写纠错算法,实现用户输入自动矫正\n",
    "首先有了用户的输入``query``, 然后做必要的处理把句子转换成tokens的形状,然后对于每一个token比较是否是valid, 如果不是的话就进行下面的修正过程。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test_query1 = \"\"  # 拼写错误的\n",
    "test_query2 = \"\"  # 拼写错误的\n",
    "\n",
    "test_query1 = spell_corector(test_query1)\n",
    "test_query2 = spell_corector(test_query2)\n",
    "\n",
    "print (get_top_results_tfidf(test_query1))\n",
    "print (get_top_results_w2v(test_query1))\n",
    "print (get_top_results_bert(test_query1))\n",
    "\n",
    "print (get_top_results_tfidf(test_query2))\n",
    "print (get_top_results_w2v(test_query2))\n",
    "print (get_top_results_bert(test_query2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 附录 \n",
    "在本次项目中我们实现了一个简易的问答系统。基于这个项目,我们其实可以有很多方面的延伸。\n",
    "- 在这里,我们使用文本向量之间的余弦相似度作为了一个标准。但实际上,我们也可以基于基于包含关键词的情况来给一定的权重。比如一个单词跟related word有多相似,越相似就意味着相似度更高,权重也会更大。 \n",
    "- 另外 ,除了根据词向量去寻找``related words``也可以提前定义好同义词库,但这个需要大量的人力成本。 \n",
    "- 在这里,我们直接返回了问题的答案。 但在理想情况下,我们还是希望通过问题的种类来返回最合适的答案。 比如一个用户问:“明天北京的天气是多少?”, 那这个问题的答案其实是一个具体的温度(其实也叫做实体),所以需要在答案的基础上做进一步的抽取。这项技术其实是跟信息抽取相关的。 \n",
    "- 对于词向量,我们只是使用了``average pooling``, 除了average pooling,我们也还有其他的经典的方法直接去学出一个句子的向量。\n",
    "- 短文的相似度分析一直是业界和学术界一个具有挑战性的问题。在这里我们使用尽可能多的同义词来提升系统的性能。但除了这种简单的方法,可以尝试其他的方法比如WMD,或者适当结合parsing相关的知识点。 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "好了,祝你好运! "
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}