from __future__ import annotations

import logging
import os
import random
import sys
import tempfile

import gradio as gr
import imageio
import numpy as np
import PIL.Image
import torch
import tqdm.auto
from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
                       DiffusionPipeline, PNDMPipeline, PNDMScheduler)

HF_TOKEN = os.environ['HF_TOKEN']

formatter = logging.Formatter(
    '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S')
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.propagate = False
logger.addHandler(stream_handler)


class Model:

    MODEL_NAMES = [
        'ddpm-128-exp000',
    ]

    def __init__(self, device: str | torch.device):
        self.device = torch.device(device)
        self._download_all_models()

        self.model_name = self.MODEL_NAMES[0]
        self.scheduler_type = 'DDIM'
        self.pipeline = self._load_pipeline(self.model_name,
                                            self.scheduler_type)
        self.rng = random.Random()

        self.real_esrgan = gr.Interface.load('spaces/hysts/Real-ESRGAN-anime')

    @staticmethod
    def _load_pipeline(model_name: str,
                       scheduler_type: str) -> DiffusionPipeline:
        repo_id = f'hysts/diffusers-anime-faces-{model_name}'
        if scheduler_type == 'DDPM':
            pipeline = DDPMPipeline.from_pretrained(repo_id,
                                                    use_auth_token=HF_TOKEN)
        elif scheduler_type == 'DDIM':
            pipeline = DDIMPipeline.from_pretrained(repo_id,
                                                    use_auth_token=HF_TOKEN)
            pipeline.scheduler = DDIMScheduler.from_config(
                repo_id, subfolder='scheduler', use_auth_token=HF_TOKEN)
        elif scheduler_type == 'PNDM':
            pipeline = PNDMPipeline.from_pretrained(repo_id,
                                                    use_auth_token=HF_TOKEN)
            pipeline.scheduler = PNDMScheduler.from_config(
                repo_id, subfolder='scheduler', use_auth_token=HF_TOKEN)
        else:
            raise ValueError
        return pipeline

    def set_pipeline(self, model_name: str, scheduler_type: str) -> None:
        logger.info('--- set_pipeline ---')
        logger.info(f'{model_name=}, {scheduler_type=}')

        if model_name == self.model_name and scheduler_type == self.scheduler_type:
            logger.info('Skipping')
            logger.info('--- done ---')
            return
        self.model_name = model_name
        self.scheduler_type = scheduler_type
        self.pipeline = self._load_pipeline(model_name, scheduler_type)

        logger.info('--- done ---')

    def _download_all_models(self) -> None:
        for name in self.MODEL_NAMES:
            self._load_pipeline(name, 'DDPM')

    def generate(self,
                 seed: int,
                 num_steps: int,
                 num_images: int = 1) -> list[PIL.Image.Image]:
        logger.info('--- generate ---')
        logger.info(f'{seed=}, {num_steps=}')

        torch.manual_seed(seed)
        if self.scheduler_type == 'DDPM':
            res = self.pipeline(batch_size=num_images,
                                torch_device=self.device)['sample']
        elif self.scheduler_type in ['DDIM', 'PNDM']:
            res = self.pipeline(batch_size=num_images,
                                torch_device=self.device,
                                num_inference_steps=num_steps)['sample']
        else:
            raise ValueError

        logger.info('--- done ---')
        return res

    @staticmethod
    def postprocess(sample: torch.Tensor) -> np.ndarray:
        res = (sample / 2 + 0.5).clamp(0, 1)
        res = (res * 255).to(torch.uint8)
        res = res.cpu().permute(0, 2, 3, 1).numpy()
        return res

    @torch.inference_mode()
    def generate_with_video(self, seed: int,
                            num_steps: int) -> tuple[PIL.Image.Image, str]:
        logger.info('--- generate_with_video ---')
        if self.scheduler_type == 'DDPM':
            num_steps = 1000
            fps = 100
        else:
            fps = 10
        logger.info(f'{seed=}, {num_steps=}')

        model = self.pipeline.unet.to(self.device)
        scheduler = self.pipeline.scheduler
        scheduler.set_timesteps(num_inference_steps=num_steps)
        input_shape = (1, model.config.in_channels, model.config.sample_size,
                       model.config.sample_size)
        torch.manual_seed(seed)

        out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
        writer = imageio.get_writer(out_file.name, fps=fps)
        sample = torch.randn(input_shape).to(self.device)
        for t in tqdm.auto.tqdm(scheduler.timesteps):
            out = model(sample, t)['sample']
            sample = scheduler.step(out, t, sample)['prev_sample']
            res = self.postprocess(sample)[0]
            writer.append_data(res)
        writer.close()

        logger.info('--- done ---')
        return PIL.Image.fromarray(res), out_file.name

    def superresolve(self, image: PIL.Image.Image) -> PIL.Image.Image:
        logger.info('--- superresolve ---')

        with tempfile.NamedTemporaryFile(suffix='.png') as f:
            image.save(f.name)
            out_file = self.real_esrgan(f.name)

        logger.info('--- done ---')
        return PIL.Image.open(out_file)

    def run(self, model_name: str, scheduler_type: str, num_steps: int,
            randomize_seed: bool,
            seed: int) -> tuple[PIL.Image.Image, PIL.Image.Image, int, str]:
        self.set_pipeline(model_name, scheduler_type)
        if scheduler_type == 'PNDM':
            num_steps = max(4, min(num_steps, 100))
        if randomize_seed:
            seed = self.rng.randint(0, 100000)
        res, filename = self.generate_with_video(seed, num_steps)
        superresolved = self.superresolve(res)
        return superresolved, res, seed, filename

    @staticmethod
    def to_grid(images: list[PIL.Image.Image],
                ncols: int = 2) -> PIL.Image.Image:
        images = [np.asarray(image) for image in images]
        nrows = (len(images) + ncols - 1) // ncols
        h, w = images[0].shape[:2]
        if (d := nrows * ncols - len(images)) > 0:
            images += [np.full((h, w, 3), 255, dtype=np.uint8)] * d
        grid = np.asarray(images).reshape(nrows, ncols, h, w, 3).transpose(
            0, 2, 1, 3, 4).reshape(nrows * h, ncols * w, 3)
        return PIL.Image.fromarray(grid)

    def run_simple(self) -> tuple[PIL.Image.Image, PIL.Image.Image]:
        self.set_pipeline(self.MODEL_NAMES[0], 'PNDM')
        seed = self.rng.randint(0, np.iinfo(np.uint32).max + 1)
        images = self.generate(seed, num_steps=10, num_images=4)
        superresolved = [self.superresolve(image) for image in images]
        return self.to_grid(superresolved, 2), self.to_grid(images, 2)