fasttext-checkpoint.ipynb 5.11 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 基于FastText的意图分类\n",
    "\n",
    "在这里我们训练一个FastText意图识别模型,并把训练好的模型存放在模型文件里。 意图识别实际上是文本分类任务,需要标注的数据:每一个句子需要对应的标签如闲聊型的,任务型的。但在这个项目中,我们并没有任何标注的数据,而且并不需要搭建闲聊机器人。所以这里搭建的FastText模型只是一个dummy模型,没有什么任何的作用。这个模块只是为了项目的完整性,也让大家明白FastText如何去使用,仅此而已。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import fasttext\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取数据:导入在preprocessor.ipynb中生成的data/question_answer_pares.pkl文件,并将其保存在变量QApares中\n",
    "with open('data/question_answer_pares.pkl','rb') as f:\n",
    "    QApares = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 这一文档的目的是为了用fasttext来进行意图识别,将问题分为任务型和闲聊型两类\n",
    "# 训练数据集是任务型还是闲聊型本应由人工打上标签,但这并不是我们的重点。我们的重点是教会大家如何用fasttext来进行意图识别\n",
    "# 所以这里我们为数据集随机打上0或1的标签,而不去管实际情况如何\n",
    "\n",
    "#fasttext的输入格式为:单词1 单词2 单词3 ... 单词n __label__标签号\n",
    "#我们将问题数据集整理为fasttext需要的输入格式并为其随机打上标签并将结果保存在data/fasttext/fasttext_train.txt和data/fasttext/fasttext_test.txt中\n",
    "with open('data/fasttext/fasttext_train.txt','w') as f:\n",
    "    for content in QApares[:int(0.7*len(QApares))].dropna().question_after_preprocessing:\n",
    "        f.write('%s __label__%d\\n' % (' '.join(content), np.random.randint(0,2)))\n",
    "with open('data/fasttext/fasttext_test.txt','w') as f:\n",
    "    for content in QApares[int(0.7*len(QApares)):].dropna().question_after_preprocessing:\n",
    "        f.write('%s __label__%d\\n' % (' '.join(content), np.random.randint(0,2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#使用fasttext进行意图识别,并将模型保存在classifier中\n",
    "classifier = fasttext.train_supervised('data/fasttext/fasttext_train.txt',     #训练数据文件路径\n",
    "                                       label=\"__label__\",      #类别前缀\n",
    "                                       dim=100,       #向量维度\n",
    "                                       epoch=5,       #训练轮次\n",
    "                                       lr=0.1,        #学习率\n",
    "                                       wordNgrams=2,      #n-gram个数\n",
    "                                       loss='softmax',    #损失函数类型\n",
    "                                       thread=5,          #线程个数, 每个线程处理输入数据的一段, 0号线程负责loss输出\n",
    "                                       verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(('__label__0',), array([0.50808138]))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#使用训练好的fasttext模型进行预测\n",
    "classifier.predict('今天 月亮 真 圆 啊')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(30000, 0.49946666666666667, 0.49946666666666667)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#使用训练好的fasttext模型对测试集文件进行评估\n",
    "classifier.test('data/fasttext/fasttext_test.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#保存模型\n",
    "classifier.save_model('model/fasttext.ftz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mpnn",
   "language": "python",
   "name": "mpnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}