diff --git a/prefix_tuning_xl.py b/prefix_tuning_cmrc.py
similarity index 51%
rename from prefix_tuning_xl.py
rename to prefix_tuning_cmrc.py
index a69e86b..93c63ae 100644
--- a/prefix_tuning_xl.py
+++ b/prefix_tuning_cmrc.py
@@ -13,8 +13,20 @@ import os
 import time
 import torch.nn.functional as F
 import torch.nn as nn
+import json
 
 
+def process_data(input_path, output_path):
+    with open(input_path, "r") as file:
+        data = json.load(file)
+    with open(output_path, "w") as file:
+        for data_obj in data["data"]:
+            obj = data_obj["paragraphs"][0]
+            for qa in obj["qas"]:
+                for answer in qa["answers"]:
+                    file.write(json.dumps({"context":obj["context"], "question":qa["question"], 
+                                           "answer": answer["text"], "qid": qa["id"]}, ensure_ascii=False)+"\n")
+    return
 
 class PrefixTuning(nn.Module):
     def __init__(self, model, num_layer, embd_dim, device, preseqlen=100, prefix_dropout=0.0):
@@ -45,7 +57,46 @@ class PrefixTuning(nn.Module):
     def forward(self,  input_ids, position_ids, attention_mask, *mems):
         if not mems:
             mems = self.get_prompt(bsz=input_ids.shape[0]) # num_layer * bsz * seq_len * hidden_dim
-        return self.model(input_ids, position_ids, attention_mask, *mems)    
+        return self.model(input_ids, position_ids, attention_mask, *mems)
+
+
+def prepare_tokenizer(args):
+    tokenizer_args = {
+        'tokenizer_type': args.tokenizer_type,
+        'corpus': None,
+        'model_path': args.tokenizer_path,
+        'vocab_size': args.vocab_size,
+        'model_type': args.tokenizer_model_type,
+        'cache_dir': args.cache_dir,
+        'add_eop': args.hierarchical}
+    tokenizer = make_tokenizer(**tokenizer_args)
+    cmd_tokens = {token.name: token for token in tokenizer._command_tokens}
+    # add special command tokens  
+    for token, token_text in [('boc', '<|beginofcontext|>'), ('boq', '<|beginofquestion|>'), ('boa', '<|beginofanswer|>')]:
+        tokenizer.num_tokens +=1  
+        tokenizer.num_command_tokens +=1
+        tokenizer.num_text_tokens +=1
+        cmd_tkn = CommandToken(token, token_text, tokenizer.num_tokens)
+        tokenizer._command_tokens.append(cmd_tkn)
+        cmd_tokens[token] = cmd_tkn
+
+    num_tokens = tokenizer.num_tokens
+    before = num_tokens
+    after = before
+    multiple = args.make_vocab_size_divisible_by
+    while (after % multiple) != 0:
+        after += 1
+    print_rank_0('> padded vocab (size: {}) with {} dummy '
+                 'tokens (new size: {})'.format(
+        before, after - before, after))
+
+    args.tokenizer_num_tokens = after
+    args.tokenizer_num_type_tokens = tokenizer.num_type_tokens
+    args.eod_token = tokenizer.get_command('eos').Id
+    args.vocab_size = after
+    print_rank_0("prepare tokenizer done")
+
+    return tokenizer, cmd_tokens
 
 
 def setup_model(args):
@@ -73,7 +124,7 @@ def setup_model(args):
                          device=args.local_rank, preseqlen=args.preseqlen, prefix_dropout=0.1)
     model.cuda(torch.cuda.current_device())
     if args.load_sd:
-        print(f"loading state_dict {args.load_sd}")
+        print(f"loading state dict {args.load_sd}")
         sd = torch.load(args.load_sd, map_location='cpu')
         model.load_state_dict(sd["module"])
     if hasattr(args, "deepspeed") and args.deepspeed and args.fp16:
@@ -92,7 +143,6 @@ def setup_model(args):
             dist_init_required=False
         )
     return model
-
     
 def prepare_envs():
     """Prepare model for generation."""
@@ -113,36 +163,37 @@ def prepare_envs():
     set_random_seed(args.seed)
 
     # get the tokenizer
-    tokenizer = prepare_tokenizer(args)
+    tokenizer, cmd_tokens = prepare_tokenizer(args)
 
     # Model, optimizer, and learning rate.
     model = setup_model(args)
 
-    return model, tokenizer, torch.cuda.current_device(), args
-
+    return model, tokenizer, cmd_tokens, torch.cuda.current_device(), args
 
-def get_processed_data(data_dir, max_seq_len=512, min_seq_len=256):
-    data = []
-    with os.scandir(data_dir) as it:
-        for idx, entry in enumerate(it):
-            if not entry.name.startswith('.') and entry.is_file():
-                with open(entry.path) as file:
-                    lines = file.readlines() 
-                    lines = [line.strip() for line in lines]
-                    doc_data = "".join(lines)
-                    if len(doc_data) <= max_seq_len and len(doc_data)>= min_seq_len:
-                        data.append(doc_data)
-    print(len(data))
-    return data
 
-def get_processed_dataset(input_data, tokenizer):
+def get_processed_dataset(input_path, eval_dataset=False):
     dataset = []
-    for doc_data in input_data:
-        doc_tkns = tokenizer.EncodeAsIds(doc_data) 
-        # important, we need to explicitly tell the model when to stop !
-        final_tkns = doc_tkns.append(tokenizer.get_command('eos'))
-        dataset.append(final_tkns.tokenization)
-    return dataset
+    with open(input_path, "r") as file:
+        while True:
+            line = file.readline()
+            if not line: break
+            text_json = json.loads(line)
+            ctx_tkns, question_tkns, answer_tkns = tokenizer.EncodeAsIds(text_json["context"]), \
+                                                   tokenizer.EncodeAsIds(text_json["question"]), \
+                                                   tokenizer.EncodeAsIds(text_json["answer"])
+            ctx_tkns.insert(0, cmd_tokens["boc"])
+            question_tkns.insert(0, cmd_tokens["boq"])
+            answer_tkns.insert(0, cmd_tokens["boa"])
+            final_tkns = ctx_tkns.append(question_tkns)
+            if not eval_dataset:
+                final_tkns = final_tkns.append(answer_tkns)
+                # important! so that eos behavior is learned
+                final_tkns.append(tokenizer.get_command('eos'))
+            else:
+                final_tkns = final_tkns.append(cmd_tokens["boa"])
+            dataset.append((final_tkns, text_json["qid"]))
+    return dataset    
+
 
 class CMRCDataset(torch.utils.data.dataset.Dataset):
     def __init__(self, dataset, device):
@@ -150,28 +201,28 @@ class CMRCDataset(torch.utils.data.dataset.Dataset):
         self.device = device
         
     def __getitem__(self, index):
-        item = self.dataset[index]
+        item, qid = self.dataset[index]
         x = torch.tensor(item, device=self.device)
-        return {"context":x}
+        return {"context":x, "qid": qid}
     
     def __len__(self):
         return len(self.dataset)
 
-
 class MyCollator(object):
-    def __init__(self, tokenizer, device, mem_length):
+    def __init__(self, tokenizer, device, mem_length, cmd_tokens):
         self.tokenizer = tokenizer
         self.device = device
         self.mem_length = mem_length
+        self.cmd_tokens = cmd_tokens
 
     def __call__(self, batch):
         # pad with specific pad id 
         data = [item["context"] for item in batch]
-        data = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=self.tokenizer.get_command('eos').Id)
+        data = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=self.cmd_tokens["pad"].Id)
         seq_length = data.shape[1]
         # get loss mask, mask padding values
         loss_mask = torch.ones(data.shape, dtype=torch.float, device=self.device)
-        # loss_mask[data == tokenizer.get_command('eos').Id] = 0.0  # comment out so that eos behavior can be learned
+        # loss_mask[data == cmd_tokens["pad"].Id] = 0.0 # important! so that eos behavior is learned
         # get position ids 
         position_ids = torch.arange(seq_length, dtype=torch.long, device=self.device)
         position_ids = position_ids.unsqueeze(0).expand_as(data)
@@ -179,28 +230,17 @@ class MyCollator(object):
         attention_mask = torch.ones((1, seq_length, seq_length + self.mem_length), device=self.device)
         attention_mask = torch.tril(torch.triu(attention_mask, 1 - seq_length + self.mem_length), self.mem_length)
         attention_mask = attention_mask.unsqueeze(1)
-        return data, attention_mask, loss_mask, position_ids
+        qids = [item["qid"] for item in batch]
+        return data, attention_mask, loss_mask, position_ids, qids
 
 
-def init_model_with_no_grad_invoke():
-    with torch.no_grad():
-        with open("./checkpoint_model/tokens.pkl", "rb") as file:
-            tokens = pickle.load(file)
-        with open("./checkpoint_model/attn_mask.pkl", "rb") as file:
-            attention_mask = pickle.load(file)
-        with open("./checkpoint_model/pos_ids.pkl", "rb") as file:
-            position_ids = pickle.load(file)
-        mems = []
-    return model(tokens, position_ids, attention_mask, *mems)
-
 def train(model, train_loader, args):
-    writer = SummaryWriter(log_dir=f"./prefix_tuning_nomlp_runs_{args.preseqlen}")
+    writer = SummaryWriter(log_dir="./cmrc_prefix_runs")
     model.train()
-
-    for epoch in range(1):
-        batch_step = 0
+    batch_step = 0
+    for _ in range(3):
         prev = time.time()
-        for data, attention_mask, loss_mask, position_ids in train_loader:
+        for data, attention_mask, loss_mask, position_ids, qids  in train_loader:
             batch_step += 1
             mems = [] # useless in here 
             logits, *mems = model(data, position_ids, attention_mask, *mems)
@@ -218,21 +258,89 @@ def train(model, train_loader, args):
                 now = time.time()
                 print(f"Batch step : {batch_step}, elapsed time: {now-prev}")
                 prev = now
-                if batch_step % 1000 == 0:
-                    ckpt_id = f"transformer-xl-prefix{args.preseqlen}-tuning-nomlp-step-{batch_step}"
+                if batch_step % 2000 == 0:
+                    ckpt_id = f"transformer-xl-cmrc-prefix-step-{batch_step}"
                     model.save_checkpoint(args.save_dir, ckpt_id)
 
 
+def cmrc_decode_func(tokenizer, cmd_id_map, tok_ids):
+    if isinstance(tok_ids, Tokenization):
+        tok_ids = tok_ids.tokenization
+    res = []
+    for idx in tok_ids:
+        idx = idx.item()
+        if idx in cmd_id_map:
+            res.append(cmd_id_map[idx].token)
+        else:
+            res.append(tokenizer.text_tokenizer.decode([idx]))
+    full_toks = ''.join(res)
+    try:
+        boa_tkn = "<|beginofanswer|>"
+        ans_toks = full_toks[full_toks.index(boa_tkn)+len(boa_tkn):]
+    except:
+        ans_toks = ''
+    return full_toks, ans_toks
+
+
+def test(model, eval_loader, cmd_tokens, args):
+    batch_step = 0
+    args.cmrc_answer_max_len = 50
+    prev_time = time.time()
+    cmd_id_map = {tok.Id: tok for name,tok in cmd_tokens.items()}
+    eval_res = {}
+    model.eval()
+
+    for batch_data, attention_mask, loss_mask, position_ids, qids in eval_loader:
+        batch_step += 1
+        counter = 0
+        org_context_length = batch_data.shape[-1]
+        while counter < args.cmrc_answer_max_len:
+            if counter == 0:
+                mems = [] # useless in here 
+                logits, *mems = model(batch_data, position_ids, attention_mask, *mems)
+            else:
+                index = org_context_length + counter
+                logits, *mems = model(batch_data[:, index - 1: index], batch_data.new_ones((1, 1)) * (index - 1),
+                                    batch_data.new_ones(1, 1, 1, args.mem_length + 1, device=device,
+                                                    dtype=torch.float), *mems)
+            logits = logits[:, -1]
+            logits /= args.temperature
+            logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
+            log_probs = F.softmax(logits, dim=-1)
+            prev = torch.multinomial(log_probs, num_samples=1)[0]
+            is_end = prev == args.eod_token
+            if is_end:
+                break
+            batch_data = torch.cat((batch_data, prev.view(1, 1)), dim=1)
+            counter += 1
+        for data, qid in  zip(batch_data, qids):
+            full_toks, ans_toks = cmrc_decode_func(tokenizer, cmd_id_map, data)
+            eval_res[qid] = ans_toks
+        if batch_step != 0 and batch_step % 10 == 0:
+            curr_time = time.time()
+            print(f"Batch step : {batch_step}, elapsed time: {curr_time-prev_time}")
+            with open("/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/cmrc2018/squad-style-data/eval_res.json", "w") as file:
+                json.dump(eval_res, file, indent=2, ensure_ascii=False)
+            prev_time = curr_time
+
+
 
 if __name__ == "__main__":
-    model, tokenizer, device, args = prepare_envs()
-    sport_data_dir = "/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/THUCNews/体育"
-    sport_data = get_processed_data(sport_data_dir)
-    train_dataset = get_processed_dataset(sport_data, tokenizer)
-    my_collator = MyCollator(tokenizer, device, args.mem_length)
+    train_input_path = "/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/cmrc2018/squad-style-data/cmrc2018_train.json"
+    train_output_path = "/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/cmrc2018/squad-style-data/train_processed.json"
+    eval_input_path = "/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/cmrc2018/squad-style-data/cmrc2018_trial.json"
+    eval_output_path = "/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/cmrc2018/squad-style-data/evaluate_processed.json"
+    process_data(train_input_path, train_output_path)
+    process_data(eval_input_path, eval_output_path)
+    model, tokenizer, cmd_tokens, device, args = prepare_envs()
+    train_dataset = get_processed_dataset(train_output_path)
+    eval_dataset = get_processed_dataset(eval_output_path, eval_dataset=True)   
+    my_collator = MyCollator(tokenizer, device, args.mem_length, cmd_tokens)
+    print(f"batch size : {args.batch_size}")
     train_loader = torch.utils.data.DataLoader(dataset=CMRCDataset(train_dataset, device),
-                                               batch_size=args.batch_size, shuffle=True, collate_fn=my_collator)
-    # not sure if still needed # this is necessary, cuz if not init with no grad context, a reference error will appear
-    #res = init_model_with_no_grad_invoke() 
+                                                batch_size=args.batch_size, shuffle=True, collate_fn=my_collator)
     train(model, train_loader, args)
-        
\ No newline at end of file
+    # eval_loader = torch.utils.data.DataLoader(dataset=CMRCDataset(eval_dataset, device),
+    #     batch_size=args.batch_size, shuffle=False, collate_fn=my_collator)
+    # test(model, eval_loader, cmd_tokens, args)
+