From 5f89ddb4b294759db8953033929934dcf4fdefbb Mon Sep 17 00:00:00 2001 From: 20210516036 <gaosandajiang@163.com> Date: Sun, 4 Jul 2021 16:20:52 +0800 Subject: [PATCH] Upload New File --- codes/fasttext.ipynb | 155 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 codes/fasttext.ipynb diff --git a/codes/fasttext.ipynb b/codes/fasttext.ipynb new file mode 100644 index 0000000..2239b5d --- /dev/null +++ b/codes/fasttext.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 基于FastText的意图分类\n", + "\n", + "在这里我们训练一个FastText意图识别模型,并把训练好的模型存放在模型文件里。 意图识别实际上是文本分类任务,需要标注的数据:每一个句子需要对应的标签如闲聊型的,任务型的。但在这个项目中,我们并没有任何标注的数据,而且并不需要搭建闲聊机器人。所以这里搭建的FastText模型只是一个dummy模型,没有什么任何的作用。这个模块只是为了项目的完整性,也让大家明白FastText如何去使用,仅此而已。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import fasttext\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "import pickle" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# 读取数据:导入在preprocessor.ipynb中生成的data/question_answer_pares.pkl文件,并将其保存在变量QApares中\n", + "with open('data/question_answer_pares.pkl','rb') as f:\n", + " QApares = pickle.load(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# 这一文档的目的是为了用fasttext来进行意图识别,将问题分为任务型和闲聊型两类\n", + "# 训练数据集是任务型还是闲聊型本应由人工打上标签,但这并不是我们的重点。我们的重点是教会大家如何用fasttext来进行意图识别\n", + "# 所以这里我们为数据集随机打上0或1的标签,而不去管实际情况如何\n", + "\n", + "#fasttext的输入格式为:单词1 单词2 单词3 ... 单词n __label__标签号\n", + "#我们将问题数据集整理为fasttext需要的输入格式并为其随机打上标签并将结果保存在data/fasttext/fasttext_train.txt和data/fasttext/fasttext_test.txt中\n", + "with open('data/fasttext/fasttext_train.txt','w') as f:\n", + " for content in QApares[:int(0.7*len(QApares))].dropna().question_after_preprocessing:\n", + " f.write('%s __label__%d\\n' % (' '.join(content), np.random.randint(0,2)))\n", + "with open('data/fasttext/fasttext_test.txt','w') as f:\n", + " for content in QApares[int(0.7*len(QApares)):].dropna().question_after_preprocessing:\n", + " f.write('%s __label__%d\\n' % (' '.join(content), np.random.randint(0,2)))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "#使用fasttext进行意图识别,并将模型保存在classifier中\n", + "classifier = fasttext.train_supervised('data/fasttext/fasttext_train.txt', #训练数据文件路径\n", + " label=\"__label__\", #类别前缀\n", + " dim=100, #向量维度\n", + " epoch=5, #训练轮次\n", + " lr=0.1, #学习率\n", + " wordNgrams=2, #n-gram个数\n", + " loss='softmax', #损失函数类型\n", + " thread=5, #线程个数, 每个线程处理输入数据的一段, 0号线程负责loss输出\n", + " verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'__label__0'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#使用训练好的fasttext模型进行预测\n", + "classifier.predict('买 二份 有没有 少 点')[0][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(30000, 0.5007, 0.5007)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#使用训练好的fasttext模型对测试集文件进行评估\n", + "classifier.test('data/fasttext/fasttext_test.txt')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "#保存模型\n", + "classifier.save_model('model/fasttext.ftz')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:greedyaiqa] *", + "language": "python", + "name": "conda-env-greedyaiqa-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} -- libgit2 0.26.0