{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 豆瓣评分的预测\n",
    "\n",
    "在这个项目中,我们要预测一部电影的评分,这个问题实际上就是一个分类问题。给定的输入为一段文本,输出为具体的评分。 在这个项目中,我们需要做:\n",
    "- 文本的预处理,如停用词的过滤,低频词的过滤,特殊符号的过滤等\n",
    "- 文本转化成向量,将使用三种方式,分别为tf-idf, word2vec以及BERT向量。 \n",
    "- 训练逻辑回归和朴素贝叶斯模型,并做交叉验证\n",
    "- 评估模型的准确率\n",
    "\n",
    "在具体标记为``TODO``的部分填写相应的代码。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#导入数据处理的基础包\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "#导入用于计数的包\n",
    "from collections import Counter\n",
    "\n",
    "#导入tf-idf相关的包\n",
    "from sklearn.feature_extraction.text import TfidfTransformer    \n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "\n",
    "#导入模型评估的包\n",
    "from sklearn import metrics\n",
    "\n",
    "#导入与word2vec相关的包\n",
    "from gensim.models import KeyedVectors\n",
    "\n",
    "#导入与bert embedding相关的包,关于mxnet包下载的注意事项参考实验手册\n",
    "from bert_embedding import BertEmbedding\n",
    "import mxnet\n",
    "\n",
    "#包tqdm是用来对可迭代对象执行时生成一个进度条用以监视程序运行过程\n",
    "from tqdm import tqdm\n",
    "\n",
    "#导入其他一些功能包\n",
    "import requests\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 读取数据并做文本的处理\n",
    "你需要完成以下几步操作:\n",
    "- 去掉无用的字符如!&,可自行定义\n",
    "- 中文分词\n",
    "- 去掉低频词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ID</th>\n",
       "      <th>Movie_Name_EN</th>\n",
       "      <th>Movie_Name_CN</th>\n",
       "      <th>Crawl_Date</th>\n",
       "      <th>Number</th>\n",
       "      <th>Username</th>\n",
       "      <th>Date</th>\n",
       "      <th>Star</th>\n",
       "      <th>Comment</th>\n",
       "      <th>Like</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>Avengers Age of Ultron</td>\n",
       "      <td>复仇者联盟2</td>\n",
       "      <td>2017-01-22</td>\n",
       "      <td>1</td>\n",
       "      <td>然潘</td>\n",
       "      <td>2015-05-13</td>\n",
       "      <td>3</td>\n",
       "      <td>连奥创都知道整容要去韩国。</td>\n",
       "      <td>2404</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>10</td>\n",
       "      <td>Avengers Age of Ultron</td>\n",
       "      <td>复仇者联盟2</td>\n",
       "      <td>2017-01-22</td>\n",
       "      <td>11</td>\n",
       "      <td>影志</td>\n",
       "      <td>2015-04-30</td>\n",
       "      <td>4</td>\n",
       "      <td>“一个没有黑暗面的人不值得信任。” 第二部剥去冗长的铺垫,开场即高潮、一直到结束,会有人觉...</td>\n",
       "      <td>381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>20</td>\n",
       "      <td>Avengers Age of Ultron</td>\n",
       "      <td>复仇者联盟2</td>\n",
       "      <td>2017-01-22</td>\n",
       "      <td>21</td>\n",
       "      <td>随时流感</td>\n",
       "      <td>2015-04-28</td>\n",
       "      <td>2</td>\n",
       "      <td>奥创弱爆了弱爆了弱爆了啊!!!!!!</td>\n",
       "      <td>120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>30</td>\n",
       "      <td>Avengers Age of Ultron</td>\n",
       "      <td>复仇者联盟2</td>\n",
       "      <td>2017-01-22</td>\n",
       "      <td>31</td>\n",
       "      <td>乌鸦火堂</td>\n",
       "      <td>2015-05-08</td>\n",
       "      <td>4</td>\n",
       "      <td>与第一集不同,承上启下,阴郁严肃,但也不会不好看啊,除非本来就不喜欢漫威电影。场面更加宏大...</td>\n",
       "      <td>30</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>40</td>\n",
       "      <td>Avengers Age of Ultron</td>\n",
       "      <td>复仇者联盟2</td>\n",
       "      <td>2017-01-22</td>\n",
       "      <td>41</td>\n",
       "      <td>办公室甜心</td>\n",
       "      <td>2015-05-10</td>\n",
       "      <td>5</td>\n",
       "      <td>看毕,我激动地对友人说,等等奥创要来毁灭台北怎么办厚,她拍了拍我肩膀,没事,反正你买了两份...</td>\n",
       "      <td>16</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   ID           Movie_Name_EN Movie_Name_CN  Crawl_Date  Number Username  \\\n",
       "0   0  Avengers Age of Ultron        复仇者联盟2  2017-01-22       1       然潘   \n",
       "1  10  Avengers Age of Ultron        复仇者联盟2  2017-01-22      11       影志   \n",
       "2  20  Avengers Age of Ultron        复仇者联盟2  2017-01-22      21     随时流感   \n",
       "3  30  Avengers Age of Ultron        复仇者联盟2  2017-01-22      31     乌鸦火堂   \n",
       "4  40  Avengers Age of Ultron        复仇者联盟2  2017-01-22      41    办公室甜心   \n",
       "\n",
       "         Date  Star                                            Comment  Like  \n",
       "0  2015-05-13     3                                      连奥创都知道整容要去韩国。  2404  \n",
       "1  2015-04-30     4   “一个没有黑暗面的人不值得信任。” 第二部剥去冗长的铺垫,开场即高潮、一直到结束,会有人觉...   381  \n",
       "2  2015-04-28     2                                 奥创弱爆了弱爆了弱爆了啊!!!!!!   120  \n",
       "3  2015-05-08     4   与第一集不同,承上启下,阴郁严肃,但也不会不好看啊,除非本来就不喜欢漫威电影。场面更加宏大...    30  \n",
       "4  2015-05-10     5   看毕,我激动地对友人说,等等奥创要来毁灭台北怎么办厚,她拍了拍我肩膀,没事,反正你买了两份...    16  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#读取数据\n",
    "data = pd.read_csv('data/DMSC.csv')\n",
    "#观察数据格式\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 212506 entries, 0 to 212505\n",
      "Data columns (total 10 columns):\n",
      " #   Column         Non-Null Count   Dtype \n",
      "---  ------         --------------   ----- \n",
      " 0   ID             212506 non-null  int64 \n",
      " 1   Movie_Name_EN  212506 non-null  object\n",
      " 2   Movie_Name_CN  212506 non-null  object\n",
      " 3   Crawl_Date     212506 non-null  object\n",
      " 4   Number         212506 non-null  int64 \n",
      " 5   Username       212496 non-null  object\n",
      " 6   Date           212506 non-null  object\n",
      " 7   Star           212506 non-null  int64 \n",
      " 8   Comment        212506 non-null  object\n",
      " 9   Like           212506 non-null  int64 \n",
      "dtypes: int64(4), object(6)\n",
      "memory usage: 16.2+ MB\n"
     ]
    }
   ],
   "source": [
    "#输出数据的一些相关信息\n",
    "data.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Comment</th>\n",
       "      <th>Star</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>连奥创都知道整容要去韩国。</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>“一个没有黑暗面的人不值得信任。” 第二部剥去冗长的铺垫,开场即高潮、一直到结束,会有人觉...</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>奥创弱爆了弱爆了弱爆了啊!!!!!!</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>与第一集不同,承上启下,阴郁严肃,但也不会不好看啊,除非本来就不喜欢漫威电影。场面更加宏大...</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>看毕,我激动地对友人说,等等奥创要来毁灭台北怎么办厚,她拍了拍我肩膀,没事,反正你买了两份...</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             Comment  Star\n",
       "0                                      连奥创都知道整容要去韩国。     3\n",
       "1   “一个没有黑暗面的人不值得信任。” 第二部剥去冗长的铺垫,开场即高潮、一直到结束,会有人觉...     4\n",
       "2                                 奥创弱爆了弱爆了弱爆了啊!!!!!!     2\n",
       "3   与第一集不同,承上启下,阴郁严肃,但也不会不好看啊,除非本来就不喜欢漫威电影。场面更加宏大...     4\n",
       "4   看毕,我激动地对友人说,等等奥创要来毁灭台北怎么办厚,她拍了拍我肩膀,没事,反正你买了两份...     5"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#只保留数据中我们需要的两列:Comment列和Star列\n",
    "data = data[['Comment','Star']]\n",
    "#观察新的数据的格式\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Comment</th>\n",
       "      <th>Star</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>连奥创都知道整容要去韩国。</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>“一个没有黑暗面的人不值得信任。” 第二部剥去冗长的铺垫,开场即高潮、一直到结束,会有人觉...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>奥创弱爆了弱爆了弱爆了啊!!!!!!</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>与第一集不同,承上启下,阴郁严肃,但也不会不好看啊,除非本来就不喜欢漫威电影。场面更加宏大...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>看毕,我激动地对友人说,等等奥创要来毁灭台北怎么办厚,她拍了拍我肩膀,没事,反正你买了两份...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             Comment  Star\n",
       "0                                      连奥创都知道整容要去韩国。     1\n",
       "1   “一个没有黑暗面的人不值得信任。” 第二部剥去冗长的铺垫,开场即高潮、一直到结束,会有人觉...     1\n",
       "2                                 奥创弱爆了弱爆了弱爆了啊!!!!!!     0\n",
       "3   与第一集不同,承上启下,阴郁严肃,但也不会不好看啊,除非本来就不喜欢漫威电影。场面更加宏大...     1\n",
       "4   看毕,我激动地对友人说,等等奥创要来毁灭台北怎么办厚,她拍了拍我肩膀,没事,反正你买了两份...     1"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 这里的star代表具体的评分。但在这个项目中,我们要预测的是正面还是负面。我们把评分为1和2的看作是负面,把评分为3,4,5的作为正面\n",
    "data['Star']=(data.Star/3).astype(int)\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 任务1: 去掉一些无用的字符"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Comment</th>\n",
       "      <th>Star</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>连奥创都知道整容要去韩国</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>一个没有黑暗面的人不值得信任 第二部剥去冗长的铺垫开场即高潮一直到结束会有人觉得只剩动作特...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>奥创弱爆了弱爆了弱爆了啊</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>与第一集不同承上启下阴郁严肃但也不会不好看啊除非本来就不喜欢漫威电影场面更加宏大单打与团战...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>看毕我激动地对友人说等等奥创要来毁灭台北怎么办厚她拍了拍我肩膀没事反正你买了两份旅行保险惹</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             Comment  Star\n",
       "0                                       连奥创都知道整容要去韩国     1\n",
       "1   一个没有黑暗面的人不值得信任 第二部剥去冗长的铺垫开场即高潮一直到结束会有人觉得只剩动作特...     1\n",
       "2                                       奥创弱爆了弱爆了弱爆了啊     0\n",
       "3   与第一集不同承上启下阴郁严肃但也不会不好看啊除非本来就不喜欢漫威电影场面更加宏大单打与团战...     1\n",
       "4      看毕我激动地对友人说等等奥创要来毁灭台北怎么办厚她拍了拍我肩膀没事反正你买了两份旅行保险惹     1"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# TODO1: 去掉一些无用的字符,自行定一个字符几何,并从文本中去掉\n",
    "#    your to do \n",
    "remove_words = ['\"', ',', '.', '。', '!', ',', '、', '…', '”', '“', '(', ')','(', ')']\n",
    "for i,r in data.iterrows():\n",
    "    temp = r['Comment']\n",
    "    for word in remove_words:\n",
    "        temp = temp.replace(word, '')\n",
    "    data.loc[i,'Comment'] = temp\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.to_csv('clean_sentence.csv', encoding='utf-8', index=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 任务2:使用结巴分词对文本做分词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "apply:   0%|                                                                                | 0/212506 [00:00<?, ?it/s]Building prefix dict from the default dictionary ...\n",
      "Loading model from cache C:\\Users\\webberg\\AppData\\Local\\Temp\\jieba.cache\n",
      "Loading model cost 1.867 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "apply: 100%|█████████████████████████████████████████████████████████████████| 212506/212506 [01:06<00:00, 3185.28it/s]\n"
     ]
    }
   ],
   "source": [
    "data = pd.read_csv('clean_sentence.csv')\n",
    "\n",
    "def not_empty(s):\n",
    "    return s and s.strip()\n",
    "# TODO2: 导入中文分词包jieba, 并用jieba对原始文本做分词\n",
    "import jieba\n",
    "def comment_cut(content):\n",
    "    # TODO: 使用结巴完成对每一个comment的分词\n",
    "    result = jieba.lcut(content)\n",
    "    result = list(filter(not_empty, result))\n",
    "    return result\n",
    "\n",
    "# 输出进度条\n",
    "tqdm.pandas(desc='apply')\n",
    "data['comment_processed'] = data['Comment'].progress_apply(comment_cut)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Comment</th>\n",
       "      <th>Star</th>\n",
       "      <th>comment_processed</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>连奥创都知道整容要去韩国</td>\n",
       "      <td>1</td>\n",
       "      <td>[连, 奥创, 都, 知道, 整容, 要, 去, 韩国]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>一个没有黑暗面的人不值得信任 第二部剥去冗长的铺垫开场即高潮一直到结束会有人觉得只剩动作特...</td>\n",
       "      <td>1</td>\n",
       "      <td>[一个, 没有, 黑暗面, 的, 人, 不, 值得, 信任, 第二部, 剥去, 冗长, 的,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>奥创弱爆了弱爆了弱爆了啊</td>\n",
       "      <td>0</td>\n",
       "      <td>[奥创, 弱, 爆, 了, 弱, 爆, 了, 弱, 爆, 了, 啊]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>与第一集不同承上启下阴郁严肃但也不会不好看啊除非本来就不喜欢漫威电影场面更加宏大单打与团战...</td>\n",
       "      <td>1</td>\n",
       "      <td>[与, 第一集, 不同, 承上启下, 阴郁, 严肃, 但, 也, 不会, 不, 好看, 啊,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>看毕我激动地对友人说等等奥创要来毁灭台北怎么办厚她拍了拍我肩膀没事反正你买了两份旅行保险惹</td>\n",
       "      <td>1</td>\n",
       "      <td>[看毕, 我, 激动, 地, 对, 友人, 说, 等等, 奥创, 要, 来, 毁灭, 台北,...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             Comment  Star  \\\n",
       "0                                       连奥创都知道整容要去韩国     1   \n",
       "1   一个没有黑暗面的人不值得信任 第二部剥去冗长的铺垫开场即高潮一直到结束会有人觉得只剩动作特...     1   \n",
       "2                                       奥创弱爆了弱爆了弱爆了啊     0   \n",
       "3   与第一集不同承上启下阴郁严肃但也不会不好看啊除非本来就不喜欢漫威电影场面更加宏大单打与团战...     1   \n",
       "4      看毕我激动地对友人说等等奥创要来毁灭台北怎么办厚她拍了拍我肩膀没事反正你买了两份旅行保险惹     1   \n",
       "\n",
       "                                   comment_processed  \n",
       "0                       [连, 奥创, 都, 知道, 整容, 要, 去, 韩国]  \n",
       "1  [一个, 没有, 黑暗面, 的, 人, 不, 值得, 信任, 第二部, 剥去, 冗长, 的,...  \n",
       "2                 [奥创, 弱, 爆, 了, 弱, 爆, 了, 弱, 爆, 了, 啊]  \n",
       "3  [与, 第一集, 不同, 承上启下, 阴郁, 严肃, 但, 也, 不会, 不, 好看, 啊,...  \n",
       "4  [看毕, 我, 激动, 地, 对, 友人, 说, 等等, 奥创, 要, 来, 毁灭, 台北,...  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 观察新的数据的格式\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 任务3:设定停用词并去掉停用词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "apply: 100%|████████████████████████████████████████████████████████████████| 212506/212506 [00:02<00:00, 93283.03it/s]\n"
     ]
    }
   ],
   "source": [
    "# TODO3: 设定停用词并从文本中去掉停用词\n",
    "\n",
    "# 下载中文停用词表至data/stopWord.json中,下载地址:https://github.com/goto456/stopwords/\n",
    "# if not os.path.exists('data/stopWord.json'):\n",
    "#     stopWord = requests.get(\"https://raw.githubusercontent.com/goto456/stopwords/master/baidu_stopwords.txt\")\n",
    "#     with open(\"data/stopWord.json\", \"wb\") as f:\n",
    "#          f.write(stopWord.content)\n",
    "\n",
    "# 读取下载的停用词表,并保存在列表中\n",
    "with open(\"data/baidu_stopwords.txt\",\"r\",encoding ='utf-8') as f:\n",
    "    stopWords = set(f.read().split(\"\\n\"))\n",
    "    \n",
    "# 去除停用词\n",
    "def rm_stop_word(wordList):\n",
    "    # your code, remove stop words\n",
    "    # TODO\n",
    "    newList = []\n",
    "    for idx in range(len(wordList)):\n",
    "        if wordList[idx] in stopWords: continue\n",
    "        newList.append(wordList[idx])\n",
    "    return newList\n",
    "\n",
    "# rm_stop_word(['“','一个', '信任','。'])\n",
    "#这行代码中.progress_apply()函数的作用等同于.apply()函数的作用,只是写成.progress_apply()函数才能被tqdm包监控从而输出进度条。\n",
    "data['comment_processed'] = data['comment_processed'].progress_apply(rm_stop_word)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Comment</th>\n",
       "      <th>Star</th>\n",
       "      <th>comment_processed</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>连奥创都知道整容要去韩国</td>\n",
       "      <td>1</td>\n",
       "      <td>[奥创, 都, 整容, 去, 韩国]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>一个没有黑暗面的人不值得信任 第二部剥去冗长的铺垫开场即高潮一直到结束会有人觉得只剩动作特...</td>\n",
       "      <td>1</td>\n",
       "      <td>[一个, 黑暗面, 人, 不, 值得, 信任, 第二部, 剥去, 冗长, 铺垫, 开场, 高...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>奥创弱爆了弱爆了弱爆了啊</td>\n",
       "      <td>0</td>\n",
       "      <td>[奥创, 弱, 爆, 弱, 爆, 弱, 爆]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>与第一集不同承上启下阴郁严肃但也不会不好看啊除非本来就不喜欢漫威电影场面更加宏大单打与团战...</td>\n",
       "      <td>1</td>\n",
       "      <td>[第一集, 承上启下, 阴郁, 严肃, 不, 好看, 本来, 不, 喜欢, 漫威, 电影, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>看毕我激动地对友人说等等奥创要来毁灭台北怎么办厚她拍了拍我肩膀没事反正你买了两份旅行保险惹</td>\n",
       "      <td>1</td>\n",
       "      <td>[看毕, 激动, 友人, 说, 奥创, 毁灭, 台北, 厚, 拍了拍, 肩膀, 没事, 反正...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             Comment  Star  \\\n",
       "0                                       连奥创都知道整容要去韩国     1   \n",
       "1   一个没有黑暗面的人不值得信任 第二部剥去冗长的铺垫开场即高潮一直到结束会有人觉得只剩动作特...     1   \n",
       "2                                       奥创弱爆了弱爆了弱爆了啊     0   \n",
       "3   与第一集不同承上启下阴郁严肃但也不会不好看啊除非本来就不喜欢漫威电影场面更加宏大单打与团战...     1   \n",
       "4      看毕我激动地对友人说等等奥创要来毁灭台北怎么办厚她拍了拍我肩膀没事反正你买了两份旅行保险惹     1   \n",
       "\n",
       "                                   comment_processed  \n",
       "0                                 [奥创, 都, 整容, 去, 韩国]  \n",
       "1  [一个, 黑暗面, 人, 不, 值得, 信任, 第二部, 剥去, 冗长, 铺垫, 开场, 高...  \n",
       "2                             [奥创, 弱, 爆, 弱, 爆, 弱, 爆]  \n",
       "3  [第一集, 承上启下, 阴郁, 严肃, 不, 好看, 本来, 不, 喜欢, 漫威, 电影, ...  \n",
       "4  [看毕, 激动, 友人, 说, 奥创, 毁灭, 台北, 厚, 拍了拍, 肩膀, 没事, 反正...  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 观察新的数据的格式\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 任务4:去掉低频词,出现次数少于10次的词去掉"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "奥创\n"
     ]
    }
   ],
   "source": [
    "from collections import Counter\n",
    "\n",
    "# TODO4: 去除低频词, 去掉词频小于10的单词,并把结果存放在data['comment_processed']里\n",
    "temp = Counter()\n",
    "for i,r in data.iterrows():\n",
    "    temp += Counter(r['comment_processed'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"counter\", 'w', encoding='utf-8') as f:\n",
    "    for t in  temp:\n",
    "        f.write( \"{} {}\\n\".format(t,temp[t]) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp = Counter()\n",
    "with open(\"counter\", 'r', encoding='utf-8') as f:\n",
    "    while True:\n",
    "        line = f.readline()\n",
    "        if line:\n",
    "            values = line.split(' ')\n",
    "#             print(int(values[1]))\n",
    "            temp[values[0]] = int(values[1])\n",
    "        else:\n",
    "            break\n",
    "# print(temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 找出低频词\n",
    "lower_words = set()\n",
    "for x in temp:\n",
    "    if temp[x] <= 10: lower_words.add(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "apply: 100%|███████████████████████████████████████████████████████████████| 212506/212506 [00:01<00:00, 128212.51it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Comment</th>\n",
       "      <th>Star</th>\n",
       "      <th>comment_processed</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>连奥创都知道整容要去韩国</td>\n",
       "      <td>1</td>\n",
       "      <td>[奥创, 都, 整容, 去, 韩国]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>一个没有黑暗面的人不值得信任 第二部剥去冗长的铺垫开场即高潮一直到结束会有人觉得只剩动作特...</td>\n",
       "      <td>1</td>\n",
       "      <td>[一个, 黑暗面, 人, 不, 值得, 信任, 第二部, 冗长, 铺垫, 开场, 高潮, 结...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>奥创弱爆了弱爆了弱爆了啊</td>\n",
       "      <td>0</td>\n",
       "      <td>[奥创, 弱, 爆, 弱, 爆, 弱, 爆]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>与第一集不同承上启下阴郁严肃但也不会不好看啊除非本来就不喜欢漫威电影场面更加宏大单打与团战...</td>\n",
       "      <td>1</td>\n",
       "      <td>[第一集, 承上启下, 阴郁, 严肃, 不, 好看, 本来, 不, 喜欢, 漫威, 电影, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>看毕我激动地对友人说等等奥创要来毁灭台北怎么办厚她拍了拍我肩膀没事反正你买了两份旅行保险惹</td>\n",
       "      <td>1</td>\n",
       "      <td>[激动, 友人, 说, 奥创, 毁灭, 台北, 厚, 肩膀, 没事, 反正, 买, 两份, ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             Comment  Star  \\\n",
       "0                                       连奥创都知道整容要去韩国     1   \n",
       "1   一个没有黑暗面的人不值得信任 第二部剥去冗长的铺垫开场即高潮一直到结束会有人觉得只剩动作特...     1   \n",
       "2                                       奥创弱爆了弱爆了弱爆了啊     0   \n",
       "3   与第一集不同承上启下阴郁严肃但也不会不好看啊除非本来就不喜欢漫威电影场面更加宏大单打与团战...     1   \n",
       "4      看毕我激动地对友人说等等奥创要来毁灭台北怎么办厚她拍了拍我肩膀没事反正你买了两份旅行保险惹     1   \n",
       "\n",
       "                                   comment_processed  \n",
       "0                                 [奥创, 都, 整容, 去, 韩国]  \n",
       "1  [一个, 黑暗面, 人, 不, 值得, 信任, 第二部, 冗长, 铺垫, 开场, 高潮, 结...  \n",
       "2                             [奥创, 弱, 爆, 弱, 爆, 弱, 爆]  \n",
       "3  [第一集, 承上启下, 阴郁, 严肃, 不, 好看, 本来, 不, 喜欢, 漫威, 电影, ...  \n",
       "4  [激动, 友人, 说, 奥创, 毁灭, 台北, 厚, 肩膀, 没事, 反正, 买, 两份, ...  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def rm_lower_words(wordList):\n",
    "    newList = []\n",
    "    for idx in range(len(wordList)):\n",
    "        if wordList[idx] not in lower_words: newList.append(wordList[idx])\n",
    "    return newList\n",
    "data['comment_processed'] = data['comment_processed'].progress_apply(rm_lower_words)\n",
    "# data['comment_processed'] = \n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 把文本分为训练集和测试集\n",
    "选择语料库中的20%作为测试数据,剩下的作为训练数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# TODO5: 把数据分为训练集和测试集. comments_train(list)保存用于训练的文本,comments_test(list)保存用于测试的文本。 y_train, y_test是对应的标签(0、1)\n",
    "test_ratio = 0.2\n",
    "features = data['comment_processed']\n",
    "labels = data['Star']\n",
    "comments_train, comments_test, y_train, y_test = train_test_split(features, labels, test_size = test_ratio, random_state=13)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "89522                                         [其實, 無, 好, 睇]\n",
      "67258                     [特效, 算, 国产电影, 中, 数一数二, 星爷, 导演, 都]\n",
      "68825     [评论, 差, 特别, 出彩, 画面, 人物, 很, 周星驰, 特效, 简陋, 硬伤, 剧情...\n",
      "151990    [低于, 期待值, 全程无, 尿点, 不功, 星爷, 电影, 残忍, 黑色幽默, 独特, 作...\n",
      "125087      [韩寒, 不错, 说, 整部, 电影, 都, 段子, 人, 没, 看过, 公路, 电影, !]\n",
      "Name: comment_processed, dtype: object 89522     1\n",
      "67258     1\n",
      "68825     1\n",
      "151990    1\n",
      "125087    1\n",
      "Name: Star, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "print(comments_test[0:5], y_test[0:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 把文本转换成向量的形式\n",
    "\n",
    "在这个部分我们会采用三种不同的方式:\n",
    "- 使用tf-idf向量\n",
    "- 使用word2vec\n",
    "- 使用bert向量\n",
    "\n",
    "转换成向量之后,我们接着做模型的训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 任务6:把文本转换成tf-idf向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_corpus(words):\n",
    "    return ' '.join(words)\n",
    "\n",
    "tfidf_train_corpus = []\n",
    "for words in comments_train:\n",
    "    tfidf_train_corpus.append(generate_corpus(words))\n",
    "tfidf_test_corpus = []\n",
    "for words in comments_test:\n",
    "    tfidf_test_corpus.append(generate_corpus(words))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "170004 42502\n"
     ]
    }
   ],
   "source": [
    "print(len(tfidf_train_corpus), len(tfidf_test_corpus))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(170004, 13687) (42502, 13687)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.feature_extraction.text import TfidfTransformer\n",
    "# TODO6: 把训练文本和测试文本转换成tf-idf向量。使用sklearn的feature_extraction.text.TfidfTransformer模块\n",
    "#    请留意fit_transform和transform之间的区别。 常见的错误是在训练集和测试集上都使用 fit_transform,需要避免! \n",
    "#    另外,可以留意一下结果是否为稀疏矩阵\n",
    "vectorizer=CountVectorizer()\n",
    "transformer = TfidfTransformer()\n",
    "tfidf_train=transformer.fit_transform(vectorizer.fit_transform(tfidf_train_corpus))\n",
    "tfidf_test=transformer.transform(vectorizer.transform(tfidf_test_corpus))\n",
    "print (tfidf_train.shape, tfidf_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 任务7:把文本转换成word2vec向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 由于训练出一个高效的word2vec词向量往往需要非常大的语料库与计算资源,所以我们通常不自己训练Wordvec词向量,而直接使用网上开源的已训练好的词向量。\n",
    "# data/sgns.zhihu.word是从https://github.com/Embedding/Chinese-Word-Vectors下载到的预训练好的中文词向量文件\n",
    "# 使用KeyedVectors.load_word2vec_format()函数加载预训练好的词向量文件\n",
    "model = KeyedVectors.load_word2vec_format('data/sgns.zhihu.word')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-3.51068e-01,  2.57389e-01, -1.46752e-01, -4.45400e-03,\n",
       "       -1.04235e-01,  3.72475e-01, -4.29349e-01, -2.80470e-02,\n",
       "        1.56651e-01, -1.27600e-01, -1.68833e-01, -2.91350e-02,\n",
       "        4.57850e-02, -3.53735e-01,  1.61205e-01, -1.82645e-01,\n",
       "       -1.35340e-02, -2.42591e-01, -1.33356e-01, -1.31012e-01,\n",
       "       -9.29500e-02, -1.70479e-01, -2.54004e-01, -1.20530e-01,\n",
       "       -1.33690e-01,  7.84360e-02, -1.46603e-01, -2.77378e-01,\n",
       "       -1.36723e-01,  9.29070e-02, -4.00197e-01,  2.80726e-01,\n",
       "       -1.73282e-01,  8.56630e-02,  2.37251e-01,  6.24290e-02,\n",
       "       -1.57132e-01,  2.15685e-01,  9.54770e-02,  1.09896e-01,\n",
       "       -2.05394e-01, -3.37900e-03, -2.77480e-02,  8.16580e-02,\n",
       "        9.65290e-02,  1.23188e-01,  9.55090e-02, -2.31017e-01,\n",
       "       -8.59590e-02, -2.21634e-01, -1.37885e-01, -1.84790e-01,\n",
       "       -2.40127e-01, -2.79150e-01, -4.56200e-03,  1.04099e-01,\n",
       "        3.20523e-01, -6.77270e-02,  1.95719e-01,  4.06145e-01,\n",
       "       -2.98546e-01, -1.67750e-02,  2.74917e-01, -9.02350e-02,\n",
       "       -1.06762e-01, -2.47535e-01, -4.00415e-01,  2.06635e-01,\n",
       "        2.76320e-01, -3.13900e-03,  3.04576e-01,  1.17664e-01,\n",
       "       -2.17286e-01,  7.54650e-02, -1.44985e-01,  6.36960e-02,\n",
       "        1.58869e-01, -4.71568e-01, -1.08640e-01,  4.00144e-01,\n",
       "       -1.83435e-01,  1.88286e-01,  1.32482e-01, -8.50580e-02,\n",
       "       -8.65500e-03, -2.80691e-01, -1.10871e-01,  4.72890e-02,\n",
       "       -1.47635e-01, -5.17090e-02, -4.65100e-03, -1.73998e-01,\n",
       "       -6.15050e-02,  1.14153e-01,  7.09480e-02,  9.88670e-02,\n",
       "       -7.25230e-02,  4.64800e-02, -1.83534e-01, -1.97097e-01,\n",
       "       -7.94430e-02,  2.80280e-01, -2.44620e-01, -3.95528e-01,\n",
       "       -6.10930e-02, -2.53600e-01,  1.49320e-01,  2.82553e-01,\n",
       "        4.33800e-02,  3.50895e-01, -1.42657e-01, -9.72500e-03,\n",
       "       -1.38536e-01, -1.25489e-01, -1.06447e-01, -9.92880e-02,\n",
       "        4.94210e-02,  1.19487e-01, -6.15150e-02,  1.44710e-01,\n",
       "        1.85710e-01,  7.26870e-02,  1.90587e-01,  2.89779e-01,\n",
       "        2.03630e-01, -9.82690e-02,  1.36294e-01, -1.17514e-01,\n",
       "       -3.54500e-01,  3.30250e-02,  3.01922e-01, -6.46030e-02,\n",
       "       -2.21900e-03, -1.35516e-01,  1.81371e-01,  9.43760e-02,\n",
       "        2.73173e-01, -1.90694e-01,  1.20015e-01,  1.08732e-01,\n",
       "       -3.41390e-02,  1.17405e-01,  3.11844e-01, -8.31670e-02,\n",
       "        2.78229e-01,  3.37064e-01,  6.89230e-02,  2.01023e-01,\n",
       "        3.29060e-02, -4.36554e-01, -1.64540e-02,  2.31550e-02,\n",
       "       -1.96904e-01, -1.49370e-01,  7.83610e-02,  3.27980e-02,\n",
       "        2.42316e-01, -1.67102e-01,  2.93025e-01, -7.99780e-02,\n",
       "        5.57970e-02,  4.07600e-02, -1.87006e-01,  1.90802e-01,\n",
       "        1.10987e-01, -2.66690e-02, -1.09340e-01,  2.88753e-01,\n",
       "       -2.08372e-01,  6.85860e-02, -3.21254e-01,  6.55090e-02,\n",
       "       -2.84544e-01, -2.70365e-01,  2.22242e-01, -8.31220e-02,\n",
       "       -1.01721e-01,  3.11709e-01, -1.59856e-01,  3.19859e-01,\n",
       "        5.72180e-02,  3.15010e-01, -7.65140e-02,  3.07237e-01,\n",
       "        4.14023e-01,  9.61900e-02, -8.12400e-03,  3.59550e-01,\n",
       "       -1.05667e-01, -4.35740e-02,  1.97829e-01, -1.71804e-01,\n",
       "        1.21416e-01, -6.59890e-02,  3.14697e-01, -1.31049e-01,\n",
       "       -1.27306e-01, -4.13040e-02,  3.01799e-01, -2.47272e-01,\n",
       "        8.71550e-02, -4.88150e-01, -2.20991e-01,  4.65800e-02,\n",
       "       -1.34422e-01,  1.35731e-01, -1.72283e-01,  1.16328e-01,\n",
       "        2.88320e-02,  3.31440e-02,  9.48420e-02, -3.48560e-02,\n",
       "        7.54000e-02,  3.56407e-01, -2.56189e-01, -1.32000e-04,\n",
       "        1.05849e-01,  4.28803e-01,  2.86090e-02,  7.92700e-03,\n",
       "        3.58461e-01,  2.82804e-01, -5.88800e-02,  1.73850e-02,\n",
       "        9.28060e-02, -3.90392e-01,  1.89097e-01,  2.85916e-01,\n",
       "        1.51707e-01,  2.58823e-01,  1.63509e-01,  1.26390e-01,\n",
       "        1.95748e-01, -9.80750e-02,  9.12650e-02, -8.20320e-02,\n",
       "       -1.50282e-01,  1.10330e-01,  3.82834e-01, -1.21887e-01,\n",
       "       -1.31515e-01, -4.10777e-01,  2.19966e-01, -1.48785e-01,\n",
       "        1.02161e-01,  8.31420e-02,  2.08074e-01,  3.58526e-01,\n",
       "        1.41909e-01,  2.27764e-01,  4.61127e-01, -1.61267e-01,\n",
       "       -1.22107e-01,  1.02524e-01, -6.15770e-02,  2.10200e-02,\n",
       "        1.46990e-02, -2.23617e-01,  1.71110e-02,  1.20386e-01,\n",
       "       -5.65090e-02, -2.34566e-01,  4.34660e-02,  1.97851e-01,\n",
       "        2.37255e-01, -1.44901e-01,  4.41118e-01, -3.86210e-02,\n",
       "       -2.60820e-01,  4.17700e-02, -9.47700e-02,  3.21410e-02,\n",
       "       -1.86014e-01, -1.40884e-01,  2.02842e-01, -4.83673e-01,\n",
       "        2.19995e-01,  3.59395e-01, -1.84255e-01,  1.30998e-01,\n",
       "        1.10280e-01,  1.42483e-01, -2.01510e-01, -1.34156e-01,\n",
       "       -1.25440e-01, -9.89700e-02, -1.45869e-01, -2.23137e-01,\n",
       "        4.83180e-02,  2.55901e-01, -1.25977e-01, -1.36290e-01,\n",
       "       -3.33329e-01, -2.65370e-01, -1.48834e-01,  1.28487e-01,\n",
       "       -7.88080e-02,  1.35266e-01,  2.17841e-01,  6.60870e-02],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#预训练词向量使用举例\n",
    "model['今天']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "apply: 100%|████████████████████████████████████████████████████████████████| 170004/170004 [00:16<00:00, 10107.06it/s]\n",
      "apply: 100%|██████████████████████████████████████████████████████████████████| 42502/42502 [00:04<00:00, 10534.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(170004,) (42502,)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# TODO7: 对于每个句子,生成句子的向量。具体的做法是:包含在句子中的所有单词的向量做平均。\n",
    "vocabulary = model.vocab\n",
    "\n",
    "def average(wordsVec):\n",
    "    words = []\n",
    "    for word in wordsVec:\n",
    "        if word in vocabulary: words.append(word)\n",
    "    if len(words) == 0: return None\n",
    "    vecs = model[words]\n",
    "    return (np.sum(vecs, axis=0)/len(vecs)).tolist()\n",
    "\n",
    "word2vec_train=comments_train.progress_apply(average)\n",
    "word2vec_test=comments_test.progress_apply(average)\n",
    "print (word2vec_train.shape, word2vec_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(170004, 13687)\n",
      "(170004, 1)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[0.0007956698536872864, 0.28387585282325745, -...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[-0.16322177648544312, 0.3530223071575165, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[-0.2371762990951538, 0.319319486618042, -0.13...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[-0.2556772530078888, 0.4733555018901825, -0.2...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[-0.0046584648080170155, 0.4374968707561493, -...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                   0\n",
       "0  [0.0007956698536872864, 0.28387585282325745, -...\n",
       "1  [-0.16322177648544312, 0.3530223071575165, -0....\n",
       "2  [-0.2371762990951538, 0.319319486618042, -0.13...\n",
       "3  [-0.2556772530078888, 0.4733555018901825, -0.2...\n",
       "4  [-0.0046584648080170155, 0.4374968707561493, -..."
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 任务8:把文本转换成bert向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入gpu版本的bert embedding预训练的模型。\n",
    "# 若没有gpu,则ctx可使用其默认值cpu(0)。但使用cpu会使程序运行的时间变得非常慢\n",
    "# 若之前没有下载过bert embedding预训练的模型,执行此句时会花费一些时间来下载预训练的模型\n",
    "ctx = mxnet.cpu()\n",
    "embedding = BertEmbedding(ctx=ctx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO8: 跟word2vec一样,计算出训练文本和测试文本的向量,仍然采用单词向量的平均。\n",
    "bert_train=embedding(tfidf_train_corpus)\n",
    "bert_test=embedding(tfidf_test_corpus)\n",
    "print (bert_train.shape, bert_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print (tfidf_train.shape, tfidf_test.shape)\n",
    "print (word2vec_train.shape, word2vec_test.shape)\n",
    "print (bert_train.shape, bert_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. 训练模型以及评估\n",
    "对如上三种不同的向量表示法,分别训练逻辑回归模型,需要做:\n",
    "- 搭建模型\n",
    "- 训练模型(并做交叉验证)\n",
    "- 输出最好的结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入逻辑回归的包\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "params = {\n",
    "    'max_iter': [50, 75, 100, 125, 150],\n",
    "    'n_jobs': [8],\n",
    "    'C': [1.0, 0.5, 0.1, 0.05, 0.01]\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 任务9:使用tf-idf,并结合逻辑回归训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<1x13687 sparse matrix of type '<class 'numpy.float64'>'\n",
       "\twith 5 stored elements in Compressed Sparse Row format>"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tfidf_train[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TF-IDF LR test accuracy 0.8710648910639499\n",
      "TF-IDF LR test F1_score 0.7212231052383025\n"
     ]
    }
   ],
   "source": [
    "# TODO9: 使用tf-idf + 逻辑回归训练模型,需要用gridsearchCV做交叉验证,并选择最好的超参数\n",
    "logistic = LogisticRegression()\n",
    "tfidf_clf = GridSearchCV(logistic, params)\n",
    "tfidf_clf.fit(tfidf_train, y_train)\n",
    "# tfidf_model = logistic.fit(tfidf_train, y_train)\n",
    "tf_idf_y_pred = tfidf_clf.predict(tfidf_test)\n",
    "\n",
    "print('TF-IDF LR test accuracy %s' % metrics.accuracy_score(y_test, tf_idf_y_pred))\n",
    "#逻辑回归模型在测试集上的F1_Score\n",
    "print('TF-IDF LR test F1_score %s' % metrics.f1_score(y_test, tf_idf_y_pred,average=\"macro\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 任务10:使用word2vec,并结合逻辑回归训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "# remove None data and its label\n",
    "new_train_data = pd.concat([word2vec_train, y_train], axis=1)\n",
    "new_train_data = new_train_data.dropna(axis=0, how='any')\n",
    "new_test_data = pd.concat([word2vec_test, y_test], axis=1)\n",
    "new_test_data = new_test_data.dropna(axis=0, how='any')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(167129, 300)\n",
      "(41784, 300)\n"
     ]
    }
   ],
   "source": [
    "train1 = new_train_data['comment_processed'].to_numpy().tolist()\n",
    "test1 = new_test_data['comment_processed'].to_numpy().tolist()\n",
    "v_train = []\n",
    "for item in train1:\n",
    "    v_train.append(item)\n",
    "print(np.array(v_train).shape)\n",
    "v_test = []\n",
    "for item in test1:\n",
    "    v_test.append(item)\n",
    "print(np.array(v_test).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO10: 使用word2vec + 逻辑回归训练模型,需要用gridsearchCV做交叉验证,并选择最好的超参数\n",
    "w2vparams = {\n",
    "    'max_iter': [50, 75, 100, 125, 150],\n",
    "    'n_jobs': [8],\n",
    "    'C': [1.0, 0.5, 0.1, 0.05, 0.01]\n",
    "}\n",
    "w2v_logistic = LogisticRegression()\n",
    "w2v_clf = GridSearchCV(w2v_logistic, w2vparams)\n",
    "w2v_clf.fit(np.array(v_train), new_data['Star'].to_numpy())\n",
    "word2vec_y_pred = w2v_clf.predict(np.array(v_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Word2vec LR test accuracy 0.8472621098985258\n",
      "Word2vec LR test F1_score 0.643362704060181\n"
     ]
    }
   ],
   "source": [
    "print('Word2vec LR test accuracy %s' % metrics.accuracy_score(new_test_data['Star'], word2vec_y_pred))\n",
    "#逻辑回归模型在测试集上的F1_Score\n",
    "print('Word2vec LR test F1_score %s' % metrics.f1_score(new_test_data['Star'], word2vec_y_pred,average=\"macro\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 任务11:使用bert,并结合逻辑回归训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO11: 使用bert + 逻辑回归训练模型,需要用gridsearchCV做交叉验证,并选择最好的超参数\n",
    "bert_logistic = LogisticRegression()\n",
    "bert_clf = GridSearchCV(bert_logistic, w2vparams)\n",
    "bert_clf.fit(bert_train, y_train)\n",
    "bert_y_pred = w2v_clf.predict(bert_test)\n",
    "\n",
    "print('Bert LR test accuracy %s' % metrics.accuracy_score(y_test, bert_y_pred))\n",
    "#逻辑回归模型在测试集上的F1_Score\n",
    "print('Bert LR test F1_score %s' % metrics.f1_score(y_test, bert_y_pred,average=\"macro\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 任务12:对于以上结果请做一下简单的总结,按照1,2,3,4提取几个关键点,包括:\n",
    "- 结果说明什么问题?\n",
    "- 接下来如何提高?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 一定要有一个baseline,以便后边观察评估指标的变化\n",
    "2. tfidf+逻辑回归作为基线表现比word2vec要好一些,可能的原因是tfidf的维度高,而word2vec只有300维。也就是tfidf模型可能过拟合了\n",
    "3. 后续可以考虑改进模型来提高准确率,例如使用XGBoost等\n",
    "4. 也可以尝试其他的正则\n",
    "5.\n",
    "6."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}