from optim import Adam, NoamOpt
import torch
import os
import torch.nn as nn
import torch.distributed
# import torch._tensor
from dataset import PadBatchSeq
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


class Trainer:
    def __init__(self, args, model, tokz, train_dataset, valid_dataset,
                 log_dir, logger, device=torch.device('cuda'), valid_writer=None, distributed=False):
        self.config = args
        self.device = device
        self.logger = logger
        self.log_dir = log_dir
        self.tokz = tokz
        self.rank = torch.distributed.get_rank() if distributed else -1
        self.train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
        if valid_writer is None:
            self.valid_writer = SummaryWriter(os.path.join(log_dir, 'valid'))
        else:
            self.valid_writer = valid_writer
        self.model = model.to(device, non_blocking=True)
        self.criterion = nn.CrossEntropyLoss(ignore_index=tokz.pad_token_id, reduction='none').to(device)

        base_optimizer = Adam(self.model.parameters(), lr=self.config.lr, weight_decay=0.01)
        if hasattr(self.model, 'config'):
            self.optimizer = NoamOpt(self.model.config.hidden_size, 0.1, self.config.lr_warmup, base_optimizer)
        else:
            self.optimizer = NoamOpt(self.model.module.config.hidden_size, 0.1, self.config.lr_warmup, base_optimizer)

        self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else torch.utils.data.RandomSampler(train_dataset)
        self.valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if distributed else None

        self.train_dataloader = DataLoader(
            train_dataset, sampler=self.train_sampler, batch_size=self.config.bs, num_workers=self.config.n_jobs, pin_memory=True,
            collate_fn=PadBatchSeq(self.tokz.pad_token_id))

        self.valid_dataloader = DataLoader(
            valid_dataset, sampler=self.valid_sampler, batch_size=self.config.bs, num_workers=self.config.n_jobs, pin_memory=True,
            collate_fn=PadBatchSeq(self.tokz.pad_token_id))

    def state_dict(self):
        return self.model.state_dict()
        
    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict)

    def _eval_train(self, epoch):
        self.model.train()

        intent_loss, slot_loss, intent_acc, slot_acc, step_count = 0, 0, 0, 0, 0
        total = len(self.train_dataloader)
        if self.rank in [-1, 0]:
            TQDM = tqdm(enumerate(self.train_dataloader), desc='Train (epoch #{})'.format(epoch),
                        dynamic_ncols=True, total=total)
        else:
            TQDM = enumerate(self.train_dataloader)

        for i, data in TQDM:
            #######################################################
            # TODO: Complete the following function.
            #       The following code should preform the training of the model
            #       You can implement this function with the following steps:
            #       1. Pass the input to GPU by calling data.to(self.device)
            #       2. Forward the input to the model
            #       3. Compute the loss (remember to divide the loss with self.config.batch_split to enable gradient accumulation)
            #       4. Backward the loss
            #       5. Update the parameters
            #       6. Evaluate the model (by calling _eval_test) every `self.config.eval_steps` steps
            #######################################################
            # 1. Pass the input to GPU by calling data.to(self.device)
            text = data['utt'].to(self.device, non_blocking=True)
            intent_labels = data['intent'].to(self.device, non_blocking=True)
            slot_labels = data['slot'].to(self.device, non_blocking=True)
            mask = data['mask'].to(self.device, non_blocking=True)
            token_type = data['token_type'].to(self.device, non_blocking=True)

            # 2. Forward the input to the model
            intent_logits, slot_logits = self.model(input_ids=text,
                                                    attention_mask=mask,
                                                    token_type_ids=token_type)

            # 3. Compute the loss (remember to divide the loss with self.config.batch_split to enable gradient accumulation)
            batch_intent_loss = self.criterion(intent_logits, intent_labels).mean()
            batch_slot_loss = self.criterion(slot_logits.view(-1, slot_logits.shape[-1]), slot_labels.view(-1)).mean()
            slot_mask = 1 - slot_labels.eq(self.tokz.pad_token_id).float()
            batch_slot_loss = (batch_slot_loss * slot_mask.view(-1)).sum() / slot_mask.sum()

            batch_loss = batch_intent_loss + batch_slot_loss
            batch_intent_acc = (torch.argmax(intent_logits, dim=-1) == intent_labels).float().mean()
            batch_slot_acc = (torch.argmax(slot_logits, dim=-1) == slot_labels)
            batch_slot_acc = torch.sum(batch_slot_acc * slot_mask) / torch.sum(slot_mask)

            # 4. Backward the loss
            full_loss = batch_loss / self.config.batch_split
            full_loss.backward()

            intent_loss += batch_intent_loss.item()
            slot_loss += batch_slot_loss.item()
            intent_acc += batch_intent_acc.item()
            slot_acc += batch_slot_acc.item()
            step_count += 1

            # 5. Update the parameters
            curr_step = self.optimizer.curr_step()
            lr = self.optimizer.param_groups[0]['lr']
            if (i + 1) % self.config.batch_split == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

                intent_loss /= step_count
                slot_loss /= step_count
                intent_acc /= step_count
                slot_acc /= step_count

                if self.rank in [-1, 0]:
                    self.train_writer.add_scalar('loss/intent_loss', intent_loss, curr_step)
                    self.train_writer.add_scalar('loss/slot_loss', slot_loss, curr_step)
                    self.train_writer.add_scalar('acc/intent_acc', intent_acc, curr_step)
                    self.train_writer.add_scalar('acc/slot_acc', slot_acc, curr_step)
                    TQDM.set_postfix({'intent_loss': intent_loss,
                                      'intent_acc': intent_acc,
                                      'slot_loss': slot_loss,
                                      'slot_acc': slot_acc})
                    intent_loss, slot_loss, intent_acc, slot_acc, step_count = 0, 0, 0, 0, 0

                # 6. Evaluate the model (by calling _eval_test) every `self.config.eval_steps` steps
                if curr_step % self.config.eval_steps == 0:
                    self._eval_test(epoch=epoch, step=curr_step)

    def _eval_test(self, epoch, step):
        self.model.eval()
        with torch.no_grad():
            dev_intent_loss = torch.tensor(0.0, dtype=torch.float32, device=self.device)
            dev_slot_loss = torch.tensor(0.0, dtype=torch.float32, device=self.device)
            dev_intent_acc = torch.tensor(0.0, dtype=torch.float32, device=self.device)
            dev_slot_acc = torch.tensor(0.0, dtype=torch.float32, device=self.device)
            count = torch.tensor(0.0, dtype=torch.float32, device=self.device)

            for data in self.valid_dataloader:
                text = data['utt'].to(self.device, non_blocking=True)
                intent_labels = data['intent'].to(self.device, non_blocking=True)
                slot_labels = data['slot'].to(self.device, non_blocking=True)
                mask = data['mask'].to(self.device, non_blocking=True)
                token_type = data['token_type'].to(self.device, non_blocking=True)

                intent_logits, slot_logits = self.model(input_ids=text, attention_mask=mask, token_type_ids=token_type)
                
                batch_intent_loss = self.criterion(intent_logits, intent_labels)
                batch_slot_loss = self.criterion(slot_logits.view(-1, slot_logits.shape[-1]), slot_labels.view(-1))
                slot_mask = 1 - slot_labels.eq(self.tokz.pad_token_id).float()
                batch_slot_loss = (batch_slot_loss * slot_mask.view(-1)).view(text.shape[0], -1).sum(dim=-1) / slot_mask.sum(dim=-1)
                
                dev_intent_loss += batch_intent_loss.sum()
                dev_slot_loss += batch_slot_loss.sum()

                batch_intent_acc = (torch.argmax(intent_logits, dim=-1) == intent_labels).sum()
                batch_slot_acc = (torch.argmax(slot_logits, dim=-1) == slot_labels)
                batch_slot_acc = torch.sum(batch_slot_acc * slot_mask, dim=-1) / torch.sum(slot_mask, dim=-1)

                dev_intent_acc += batch_intent_acc
                dev_slot_acc += batch_slot_acc.sum()
                count += text.shape[0]

            if self.rank != -1:
                torch.distributed.all_reduce(dev_intent_loss, op=torch.distributed.reduce_op.SUM)
                torch.distributed.all_reduce(dev_slot_loss, op=torch.distributed.reduce_op.SUM)
                torch.distributed.all_reduce(dev_intent_acc, op=torch.distributed.reduce_op.SUM)
                torch.distributed.all_reduce(dev_slot_acc, op=torch.distributed.reduce_op.SUM)
                torch.distributed.all_reduce(count, op=torch.distributed.reduce_op.SUM)

            dev_intent_loss /= count
            dev_slot_loss /= count
            dev_intent_acc /= count
            dev_slot_acc /= count

            if self.rank in [-1, 0]:
                self.valid_writer.add_scalar('loss/intent_loss', dev_intent_loss, step)
                self.valid_writer.add_scalar('loss/slot_loss', dev_slot_loss, step)
                self.valid_writer.add_scalar('acc/intent_acc', dev_intent_acc, step)
                self.valid_writer.add_scalar('acc/slot_acc', dev_slot_acc, step)
                log_str = 'epoch {:>3}, step {}'.format(epoch, step)
                log_str += ', dev_intent_loss {:>4.4f}'.format(dev_intent_loss)
                log_str += ', dev_slot_loss {:>4.4f}'.format(dev_slot_loss)
                log_str += ', dev_intent_acc {:>4.4f}'.format(dev_intent_acc)
                log_str += ', dev_slot_acc {:>4.4f}'.format(dev_slot_acc)
                self.logger.info(log_str)

        self.model.train()

    def train(self, start_epoch, epochs, after_epoch_funcs=[], after_step_funcs=[]):
        for epoch in range(start_epoch + 1, epochs):
            self.logger.info('Training on epoch'.format(epoch))
            if hasattr(self.train_sampler, 'set_epoch'):
                self.train_sampler.set_epoch(epoch)
            self._eval_train(epoch)
            for func in after_epoch_funcs:
                func(epoch, self.device)