Commit 77d2abdb by 20220418012

Upload New File

parent 6b0ec419
import argparse
import torch
import os
parser = argparse.ArgumentParser()
parser.add_argument('--bert_path', help='config file', default='/home/data/tmp/bert-base-chinese')
parser.add_argument('--save_path', help='path to save checkpoints', default='/home/data/tmp/NLP_Course/Joint_NLU/train')
parser.add_argument('--train_file', help='training data', default='/home/data/tmp/NLP_Course/Joint_NLU/data/train.tsv')
parser.add_argument('--valid_file', help='valid data', default='/home/data/tmp/NLP_Course/Joint_NLU/data/test.tsv')
parser.add_argument('--intent_label_vocab', help='training file', default='/home/data/tmp/NLP_Course/Joint_NLU/data/cls_vocab')
parser.add_argument('--slot_label_vocab', help='training file', default='/home/data/tmp/NLP_Course/Joint_NLU/data/slot_vocab')
parser.add_argument("--local_rank", help='used for distributed training', type=int, default=-1)
parser.add_argument('--lr', type=float, default=8e-6)
parser.add_argument('--lr_warmup', type=float, default=200)
parser.add_argument('--bs', type=int, default=30)
parser.add_argument('--batch_split', type=int, default=1)
parser.add_argument('--eval_steps', type=int, default=40)
parser.add_argument('--n_epochs', type=int, default=30)
parser.add_argument('--max_length', type=int, default=90)
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--n_jobs', type=int, default=1, help='num of workers to process data')
parser.add_argument('--gpu', help='which gpu to use', type=str, default='3')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
from transformers import BertConfig, BertTokenizer, AdamW
from NLU_model import NLUModule
import dataset
import utils
import traceback
from trainer import Trainer
from torch.nn.parallel import DistributedDataParallel
train_path = os.path.join(args.save_path, 'train')
log_path = os.path.join(args.save_path, 'log')
def save_func(epoch, device):
filename = utils.get_ckpt_filename('model', epoch)
torch.save(trainer.state_dict(), os.path.join(train_path, filename))
try:
if args.local_rank == -1 or args.local_rank == 0:
if not os.path.isdir(args.save_path):
os.makedirs(args.save_path)
while not os.path.isdir(args.save_path):
pass
logger = utils.get_logger(os.path.join(args.save_path, 'train.log'))
# Setup logging and save folder
if args.local_rank == -1 or args.local_rank == 0:
for path in [train_path, log_path]:
if not os.path.isdir(path):
logger.info('cannot find {}, mkdiring'.format(path))
os.makedirs(path)
for i in vars(args):
logger.info('{}: {}'.format(i, getattr(args, i)))
distributed = (args.local_rank != -1)
if distributed:
# Setup distributed training
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
torch.manual_seed(args.seed)
else:
device = torch.device("cuda", 0)
tokz = BertTokenizer.from_pretrained(args.bert_path)
_, intent2index, _ = utils.load_vocab(args.intent_label_vocab)
_, slot2index, _ = utils.load_vocab(args.slot_label_vocab)
train_dataset = dataset.NLUDataset([args.train_file], tokz, intent2index, slot2index, logger, max_lengths=args.max_length)
valid_dataset = dataset.NLUDataset([args.valid_file], tokz, intent2index, slot2index, logger, max_lengths=args.max_length)
logger.info('Building models, rank {}'.format(args.local_rank))
bert_config = BertConfig.from_pretrained(args.bert_path)
bert_config.num_intent_labels = len(intent2index)
bert_config.num_slot_labels = len(slot2index)
model = NLUModule.from_pretrained(args.bert_path, config=bert_config).to(device)
if distributed:
model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
trainer = Trainer(args, model, tokz, train_dataset, valid_dataset, log_path, logger, device, distributed=distributed)
start_epoch = 0
if args.local_rank in [-1, 0]:
trainer.train(start_epoch, args.n_epochs, after_epoch_funcs=[save_func])
else:
trainer.train(start_epoch, args.n_epochs)
except:
logger.error(traceback.format_exc())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment