diff --git a/prefix_tuning_xl.py b/prefix_tuning_cmrc.py similarity index 51% rename from prefix_tuning_xl.py rename to prefix_tuning_cmrc.py index a69e86b..93c63ae 100644 --- a/prefix_tuning_xl.py +++ b/prefix_tuning_cmrc.py @@ -13,8 +13,20 @@ import os import time import torch.nn.functional as F import torch.nn as nn +import json +def process_data(input_path, output_path): + with open(input_path, "r") as file: + data = json.load(file) + with open(output_path, "w") as file: + for data_obj in data["data"]: + obj = data_obj["paragraphs"][0] + for qa in obj["qas"]: + for answer in qa["answers"]: + file.write(json.dumps({"context":obj["context"], "question":qa["question"], + "answer": answer["text"], "qid": qa["id"]}, ensure_ascii=False)+"\n") + return class PrefixTuning(nn.Module): def __init__(self, model, num_layer, embd_dim, device, preseqlen=100, prefix_dropout=0.0): @@ -45,7 +57,46 @@ class PrefixTuning(nn.Module): def forward(self, input_ids, position_ids, attention_mask, *mems): if not mems: mems = self.get_prompt(bsz=input_ids.shape[0]) # num_layer * bsz * seq_len * hidden_dim - return self.model(input_ids, position_ids, attention_mask, *mems) + return self.model(input_ids, position_ids, attention_mask, *mems) + + +def prepare_tokenizer(args): + tokenizer_args = { + 'tokenizer_type': args.tokenizer_type, + 'corpus': None, + 'model_path': args.tokenizer_path, + 'vocab_size': args.vocab_size, + 'model_type': args.tokenizer_model_type, + 'cache_dir': args.cache_dir, + 'add_eop': args.hierarchical} + tokenizer = make_tokenizer(**tokenizer_args) + cmd_tokens = {token.name: token for token in tokenizer._command_tokens} + # add special command tokens + for token, token_text in [('boc', '<|beginofcontext|>'), ('boq', '<|beginofquestion|>'), ('boa', '<|beginofanswer|>')]: + tokenizer.num_tokens +=1 + tokenizer.num_command_tokens +=1 + tokenizer.num_text_tokens +=1 + cmd_tkn = CommandToken(token, token_text, tokenizer.num_tokens) + tokenizer._command_tokens.append(cmd_tkn) + cmd_tokens[token] = cmd_tkn + + num_tokens = tokenizer.num_tokens + before = num_tokens + after = before + multiple = args.make_vocab_size_divisible_by + while (after % multiple) != 0: + after += 1 + print_rank_0('> padded vocab (size: {}) with {} dummy ' + 'tokens (new size: {})'.format( + before, after - before, after)) + + args.tokenizer_num_tokens = after + args.tokenizer_num_type_tokens = tokenizer.num_type_tokens + args.eod_token = tokenizer.get_command('eos').Id + args.vocab_size = after + print_rank_0("prepare tokenizer done") + + return tokenizer, cmd_tokens def setup_model(args): @@ -73,7 +124,7 @@ def setup_model(args): device=args.local_rank, preseqlen=args.preseqlen, prefix_dropout=0.1) model.cuda(torch.cuda.current_device()) if args.load_sd: - print(f"loading state_dict {args.load_sd}") + print(f"loading state dict {args.load_sd}") sd = torch.load(args.load_sd, map_location='cpu') model.load_state_dict(sd["module"]) if hasattr(args, "deepspeed") and args.deepspeed and args.fp16: @@ -92,7 +143,6 @@ def setup_model(args): dist_init_required=False ) return model - def prepare_envs(): """Prepare model for generation.""" @@ -113,36 +163,37 @@ def prepare_envs(): set_random_seed(args.seed) # get the tokenizer - tokenizer = prepare_tokenizer(args) + tokenizer, cmd_tokens = prepare_tokenizer(args) # Model, optimizer, and learning rate. model = setup_model(args) - return model, tokenizer, torch.cuda.current_device(), args - + return model, tokenizer, cmd_tokens, torch.cuda.current_device(), args -def get_processed_data(data_dir, max_seq_len=512, min_seq_len=256): - data = [] - with os.scandir(data_dir) as it: - for idx, entry in enumerate(it): - if not entry.name.startswith('.') and entry.is_file(): - with open(entry.path) as file: - lines = file.readlines() - lines = [line.strip() for line in lines] - doc_data = "".join(lines) - if len(doc_data) <= max_seq_len and len(doc_data)>= min_seq_len: - data.append(doc_data) - print(len(data)) - return data -def get_processed_dataset(input_data, tokenizer): +def get_processed_dataset(input_path, eval_dataset=False): dataset = [] - for doc_data in input_data: - doc_tkns = tokenizer.EncodeAsIds(doc_data) - # important, we need to explicitly tell the model when to stop ! - final_tkns = doc_tkns.append(tokenizer.get_command('eos')) - dataset.append(final_tkns.tokenization) - return dataset + with open(input_path, "r") as file: + while True: + line = file.readline() + if not line: break + text_json = json.loads(line) + ctx_tkns, question_tkns, answer_tkns = tokenizer.EncodeAsIds(text_json["context"]), \ + tokenizer.EncodeAsIds(text_json["question"]), \ + tokenizer.EncodeAsIds(text_json["answer"]) + ctx_tkns.insert(0, cmd_tokens["boc"]) + question_tkns.insert(0, cmd_tokens["boq"]) + answer_tkns.insert(0, cmd_tokens["boa"]) + final_tkns = ctx_tkns.append(question_tkns) + if not eval_dataset: + final_tkns = final_tkns.append(answer_tkns) + # important! so that eos behavior is learned + final_tkns.append(tokenizer.get_command('eos')) + else: + final_tkns = final_tkns.append(cmd_tokens["boa"]) + dataset.append((final_tkns, text_json["qid"])) + return dataset + class CMRCDataset(torch.utils.data.dataset.Dataset): def __init__(self, dataset, device): @@ -150,28 +201,28 @@ class CMRCDataset(torch.utils.data.dataset.Dataset): self.device = device def __getitem__(self, index): - item = self.dataset[index] + item, qid = self.dataset[index] x = torch.tensor(item, device=self.device) - return {"context":x} + return {"context":x, "qid": qid} def __len__(self): return len(self.dataset) - class MyCollator(object): - def __init__(self, tokenizer, device, mem_length): + def __init__(self, tokenizer, device, mem_length, cmd_tokens): self.tokenizer = tokenizer self.device = device self.mem_length = mem_length + self.cmd_tokens = cmd_tokens def __call__(self, batch): # pad with specific pad id data = [item["context"] for item in batch] - data = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=self.tokenizer.get_command('eos').Id) + data = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=self.cmd_tokens["pad"].Id) seq_length = data.shape[1] # get loss mask, mask padding values loss_mask = torch.ones(data.shape, dtype=torch.float, device=self.device) - # loss_mask[data == tokenizer.get_command('eos').Id] = 0.0 # comment out so that eos behavior can be learned + # loss_mask[data == cmd_tokens["pad"].Id] = 0.0 # important! so that eos behavior is learned # get position ids position_ids = torch.arange(seq_length, dtype=torch.long, device=self.device) position_ids = position_ids.unsqueeze(0).expand_as(data) @@ -179,28 +230,17 @@ class MyCollator(object): attention_mask = torch.ones((1, seq_length, seq_length + self.mem_length), device=self.device) attention_mask = torch.tril(torch.triu(attention_mask, 1 - seq_length + self.mem_length), self.mem_length) attention_mask = attention_mask.unsqueeze(1) - return data, attention_mask, loss_mask, position_ids + qids = [item["qid"] for item in batch] + return data, attention_mask, loss_mask, position_ids, qids -def init_model_with_no_grad_invoke(): - with torch.no_grad(): - with open("./checkpoint_model/tokens.pkl", "rb") as file: - tokens = pickle.load(file) - with open("./checkpoint_model/attn_mask.pkl", "rb") as file: - attention_mask = pickle.load(file) - with open("./checkpoint_model/pos_ids.pkl", "rb") as file: - position_ids = pickle.load(file) - mems = [] - return model(tokens, position_ids, attention_mask, *mems) - def train(model, train_loader, args): - writer = SummaryWriter(log_dir=f"./prefix_tuning_nomlp_runs_{args.preseqlen}") + writer = SummaryWriter(log_dir="./cmrc_prefix_runs") model.train() - - for epoch in range(1): - batch_step = 0 + batch_step = 0 + for _ in range(3): prev = time.time() - for data, attention_mask, loss_mask, position_ids in train_loader: + for data, attention_mask, loss_mask, position_ids, qids in train_loader: batch_step += 1 mems = [] # useless in here logits, *mems = model(data, position_ids, attention_mask, *mems) @@ -218,21 +258,89 @@ def train(model, train_loader, args): now = time.time() print(f"Batch step : {batch_step}, elapsed time: {now-prev}") prev = now - if batch_step % 1000 == 0: - ckpt_id = f"transformer-xl-prefix{args.preseqlen}-tuning-nomlp-step-{batch_step}" + if batch_step % 2000 == 0: + ckpt_id = f"transformer-xl-cmrc-prefix-step-{batch_step}" model.save_checkpoint(args.save_dir, ckpt_id) +def cmrc_decode_func(tokenizer, cmd_id_map, tok_ids): + if isinstance(tok_ids, Tokenization): + tok_ids = tok_ids.tokenization + res = [] + for idx in tok_ids: + idx = idx.item() + if idx in cmd_id_map: + res.append(cmd_id_map[idx].token) + else: + res.append(tokenizer.text_tokenizer.decode([idx])) + full_toks = ''.join(res) + try: + boa_tkn = "<|beginofanswer|>" + ans_toks = full_toks[full_toks.index(boa_tkn)+len(boa_tkn):] + except: + ans_toks = '' + return full_toks, ans_toks + + +def test(model, eval_loader, cmd_tokens, args): + batch_step = 0 + args.cmrc_answer_max_len = 50 + prev_time = time.time() + cmd_id_map = {tok.Id: tok for name,tok in cmd_tokens.items()} + eval_res = {} + model.eval() + + for batch_data, attention_mask, loss_mask, position_ids, qids in eval_loader: + batch_step += 1 + counter = 0 + org_context_length = batch_data.shape[-1] + while counter < args.cmrc_answer_max_len: + if counter == 0: + mems = [] # useless in here + logits, *mems = model(batch_data, position_ids, attention_mask, *mems) + else: + index = org_context_length + counter + logits, *mems = model(batch_data[:, index - 1: index], batch_data.new_ones((1, 1)) * (index - 1), + batch_data.new_ones(1, 1, 1, args.mem_length + 1, device=device, + dtype=torch.float), *mems) + logits = logits[:, -1] + logits /= args.temperature + logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) + log_probs = F.softmax(logits, dim=-1) + prev = torch.multinomial(log_probs, num_samples=1)[0] + is_end = prev == args.eod_token + if is_end: + break + batch_data = torch.cat((batch_data, prev.view(1, 1)), dim=1) + counter += 1 + for data, qid in zip(batch_data, qids): + full_toks, ans_toks = cmrc_decode_func(tokenizer, cmd_id_map, data) + eval_res[qid] = ans_toks + if batch_step != 0 and batch_step % 10 == 0: + curr_time = time.time() + print(f"Batch step : {batch_step}, elapsed time: {curr_time-prev_time}") + with open("/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/cmrc2018/squad-style-data/eval_res.json", "w") as file: + json.dump(eval_res, file, indent=2, ensure_ascii=False) + prev_time = curr_time + + if __name__ == "__main__": - model, tokenizer, device, args = prepare_envs() - sport_data_dir = "/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/THUCNews/体育" - sport_data = get_processed_data(sport_data_dir) - train_dataset = get_processed_dataset(sport_data, tokenizer) - my_collator = MyCollator(tokenizer, device, args.mem_length) + train_input_path = "/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/cmrc2018/squad-style-data/cmrc2018_train.json" + train_output_path = "/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/cmrc2018/squad-style-data/train_processed.json" + eval_input_path = "/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/cmrc2018/squad-style-data/cmrc2018_trial.json" + eval_output_path = "/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/cmrc2018/squad-style-data/evaluate_processed.json" + process_data(train_input_path, train_output_path) + process_data(eval_input_path, eval_output_path) + model, tokenizer, cmd_tokens, device, args = prepare_envs() + train_dataset = get_processed_dataset(train_output_path) + eval_dataset = get_processed_dataset(eval_output_path, eval_dataset=True) + my_collator = MyCollator(tokenizer, device, args.mem_length, cmd_tokens) + print(f"batch size : {args.batch_size}") train_loader = torch.utils.data.DataLoader(dataset=CMRCDataset(train_dataset, device), - batch_size=args.batch_size, shuffle=True, collate_fn=my_collator) - # not sure if still needed # this is necessary, cuz if not init with no grad context, a reference error will appear - #res = init_model_with_no_grad_invoke() + batch_size=args.batch_size, shuffle=True, collate_fn=my_collator) train(model, train_loader, args) - \ No newline at end of file + # eval_loader = torch.utils.data.DataLoader(dataset=CMRCDataset(eval_dataset, device), + # batch_size=args.batch_size, shuffle=False, collate_fn=my_collator) + # test(model, eval_loader, cmd_tokens, args) +