import os
import torch
import random
import traceback
import model.utils as utils
import model.dataset as dataset
from model.model_multi_input import MultiInputModel
from model.trainer_multi_input import Trainer
from model.text import Vocab
import argparse

class mylog:
    def info(self, text):
        print(text)

parser = argparse.ArgumentParser()
parser.add_argument('--config', help='config file', default='config.json')
parser.add_argument('--gpu', help='which gpu to use', type=str, default='3')
parser.add_argument('--epoch', help='which epoch to use', type=int, default=-1)

args = parser.parse_args()
config = utils.load_config(args.config)
config_path = os.path.dirname(args.config)
train_dir = os.path.join(config_path, config['train_dir'])

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

try:
    print('pytorch version: {}'.format(torch.__version__))
    if args.epoch == -1:
        model_path = os.path.join(train_dir, utils.get_latest_ckpt(train_dir))
    else:
        model_path = os.path.join(train_dir, utils.get_ckpt_filename('model', args.epoch))

    if not os.path.isfile(model_path):
        print('cannot find {}'.format(model_path))
        exit(0)

    if len(args.gpu) != 0:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    vocab = Vocab(config.vocab_path)

    print('Building models')
    model =  MultiInputModel(config, vocab).to(device)

    print('Loading weights from {}'.format(model_path))
    state_dict = torch.load(model_path, map_location=device)['model']
    for i in list(state_dict.keys()):
        state_dict[i.replace('.module.', '.')] = state_dict.pop(i)
    model.load_state_dict(state_dict)
    model.eval()

    while True:
        post = input('>> ')
        post = ' '.join(list(post.replace(' ', '')))
        # print('post_str', post)
        post = [vocab.eos_id] + vocab.string2ids(post) + [vocab.eos_id]
        # print('post', post)
        contexts = [torch.tensor([post], dtype=torch.long, device=device)]
        # print('contexts', contexts)
        prediction = model.predict(contexts)[0]
        pred_str = vocab.ids2string(prediction)
        print('>> {}'.format(pred_str))

except:
    print(traceback.format_exc())