import collections
import os
import random
import torch
from torch import nn
import torchtext.vocab as Vocab
import torch.utils.data as Data
from sklearn.model_selection import train_test_split
import jieba
import time
import torch.nn.functional as F


class MyDataset:

    def __init__(self, data_dir, batch_size=64):
        self.batch_size = batch_size
        self.data_dir = data_dir
        self.label_list = ['体育', '娱乐', '家居', '房产', '教育',
                           '时尚', '时政', '游戏', '科技', '财经']
        self.data = self.read_file()
        self.train_iter, self.test_iter, self.vocab = self.preprocess(self.data)

    def read_file(self):
        with open(self.data_dir, 'rb') as f:
            lines = f.readlines()
        reviews = map(lambda x: x.decode('utf-8').lower().replace('\r\n', ''), lines)
        data = []
        for review in reviews:
            str_array = review.split('\t')
            data.append([str_array[1], self.label_list.index(str_array[0])])
        random.shuffle(data)
        return data

    @staticmethod
    def get_tokenized(data):
        """
        data: list of [string, label]
        """

        def tokenizer(text):
            return jieba.cut(text, cut_all=True)

        reviews = []
        for review, _ in data:
            reviews.append([tokens for tokens in tokenizer(review)])
        return reviews

    @staticmethod
    def get_vocab(tokenized_data):
        counter = collections.Counter([tk for st in tokenized_data for tk in st])
        return Vocab.Vocab(counter, min_freq=5)

    def preprocess(self, data):
        max_l = 350

        def pad(x):
            return x[:max_l] if len(x) > max_l else x + [0] * (max_l - len(x))

        tokenized_data = self.get_tokenized(data)
        vocab = self.get_vocab(tokenized_data)

        features = torch.tensor([pad([vocab.stoi[word] for word in words]) for words in tokenized_data])
        labels = torch.tensor([score for _, score in data])
        X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.25, random_state=0)
        train_set = Data.TensorDataset(X_train, y_train)
        test_set = Data.TensorDataset(X_test, y_test)
        train_iter = Data.DataLoader(train_set, self.batch_size, shuffle=True)
        test_iter = Data.DataLoader(test_set, self.batch_size)
        return train_iter, test_iter, vocab


class BiLSTM(nn.Module):
    def __init__(self, vocab, embed_size, num_hiddens, num_layers):
        super(BiLSTM, self).__init__()
        self.embedding = nn.Embedding(len(vocab), embed_size)
        # bidirectional设为True即得到双向循环神经网络
        self.encoder = nn.LSTM(input_size=embed_size,
                               hidden_size=num_hiddens,
                               num_layers=num_layers,
                               bidirectional=True)
        self.attention_layer = SelfAttention(num_hiddens * 2)
        # 初始时间步和最终时间步的隐藏状态作为全连接层输入,10分类
        self.decoder = nn.Linear(2 * num_hiddens, 10)

    def forward(self, inputs):
        # inputs的形状是(批量大小, 词数),LSTM需要将序列长度(seq_len)作为第一维
        # 再提取词特征,输出形状为(词数, 批量大小, 词向量维度)
        embeddings = self.embedding(inputs.permute(1, 0))
        # outputs (词数, 批量大小, 2 * 隐藏单元个数)
        outputs, _ = self.encoder(embeddings)  # output, (h, c)
        # # # Attention过程
        feat, att_score = self.attention_layer(outputs.permute(1, 0, 2))
        outs = self.decoder(feat)
        return outs, att_score


class SelfAttention(nn.Module):
    def __init__(self, num_hiddens):
        super(SelfAttention, self).__init__()
        self.weight_W = nn.Parameter(torch.empty(num_hiddens, num_hiddens), requires_grad=True)
        self.weight_proj = nn.Parameter(torch.empty(num_hiddens, 1), requires_grad=True)
        # 自定义的权重初始化
        nn.init.uniform_(self.weight_W, -0.1, 0.1)
        nn.init.uniform_(self.weight_proj, -0.1, 0.1)

    def forward(self, x):
        u = torch.tanh(torch.matmul(x, self.weight_W))
        att = torch.matmul(u, self.weight_proj)
        att_score = F.softmax(att, dim=1)
        scored_x = x * att_score
        feat = torch.sum(scored_x, dim=1)
        return feat, att_score


def train_model(data_dir):
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = MyDataset(data_dir=data_dir)
    embed_size, num_hiddens, num_layers = 100, 128, 2
    net = BiLSTM(dataset.vocab, embed_size, num_hiddens, num_layers)
    lr, num_epochs = 0.01, 5
    # 要过滤掉不计算梯度的embedding参数,如果是预训练词嵌入就不需要更新embedding层,这里要训练
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)
    loss = nn.CrossEntropyLoss()
    train(dataset.train_iter, dataset.test_iter, net, loss, optimizer, device, num_epochs)


def train(train_iter, test_iter, net, loss_fn, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on ", device)
    batch_count = 0
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        count = 0
        for X, y in train_iter:
            count += 1
            X = X.to(device)
            y = y.to(device)
            y_hat, _ = net(X)
            loss = loss_fn(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_l_sum += loss.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
            if count % 10 == 0:
                print('epoch %d, loss %.4f, train acc %.3f'
                      % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n))
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))


def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        # 如果没指定device就使用net的device
        device = list(net.parameters())[0].device
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(net, torch.nn.Module):
                net.eval()  # 评估模式, 这会关闭dropout
                outputs, _ = net(X.to(device))
                acc_sum += (outputs.argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                net.train()  # 改回训练模式
            n += y.shape[0]
    return acc_sum / n


if __name__ == '__main__':
    data_path = os.path.join(r'D:\JetBrains\workspace\project3\data', '待抽取关键词文本数据.txt')
    train_model(data_path)