Commit 10897314 by 20220418012

Upload New File

parent 6f779620
from torch.utils.data import Dataset
import torch
class NLUDataset(Dataset):
def __init__(self, paths, tokz, cls_vocab, slot_vocab, logger, max_lengths=2048):
self.logger = logger
self.data = NLUDataset.make_dataset(paths, tokz, cls_vocab, slot_vocab, logger, max_lengths)
@staticmethod
def make_dataset(paths, tokz, cls_vocab, slot_vocab, logger, max_lengths):
logger.info('reading data from {}'.format(paths))
dataset = []
#######################################################
# TODO: Complete the following function.
# The output of this function is a list. Each element of this list
# is a tuple consists of
# (intent_id, utterance_token_id_list, token_type_id, slot_id_list)
# intent_id: id for the intent
# utterance_token_id_list: id list for each token in the utterance
# token_type_id: token type id used for BERT
# slot_id_list: id list for all the slots
#######################################################
for path in paths:
with open(path, "r", encoding="utf8") as fp:
lines = [line.strip().lower() for line in fp.readlines() if len(line.strip()) > 0]
line_split = [line.split('\t') for line in lines]
for label, utt, slots in line_split:
intent_id = int(cls_vocab[label])
utt = tokz.convert_tokens_to_ids(list(utt)[:max_lengths])
slots = [slot_vocab[i] for i in slots.split()]
assert len(utt) == len(slots)
dataset.append([intent_id,
[tokz.cls_token_id] + utt + [tokz.sep_token_id],
tokz.create_token_type_ids_from_sequences(token_ids_0=utt),
[tokz.pad_token_id] + slots + [tokz.pad_token_id]])
logger.info('{} data record loaded'.format(len(dataset)))
return dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
intent, utt, token_type, slot = self.data[idx]
return {"intent": intent, "utt": utt, "token_type": token_type, "slot": slot}
class PinnedBatch:
def __init__(self, data):
self.data = data
def __getitem__(self, k):
return self.data[k]
def pin_memory(self):
for k in self.data.keys():
self.data[k] = self.data[k].pin_memory()
return self
class PadBatchSeq:
def __init__(self, pad_id):
self.pad_id = pad_id
def __call__(self, batch):
res = dict()
#######################################################
# TODO: Complete the following function.
# Pad a batch of samples into Tensors.
# The result should be a dict with the following keys:
# "intent": A 1d tensor of intent id. shape: [bs]
# "utt": A 2d tensor of token ids. Shape: [bs, max_seq_len]
# "mask": A 2d tensor of attention mask. The value of each element is either 1 (non PAD token) or 0 (PAD token). Shape: [bs, max_seq_len]
# "toke_type": A 2d tensor of token types. Shape: [bs, max_seq_len]
# "slot": A 2d tensor of slot ids. Shape: [bs, max_seq_len]
#######################################################
res['intent'] = torch.LongTensor(i['intent'] for i in batch)
max_len = max([len(i['utt']) for i in batch])
res['utt'] = torch.LongTensor([i['utt'] + [self.pad_id] * (max_len - len(i['utt'])) for i in batch])
res['mask'] = torch.LongTensor([[1] * len(i['utt']) + [0] * (max_len - len(i['utt'])) for i in batch])
res['token_type'] = torch.LongTensor([i['token_type'] + [self.pad_id] * (max_len - len(i['token_type'])) for i in batch])
res['slot'] = torch.LongTensor([i['slot'] + [self.pad_id] * (max_len - len(i['slot'])) for i in batch])
return PinnedBatch(res)
if __name__ == '__main__':
from transformers import BertTokenizer
bert_path = '/home/data/tmp/bert-base-chinese'
data_file = '/home/data/tmp/NLP_Course/Joint_NLU/data/train.tsv'
cls_vocab_file = '/home/data/tmp/NLP_Course/Joint_NLU/data/cls_vocab'
slot_vocab_file = '/home/data/tmp/NLP_Course/Joint_NLU/data/slot_vocab'
with open(cls_vocab_file) as f:
res = [i.strip() for i in f.readlines() if len(i.strip()) != 0]
cls_vocab = dict(zip(res, range(len(res))))
with open(slot_vocab_file) as f:
res = [i.strip() for i in f.readlines() if len(i.strip()) != 0]
slot_vocab = dict(zip(res, range(len(res))))
class Logger:
def info(self, s):
print(s)
logger = Logger()
tokz = BertTokenizer.from_pretrained(bert_path)
dataset = NLUDataset([data_file], tokz, cls_vocab, slot_vocab, logger)
pad = PadBatchSeq(tokz.pad_token_id)
print(pad([dataset[i] for i in range(5)]))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment