from torch.utils.data import Dataset import torch class NLUDataset(Dataset): def __init__(self, paths, tokz, cls_vocab, slot_vocab, logger, max_lengths=2048): self.logger = logger self.data = NLUDataset.make_dataset(paths, tokz, cls_vocab, slot_vocab, logger, max_lengths) @staticmethod def make_dataset(paths, tokz, cls_vocab, slot_vocab, logger, max_lengths): logger.info('reading data from {}'.format(paths)) dataset = [] ####################################################### # TODO: Complete the following function. # The output of this function is a list. Each element of this list # is a tuple consists of # (intent_id, utterance_token_id_list, token_type_id, slot_id_list) # intent_id: id for the intent # utterance_token_id_list: id list for each token in the utterance # token_type_id: token type id used for BERT # slot_id_list: id list for all the slots ####################################################### for path in paths: with open(path, "r", encoding="utf8") as fp: lines = [line.strip().lower() for line in fp.readlines() if len(line.strip()) > 0] line_split = [line.split('\t') for line in lines] for label, utt, slots in line_split: intent_id = int(cls_vocab[label]) utt = tokz.convert_tokens_to_ids(list(utt)[:max_lengths]) slots = [slot_vocab[i] for i in slots.split()] assert len(utt) == len(slots) dataset.append([intent_id, [tokz.cls_token_id] + utt + [tokz.sep_token_id], tokz.create_token_type_ids_from_sequences(token_ids_0=utt), [tokz.pad_token_id] + slots + [tokz.pad_token_id]]) logger.info('{} data record loaded'.format(len(dataset))) return dataset def __len__(self): return len(self.data) def __getitem__(self, idx): intent, utt, token_type, slot = self.data[idx] return {"intent": intent, "utt": utt, "token_type": token_type, "slot": slot} class PinnedBatch: def __init__(self, data): self.data = data def __getitem__(self, k): return self.data[k] def pin_memory(self): for k in self.data.keys(): self.data[k] = self.data[k].pin_memory() return self class PadBatchSeq: def __init__(self, pad_id): self.pad_id = pad_id def __call__(self, batch): res = dict() ####################################################### # TODO: Complete the following function. # Pad a batch of samples into Tensors. # The result should be a dict with the following keys: # "intent": A 1d tensor of intent id. shape: [bs] # "utt": A 2d tensor of token ids. Shape: [bs, max_seq_len] # "mask": A 2d tensor of attention mask. The value of each element is either 1 (non PAD token) or 0 (PAD token). Shape: [bs, max_seq_len] # "toke_type": A 2d tensor of token types. Shape: [bs, max_seq_len] # "slot": A 2d tensor of slot ids. Shape: [bs, max_seq_len] ####################################################### res['intent'] = torch.LongTensor(i['intent'] for i in batch) max_len = max([len(i['utt']) for i in batch]) res['utt'] = torch.LongTensor([i['utt'] + [self.pad_id] * (max_len - len(i['utt'])) for i in batch]) res['mask'] = torch.LongTensor([[1] * len(i['utt']) + [0] * (max_len - len(i['utt'])) for i in batch]) res['token_type'] = torch.LongTensor([i['token_type'] + [self.pad_id] * (max_len - len(i['token_type'])) for i in batch]) res['slot'] = torch.LongTensor([i['slot'] + [self.pad_id] * (max_len - len(i['slot'])) for i in batch]) return PinnedBatch(res) if __name__ == '__main__': from transformers import BertTokenizer bert_path = '/home/data/tmp/bert-base-chinese' data_file = '/home/data/tmp/NLP_Course/Joint_NLU/data/train.tsv' cls_vocab_file = '/home/data/tmp/NLP_Course/Joint_NLU/data/cls_vocab' slot_vocab_file = '/home/data/tmp/NLP_Course/Joint_NLU/data/slot_vocab' with open(cls_vocab_file) as f: res = [i.strip() for i in f.readlines() if len(i.strip()) != 0] cls_vocab = dict(zip(res, range(len(res)))) with open(slot_vocab_file) as f: res = [i.strip() for i in f.readlines() if len(i.strip()) != 0] slot_vocab = dict(zip(res, range(len(res)))) class Logger: def info(self, s): print(s) logger = Logger() tokz = BertTokenizer.from_pretrained(bert_path) dataset = NLUDataset([data_file], tokz, cls_vocab, slot_vocab, logger) pad = PadBatchSeq(tokz.pad_token_id) print(pad([dataset[i] for i in range(5)]))