import os
import math
from decimal import Decimal

import utility

import torch
import torch.nn.utils as utils
from tqdm import tqdm

import torch.cuda.amp as amp

from torch.utils.tensorboard import SummaryWriter
import torchvision

import numpy as np

class Trainer():
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.loss = my_loss
        self.optimizer = utility.make_optimizer(args, self.model)

        if self.args.load != '':
            self.optimizer.load(ckp.dir, epoch=len(ckp.log))

        self.error_last = 1e8
        self.scaler=amp.GradScaler(
            enabled=args.amp
        )
        self.writter=None
        self.recurrence=args.recurrence
        if args.recurrence>1:
            self.writter=SummaryWriter(f"runs/{args.save}")

    def train(self):
        self.loss.step()
        epoch = self.optimizer.get_last_epoch() + 1
        lr = self.optimizer.get_lr()

        self.ckp.write_log(
            '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
        )
        self.loss.start_log()
        self.model.train()

        timer_data, timer_model = utility.timer(), utility.timer()
        # TEMP
        self.loader_train.dataset.set_scale(0)
        total=len(self.loader_train)
        buffer=[0.0]*self.recurrence
        # torch.autograd.set_detect_anomaly(True)
        for batch, (lr, hr, _,) in enumerate(self.loader_train):
            lr, hr = self.prepare(lr, hr)
            # print(lr.min(),lr.max(), hr.min(),hr.max())
            # exit(0)
            timer_data.hold()
            timer_model.tic()

            self.optimizer.zero_grad()
            with amp.autocast(self.args.amp):
                sr = self.model(lr, 0)
                if len(sr)==1 and isinstance(sr,list):
                    sr=sr[0]
                # loss,buffer_lst=sequence_loss(sr,hr)
                loss = self.loss(sr, hr)
            self.scaler.scale(loss).backward()
            if self.args.gclip > 0:
                self.scaler.unscale_(self.optimizer)
                utils.clip_grad_value_(
                    self.model.parameters(),
                    self.args.gclip
                )
            self.scaler.step(self.optimizer)
            self.scaler.update()
            for i in range(self.recurrence):
                buffer[i]+=self.loss.buffer[i]
            # self.optimizer.step()

            timer_model.hold()

            if (batch + 1) % self.args.print_every == 0:
                self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
                    (batch + 1) * self.args.batch_size,
                    len(self.loader_train.dataset),
                    self.loss.display_loss(batch),
                    timer_model.release(),
                    timer_data.release()))

            timer_data.tic()
        if self.writter:
            for i in range(self.recurrence):
                grid=torchvision.utils.make_grid(sr[i])
                self.writter.add_image(f"Output{i}",grid,epoch)
                self.writter.add_scalar(f"Loss{i}",buffer[i]/total,epoch)
            self.writter.add_image("Input",torchvision.utils.make_grid(lr),epoch)
            self.writter.add_image("Target",torchvision.utils.make_grid(hr),epoch)
        self.loss.end_log(len(self.loader_train))
        self.error_last = self.loss.log[-1, -1]
        self.optimizer.schedule()

    def test(self):
        torch.set_grad_enabled(False)

        epoch = self.optimizer.get_last_epoch()
        self.ckp.write_log('\nEvaluation:')
        self.ckp.add_log(
            torch.zeros(1, len(self.loader_test), len(self.scale))
        )
        self.model.eval()

        timer_test = utility.timer()
        if self.args.save_results: self.ckp.begin_background()
        for idx_data, d in enumerate(self.loader_test):
            for idx_scale, scale in enumerate(self.scale):
                d.dataset.set_scale(idx_scale)
                for lr, hr, filename in tqdm(d, ncols=80):
                    lr, hr = self.prepare(lr, hr)
                    with amp.autocast(self.args.amp):
                        sr = self.model(lr, idx_scale)
                    if isinstance(sr,list):
                        sr=sr[-1]
                    sr = utility.quantize(sr, self.args.rgb_range)

                    save_list = [sr]
                    self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
                        sr, hr, scale, self.args.rgb_range, dataset=d
                    )
                    if self.args.save_gt:
                        save_list.extend([lr, hr])

                    if self.args.save_results:
                        self.ckp.save_results(d, filename[0], save_list, scale)

                self.ckp.log[-1, idx_data, idx_scale] /= len(d)
                best = self.ckp.log.max(0)
                self.ckp.write_log(
                    '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
                        d.dataset.name,
                        scale,
                        self.ckp.log[-1, idx_data, idx_scale],
                        best[0][idx_data, idx_scale],
                        best[1][idx_data, idx_scale] + 1
                    )
                )
        self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
        self.ckp.write_log('Saving...')
        # torch.cuda.empty_cache()
        if self.args.save_results:
            self.ckp.end_background()

        if not self.args.test_only:
            self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))

        self.ckp.write_log(
            'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
        )

        torch.set_grad_enabled(True)

    def prepare(self, *args):
        device = torch.device('cpu' if self.args.cpu else 'cuda')
        def _prepare(tensor):
            if self.args.precision == 'half': tensor = tensor.half()
            return tensor.to(device)

        return [_prepare(a) for a in args]

    def terminate(self):
        if self.args.test_only:
            self.test()
            return True
        else:
            epoch = self.optimizer.get_last_epoch() + 1
            return epoch >= self.args.epochs