Commit 6b0ec419 by 20220418012

Upload New File

parent b1b0c4b3
from optim import Adam, NoamOpt
import torch
import os
import torch.nn as nn
import torch.distributed
# import torch._tensor
from dataset import PadBatchSeq
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class Trainer:
def __init__(self, args, model, tokz, train_dataset, valid_dataset,
log_dir, logger, device=torch.device('cuda'), valid_writer=None, distributed=False):
self.config = args
self.device = device
self.logger = logger
self.log_dir = log_dir
self.tokz = tokz
self.rank = torch.distributed.get_rank() if distributed else -1
self.train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
if valid_writer is None:
self.valid_writer = SummaryWriter(os.path.join(log_dir, 'valid'))
else:
self.valid_writer = valid_writer
self.model = model.to(device, non_blocking=True)
self.criterion = nn.CrossEntropyLoss(ignore_index=tokz.pad_token_id, reduction='none').to(device)
base_optimizer = Adam(self.model.parameters(), lr=self.config.lr, weight_decay=0.01)
if hasattr(self.model, 'config'):
self.optimizer = NoamOpt(self.model.config.hidden_size, 0.1, self.config.lr_warmup, base_optimizer)
else:
self.optimizer = NoamOpt(self.model.module.config.hidden_size, 0.1, self.config.lr_warmup, base_optimizer)
self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else torch.utils.data.RandomSampler(train_dataset)
self.valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if distributed else None
self.train_dataloader = DataLoader(
train_dataset, sampler=self.train_sampler, batch_size=self.config.bs, num_workers=self.config.n_jobs, pin_memory=True,
collate_fn=PadBatchSeq(self.tokz.pad_token_id))
self.valid_dataloader = DataLoader(
valid_dataset, sampler=self.valid_sampler, batch_size=self.config.bs, num_workers=self.config.n_jobs, pin_memory=True,
collate_fn=PadBatchSeq(self.tokz.pad_token_id))
def state_dict(self):
return self.model.state_dict()
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)
def _eval_train(self, epoch):
self.model.train()
intent_loss, slot_loss, intent_acc, slot_acc, step_count = 0, 0, 0, 0, 0
total = len(self.train_dataloader)
if self.rank in [-1, 0]:
TQDM = tqdm(enumerate(self.train_dataloader), desc='Train (epoch #{})'.format(epoch),
dynamic_ncols=True, total=total)
else:
TQDM = enumerate(self.train_dataloader)
for i, data in TQDM:
#######################################################
# TODO: Complete the following function.
# The following code should preform the training of the model
# You can implement this function with the following steps:
# 1. Pass the input to GPU by calling data.to(self.device)
# 2. Forward the input to the model
# 3. Compute the loss (remember to divide the loss with self.config.batch_split to enable gradient accumulation)
# 4. Backward the loss
# 5. Update the parameters
# 6. Evaluate the model (by calling _eval_test) every `self.config.eval_steps` steps
#######################################################
# 1. Pass the input to GPU by calling data.to(self.device)
text = data['utt'].to(self.device, non_blocking=True)
intent_labels = data['intent'].to(self.device, non_blocking=True)
slot_labels = data['slot'].to(self.device, non_blocking=True)
mask = data['mask'].to(self.device, non_blocking=True)
token_type = data['token_type'].to(self.device, non_blocking=True)
# 2. Forward the input to the model
intent_logits, slot_logits = self.model(input_ids=text,
attention_mask=mask,
token_type_ids=token_type)
# 3. Compute the loss (remember to divide the loss with self.config.batch_split to enable gradient accumulation)
batch_intent_loss = self.criterion(intent_logits, intent_labels).mean()
batch_slot_loss = self.criterion(slot_logits.view(-1, slot_logits.shape[-1]), slot_labels.view(-1)).mean()
slot_mask = 1 - slot_labels.eq(self.tokz.pad_token_id).float()
batch_slot_loss = (batch_slot_loss * slot_mask.view(-1)).sum() / slot_mask.sum()
batch_loss = batch_intent_loss + batch_slot_loss
batch_intent_acc = (torch.argmax(intent_logits, dim=-1) == intent_labels).float().mean()
batch_slot_acc = (torch.argmax(slot_logits, dim=-1) == slot_labels)
batch_slot_acc = torch.sum(batch_slot_acc * slot_mask) / torch.sum(slot_mask)
# 4. Backward the loss
full_loss = batch_loss / self.config.batch_split
full_loss.backward()
intent_loss += batch_intent_loss.item()
slot_loss += batch_slot_loss.item()
intent_acc += batch_intent_acc.item()
slot_acc += batch_slot_acc.item()
step_count += 1
# 5. Update the parameters
curr_step = self.optimizer.curr_step()
lr = self.optimizer.param_groups[0]['lr']
if (i + 1) % self.config.batch_split == 0:
self.optimizer.step()
self.optimizer.zero_grad()
intent_loss /= step_count
slot_loss /= step_count
intent_acc /= step_count
slot_acc /= step_count
if self.rank in [-1, 0]:
self.train_writer.add_scalar('loss/intent_loss', intent_loss, curr_step)
self.train_writer.add_scalar('loss/slot_loss', slot_loss, curr_step)
self.train_writer.add_scalar('acc/intent_acc', intent_acc, curr_step)
self.train_writer.add_scalar('acc/slot_acc', slot_acc, curr_step)
TQDM.set_postfix({'intent_loss': intent_loss,
'intent_acc': intent_acc,
'slot_loss': slot_loss,
'slot_acc': slot_acc})
intent_loss, slot_loss, intent_acc, slot_acc, step_count = 0, 0, 0, 0, 0
# 6. Evaluate the model (by calling _eval_test) every `self.config.eval_steps` steps
if curr_step % self.config.eval_steps == 0:
self._eval_test(epoch=epoch, step=curr_step)
def _eval_test(self, epoch, step):
self.model.eval()
with torch.no_grad():
dev_intent_loss = torch.tensor(0.0, dtype=torch.float32, device=self.device)
dev_slot_loss = torch.tensor(0.0, dtype=torch.float32, device=self.device)
dev_intent_acc = torch.tensor(0.0, dtype=torch.float32, device=self.device)
dev_slot_acc = torch.tensor(0.0, dtype=torch.float32, device=self.device)
count = torch.tensor(0.0, dtype=torch.float32, device=self.device)
for data in self.valid_dataloader:
text = data['utt'].to(self.device, non_blocking=True)
intent_labels = data['intent'].to(self.device, non_blocking=True)
slot_labels = data['slot'].to(self.device, non_blocking=True)
mask = data['mask'].to(self.device, non_blocking=True)
token_type = data['token_type'].to(self.device, non_blocking=True)
intent_logits, slot_logits = self.model(input_ids=text, attention_mask=mask, token_type_ids=token_type)
batch_intent_loss = self.criterion(intent_logits, intent_labels)
batch_slot_loss = self.criterion(slot_logits.view(-1, slot_logits.shape[-1]), slot_labels.view(-1))
slot_mask = 1 - slot_labels.eq(self.tokz.pad_token_id).float()
batch_slot_loss = (batch_slot_loss * slot_mask.view(-1)).view(text.shape[0], -1).sum(dim=-1) / slot_mask.sum(dim=-1)
dev_intent_loss += batch_intent_loss.sum()
dev_slot_loss += batch_slot_loss.sum()
batch_intent_acc = (torch.argmax(intent_logits, dim=-1) == intent_labels).sum()
batch_slot_acc = (torch.argmax(slot_logits, dim=-1) == slot_labels)
batch_slot_acc = torch.sum(batch_slot_acc * slot_mask, dim=-1) / torch.sum(slot_mask, dim=-1)
dev_intent_acc += batch_intent_acc
dev_slot_acc += batch_slot_acc.sum()
count += text.shape[0]
if self.rank != -1:
torch.distributed.all_reduce(dev_intent_loss, op=torch.distributed.reduce_op.SUM)
torch.distributed.all_reduce(dev_slot_loss, op=torch.distributed.reduce_op.SUM)
torch.distributed.all_reduce(dev_intent_acc, op=torch.distributed.reduce_op.SUM)
torch.distributed.all_reduce(dev_slot_acc, op=torch.distributed.reduce_op.SUM)
torch.distributed.all_reduce(count, op=torch.distributed.reduce_op.SUM)
dev_intent_loss /= count
dev_slot_loss /= count
dev_intent_acc /= count
dev_slot_acc /= count
if self.rank in [-1, 0]:
self.valid_writer.add_scalar('loss/intent_loss', dev_intent_loss, step)
self.valid_writer.add_scalar('loss/slot_loss', dev_slot_loss, step)
self.valid_writer.add_scalar('acc/intent_acc', dev_intent_acc, step)
self.valid_writer.add_scalar('acc/slot_acc', dev_slot_acc, step)
log_str = 'epoch {:>3}, step {}'.format(epoch, step)
log_str += ', dev_intent_loss {:>4.4f}'.format(dev_intent_loss)
log_str += ', dev_slot_loss {:>4.4f}'.format(dev_slot_loss)
log_str += ', dev_intent_acc {:>4.4f}'.format(dev_intent_acc)
log_str += ', dev_slot_acc {:>4.4f}'.format(dev_slot_acc)
self.logger.info(log_str)
self.model.train()
def train(self, start_epoch, epochs, after_epoch_funcs=[], after_step_funcs=[]):
for epoch in range(start_epoch + 1, epochs):
self.logger.info('Training on epoch'.format(epoch))
if hasattr(self.train_sampler, 'set_epoch'):
self.train_sampler.set_epoch(epoch)
self._eval_train(epoch)
for func in after_epoch_funcs:
func(epoch, self.device)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment