import math
import os
from pathlib import Path
from typing import Union

import cv2
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from basicsr.utils.download_util import load_file_from_url
from gfpgan import GFPGANer
from PIL import Image
from realesrgan import RealESRGANer

import internals.util.image as ImageUtil
from internals.util.commons import download_image
from internals.util.config import get_root_dir
from models.ultrasharp.model import Ultrasharp


class Upscaler:
    __model_esrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
    __model_esrgan_anime_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
    __model_gfpgan_url = (
        "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
    )
    __model_4x_ultrasharp_url = (
        "https://comic-assets.s3.ap-south-1.amazonaws.com/models/4x-UltraSharp.pth"
    )

    __loaded = False

    def load(self):
        if self.__loaded:
            return

        download_dir = Path(Path.home() / ".cache" / "realesrgan")
        download_dir.mkdir(parents=True, exist_ok=True)

        self.__model_path = self.__preload_model(self.__model_esrgan_url, download_dir)
        self.__model_path_anime = self.__preload_model(
            self.__model_esrgan_anime_url, download_dir
        )
        self.__model_path_gfpgan = self.__preload_model(
            self.__model_gfpgan_url, download_dir
        )
        self.__model_path_4x_ultrasharp = self.__preload_model(
            self.__model_4x_ultrasharp_url, download_dir
        )
        self.__loaded = True

    def upscale(
        self,
        image: Union[str, Image.Image],
        width: int,
        height: int,
        face_enhance: bool,
        resize_dimension: int,
    ) -> bytes:
        model = SRVGGNetCompact(
            num_in_ch=3,
            num_out_ch=3,
            num_feat=64,
            num_conv=32,
            upscale=4,
            act_type="prelu",
        )
        return self.__internal_upscale(
            image,
            resize_dimension,
            face_enhance,
            width,
            height,
            self.__model_path,
            model,
        )

    def upscale_anime(
        self,
        image: Union[str, Image.Image],
        width: int,
        height: int,
        face_enhance: bool,
        resize_dimension: int,
    ) -> bytes:
        model = RRDBNet(
            num_in_ch=3,
            num_out_ch=3,
            num_feat=64,
            num_block=6,
            num_grow_ch=32,
            scale=4,
        )
        return self.__internal_upscale(
            image,
            resize_dimension,
            face_enhance,
            width,
            height,
            self.__model_path_anime,
            model,
        )

    def __preload_model(self, url: str, download_dir: Path):
        name = url.split("/")[-1]
        if not os.path.exists(str(download_dir / name)):
            return load_file_from_url(
                url=url,
                model_dir=str(download_dir),
                progress=True,
                file_name=None,
            )
        else:
            return str(download_dir / name)

    def __internal_upscale(
        self,
        image,
        resize_dimension: int,
        face_enhance: bool,
        width: int,
        height: int,
        model_path: str,
        model,
    ) -> bytes:
        if type(image) is str:
            image = download_image(image)

        w, h = image.size
        if max(w, h) > 1024:
            image = ImageUtil.resize_image(image, dimension=1024)

        in_path = str(Path.home() / ".cache" / "input_upscale.png")
        image.save(in_path)
        input_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED)
        dimension = min(input_image.shape[0], input_image.shape[1])
        scale = max(math.floor(resize_dimension / dimension), 2)

        os.chdir(str(Path.home() / ".cache"))
        if scale == 4:
            print("Using 4x-Ultrasharp")
            upsampler = Ultrasharp(self.__model_path_4x_ultrasharp)
        else:
            print("Using RealESRGANer")
            upsampler = RealESRGANer(
                scale=4,
                model_path=model_path,
                model=model,
                half=False,
                gpu_id="0",
                tile=320,
                tile_pad=10,
                pre_pad=0,
            )
        face_enhancer = GFPGANer(
            model_path=self.__model_path_gfpgan,
            upscale=scale,
            arch="clean",
            channel_multiplier=2,
            bg_upsampler=upsampler,
        )

        if face_enhance:
            _, _, output = face_enhancer.enhance(
                input_image, has_aligned=False, only_center_face=False, paste_back=True
            )
        else:
            output, _ = upsampler.enhance(input_image, outscale=scale)
        os.chdir(get_root_dir())
        cv2.imwrite("out.png", output)
        out_bytes = cv2.imencode(".png", output)[1].tobytes()
        return out_bytes