from torch.utils.data import Dataset
import torch


class NLUDataset(Dataset):
    def __init__(self, paths, tokz, cls_vocab, slot_vocab, logger, max_lengths=2048):
        self.logger = logger
        self.data = NLUDataset.make_dataset(paths, tokz, cls_vocab, slot_vocab, logger, max_lengths)

    @staticmethod
    def make_dataset(paths, tokz, cls_vocab, slot_vocab, logger, max_lengths):
        logger.info('reading data from {}'.format(paths))
        dataset = []
        #######################################################
        # TODO: Complete the following function.
        #       The output of this function is a list. Each element of this list
        #       is a tuple consists of 
        #       (intent_id, utterance_token_id_list, token_type_id, slot_id_list)
        #       intent_id: id for the intent
        #       utterance_token_id_list: id list for each token in the utterance
        #       token_type_id: token type id used for BERT
        #       slot_id_list: id list for all the slots
        #######################################################
        for path in paths:
            with open(path, "r", encoding="utf8") as fp:
                lines = [line.strip().lower() for line in fp.readlines() if len(line.strip()) > 0]
                line_split = [line.split('\t') for line in lines]
                for label, utt, slots in line_split:
                    intent_id = int(cls_vocab[label])
                    utt = tokz.convert_tokens_to_ids(list(utt)[:max_lengths])
                    slots = [slot_vocab[i] for i in slots.split()]
                    assert len(utt) == len(slots)
                    dataset.append([intent_id,
                                   [tokz.cls_token_id] + utt + [tokz.sep_token_id],
                                   tokz.create_token_type_ids_from_sequences(token_ids_0=utt),
                                   [tokz.pad_token_id] + slots + [tokz.pad_token_id]])

        logger.info('{} data record loaded'.format(len(dataset)))
        return dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        intent, utt, token_type, slot = self.data[idx]
        return {"intent": intent, "utt": utt, "token_type": token_type, "slot": slot}


class PinnedBatch:
    def __init__(self, data):
        self.data = data

    def __getitem__(self, k):
        return self.data[k]

    def pin_memory(self):
        for k in self.data.keys():
            self.data[k] = self.data[k].pin_memory()
        return self


class PadBatchSeq:
    def __init__(self, pad_id):
        self.pad_id = pad_id

    def __call__(self, batch):
        res = dict()
        #######################################################
        # TODO: Complete the following function.
        #       Pad a batch of samples into Tensors.
        #       The result should be a dict with the following keys:
        #       "intent": A 1d tensor of intent id. shape: [bs]
        #       "utt": A 2d tensor of token ids. Shape: [bs, max_seq_len]
        #       "mask": A 2d tensor of attention mask. The value of each element is either 1 (non PAD token) or 0 (PAD token). Shape: [bs, max_seq_len]
        #       "toke_type": A 2d tensor of token types. Shape: [bs, max_seq_len]
        #       "slot": A 2d tensor of slot ids. Shape: [bs, max_seq_len]
        #######################################################
        res['intent'] = torch.LongTensor(i['intent'] for i in batch)
        max_len = max([len(i['utt']) for i in batch])
        res['utt'] = torch.LongTensor([i['utt'] + [self.pad_id] * (max_len - len(i['utt'])) for i in batch])
        res['mask'] = torch.LongTensor([[1] * len(i['utt']) + [0] * (max_len - len(i['utt'])) for i in batch])
        res['token_type'] = torch.LongTensor([i['token_type'] + [self.pad_id] * (max_len - len(i['token_type'])) for i in batch])
        res['slot'] = torch.LongTensor([i['slot'] + [self.pad_id] * (max_len - len(i['slot'])) for i in batch])

        return PinnedBatch(res)


if __name__ == '__main__':
    from transformers import BertTokenizer
    bert_path = '/home/data/tmp/bert-base-chinese'
    data_file = '/home/data/tmp/NLP_Course/Joint_NLU/data/train.tsv'
    cls_vocab_file = '/home/data/tmp/NLP_Course/Joint_NLU/data/cls_vocab'
    slot_vocab_file = '/home/data/tmp/NLP_Course/Joint_NLU/data/slot_vocab'
    with open(cls_vocab_file) as f:
        res = [i.strip() for i in f.readlines() if len(i.strip()) != 0]
    cls_vocab = dict(zip(res, range(len(res))))
    with open(slot_vocab_file) as f:
        res = [i.strip() for i in f.readlines() if len(i.strip()) != 0]
    slot_vocab = dict(zip(res, range(len(res))))

    class Logger:
        def info(self, s):
            print(s)

    logger = Logger()
    tokz = BertTokenizer.from_pretrained(bert_path)
    dataset = NLUDataset([data_file], tokz, cls_vocab, slot_vocab, logger)
    pad = PadBatchSeq(tokz.pad_token_id)
    print(pad([dataset[i] for i in range(5)]))