#!/usr/bin/env python3
# coding=utf-8

from utility.loading_bar import LoadingBar
import time
import torch


class Log:
    def __init__(self, dataset, model, optimizer, args, directory, log_each: int, initial_epoch=-1, log_wandb=True):
        self.dataset = dataset
        self.model = model
        self.args = args
        self.optimizer = optimizer

        self.loading_bar = LoadingBar(length=27)
        self.best_f1_score = 0.0
        self.log_each = log_each
        self.epoch = initial_epoch
        self.log_wandb = log_wandb
        if self.log_wandb:
            globals()["wandb"] = __import__("wandb")  # ugly way to not require wandb if not needed

        self.directory = directory
        self.evaluation_results = f"{directory}/results_{{0}}_{{1}}.json"
        self.full_evaluation_results = f"{directory}/full_results_{{0}}_{{1}}.json"
        self.best_full_evaluation_results = f"{directory}/best_full_results_{{0}}_{{1}}.json"
        self.result_history = {epoch: {} for epoch in range(args.epochs)}

        self.best_checkpoint_filename = f"{self.directory}/best_checkpoint.h5"
        self.last_checkpoint_filename = f"{self.directory}/last_checkpoint.h5"

        self.step = 0
        self.total_batch_size = 0
        self.flushed = True

    def train(self, len_dataset: int) -> None:
        self.flush()

        self.epoch += 1
        if self.epoch == 0:
            self._print_header()

        self.is_train = True
        self._reset(len_dataset)

    def eval(self, len_dataset: int) -> None:
        self.flush()
        self.is_train = False
        self._reset(len_dataset)

    def __call__(self, batch_size, losses, grad_norm: float = None, learning_rates: float = None,) -> None:
        if self.is_train:
            self._train_step(batch_size, losses, grad_norm, learning_rates)
        else:
            self._eval_step(batch_size, losses)

        self.flushed = False

    def flush(self) -> None:
        if self.flushed:
            return
        self.flushed = True

        if self.is_train:
            print(f"\r┃{self.epoch:12d}  ┃{self._time():>12}  │", end="", flush=True)
        else:
            if self.losses is not None and self.log_wandb:
                dictionary = {f"validation/{key}": value / self.step for key, value in self.losses.items()}
                dictionary["epoch"] = self.epoch
                wandb.log(dictionary)

            self.losses = None
            # self._save_model(save_as_best=False, performance=None)

    def log_evaluation(self, scores, mode, epoch):
        f1_score = scores["sentiment_tuple/f1"]
        if self.log_wandb:
            scores = {f"{mode}/{k}": v for k, v in scores.items()}
            wandb.log({
                "epoch": epoch,
                **scores
            })

        if mode == "validation" and f1_score > self.best_f1_score:
            if self.log_wandb:
                wandb.run.summary["best sentiment tuple f1 score"] = f1_score
                self.best_f1_score = f1_score
                self._save_model(save_as_best=True, f1_score=f1_score)

    def _save_model(self, save_as_best: bool, f1_score: float):
        if not self.args.save_checkpoints:
            return

        state = {
            "epoch": self.epoch,
            "dataset": self.dataset.state_dict(),
            "f1_score": f1_score,
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "args": self.args.state_dict(),
        }

        filename = self.best_checkpoint_filename if save_as_best else self.last_checkpoint_filename

        torch.save(state, filename)
        if self.log_wandb:
            wandb.save(filename)

    def _train_step(self, batch_size, losses, grad_norm: float, learning_rates) -> None:
        self.total_batch_size += batch_size
        self.step += 1

        if self.losses is None:
            self.losses = losses
        else:
            for key, values in losses.items():
                if key not in self.losses:
                    self.losses[key] = losses[key]
                    continue
                self.losses[key] += losses[key]

        if self.step % self.log_each == 0:
            progress = self.total_batch_size / self.len_dataset
            print(f"\r┃{self.epoch:12d}  │{self._time():>12}  {self.loading_bar(progress)}", end="", flush=True)

            if self.log_wandb:
                dictionary = {f"train/{key}" if not key.startswith("weight/") else key: value / self.log_each for key, value in self.losses.items()}
                dictionary["epoch"] = self.epoch
                dictionary["learning_rate/encoder"] = learning_rates[0]
                dictionary["learning_rate/decoder"] = learning_rates[-2]
                dictionary["learning_rate/grad_norm"] = learning_rates[-1]
                dictionary["gradient norm"] = grad_norm

                wandb.log(dictionary)

            self.losses = None

    def _eval_step(self, batch_size, losses) -> None:
        self.step += 1

        if self.losses is None:
            self.losses = losses
        else:
            for key, values in losses.items():
                if key not in self.losses:
                    self.losses[key] = losses[key]
                    continue
                self.losses[key] += losses[key]

    def _reset(self, len_dataset: int) -> None:
        self.start_time = time.time()
        self.step = 0
        self.total_batch_size = 0
        self.len_dataset = len_dataset
        self.losses = None

    def _time(self) -> str:
        time_seconds = int(time.time() - self.start_time)
        return f"{time_seconds // 60:02d}:{time_seconds % 60:02d} min"

    def _print_header(self) -> None:
        print(f"┏━━━━━━━━━━━━━━┳━━━╸S╺╸E╺╸M╺╸A╺╸N╺╸T╺╸I╺╸S╺╸K╺━━━━━━━━━━━━━━┓")
        print(f"┃              ┃              ╷                             ┃")
        print(f"┃       epoch  ┃     elapsed  │               progress bar  ┃")
        print(f"┠──────────────╂──────────────┼─────────────────────────────┨")