Commit 5aac183c by 20220418012

Upload New File

parent eba021df
import os
import torch
import random
import traceback
import model.utils as utils
import model.dataset as dataset
from model.model_multi_input import MultiInputModel
from model.trainer_multi_input import Trainer
from model.text import Vocab
import argparse
class mylog:
def info(self, text):
print(text)
parser = argparse.ArgumentParser()
parser.add_argument('--config', help='config file', default='config.json')
parser.add_argument('--gpu', help='which gpu to use', type=str, default='3')
parser.add_argument('--epoch', help='which epoch to use', type=int, default=-1)
args = parser.parse_args()
config = utils.load_config(args.config)
config_path = os.path.dirname(args.config)
train_dir = os.path.join(config_path, config['train_dir'])
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
try:
print('pytorch version: {}'.format(torch.__version__))
if args.epoch == -1:
model_path = os.path.join(train_dir, utils.get_latest_ckpt(train_dir))
else:
model_path = os.path.join(train_dir, utils.get_ckpt_filename('model', args.epoch))
if not os.path.isfile(model_path):
print('cannot find {}'.format(model_path))
exit(0)
if len(args.gpu) != 0:
device = torch.device("cuda")
else:
device = torch.device("cpu")
vocab = Vocab(config.vocab_path)
print('Building models')
model = MultiInputModel(config, vocab).to(device)
print('Loading weights from {}'.format(model_path))
state_dict = torch.load(model_path, map_location=device)['model']
for i in list(state_dict.keys()):
state_dict[i.replace('.module.', '.')] = state_dict.pop(i)
model.load_state_dict(state_dict)
model.eval()
while True:
post = input('>> ')
post = ' '.join(list(post.replace(' ', '')))
# print('post_str', post)
post = [vocab.eos_id] + vocab.string2ids(post) + [vocab.eos_id]
# print('post', post)
contexts = [torch.tensor([post], dtype=torch.long, device=device)]
# print('contexts', contexts)
prediction = model.predict(contexts)[0]
pred_str = vocab.ids2string(prediction)
print('>> {}'.format(pred_str))
except:
print(traceback.format_exc())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment