import spaces
import gradio as gr
import torch
from PIL import Image
from pathlib import Path
import gc
import subprocess
from env import num_cns, model_trigger


subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
subprocess.run('pip cache purge', shell=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)


control_images = [None] * num_cns
control_modes = [-1] * num_cns
control_scales = [0] * num_cns


def is_repo_name(s):
    import re
    return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)


def is_repo_exists(repo_id):
    from huggingface_hub import HfApi
    api = HfApi()
    try:
        if api.repo_exists(repo_id=repo_id): return True
        else: return False
    except Exception as e:
        print(f"Error: Failed to connect {repo_id}.")
        print(e)
        return True # for safe


from translatepy import Translator
translator = Translator()
def translate_to_en(input: str):
    try:
        output = str(translator.translate(input, 'English'))
    except Exception as e:
        output = input
        print(e)
    return output


def clear_cache():
    try:
        torch.cuda.empty_cache()
        #torch.cuda.reset_max_memory_allocated()
        #torch.cuda.reset_peak_memory_stats()
        gc.collect()
    except Exception as e:
        print(e)
        raise Exception(f"Cache clearing error: {e}") from e


def get_repo_safetensors(repo_id: str):
    from huggingface_hub import HfApi
    api = HfApi()
    try:
        if not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(value="", choices=[])
        files = api.list_repo_files(repo_id=repo_id)
    except Exception as e:
        print(f"Error: Failed to get {repo_id}'s info.")
        print(e)
        gr.Warning(f"Error: Failed to get {repo_id}'s info.")
        return gr.update(choices=[])
    files = [f for f in files if f.endswith(".safetensors")]
    if len(files) == 0: return gr.update(value="", choices=[])
    else: return gr.update(value=files[0], choices=files)


def expand2square(pil_img: Image.Image, background_color: tuple=(0, 0, 0)):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result


# https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny/blob/main/app.py
def resize_image(image, target_width, target_height, crop=True):
    from image_datasets.canny_dataset import c_crop
    if crop:
        image = c_crop(image)  # Crop the image to square
        original_width, original_height = image.size

        # Resize to match the target size without stretching
        scale = max(target_width / original_width, target_height / original_height)
        resized_width = int(scale * original_width)
        resized_height = int(scale * original_height)

        image = image.resize((resized_width, resized_height), Image.LANCZOS)
        
        # Center crop to match the target dimensions
        left = (resized_width - target_width) // 2
        top = (resized_height - target_height) // 2
        image = image.crop((left, top, left + target_width, top + target_height))
    else:
        image = image.resize((target_width, target_height), Image.LANCZOS)
    
    return image


# https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union/blob/main/app.py
# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
controlnet_union_modes = {
    "None": -1,
    #"scribble_hed": 0,
    "canny": 0, # supported
    "mlsd": 0, #supported
    "tile": 1, #supported
    "depth_midas": 2, # supported
    "blur": 3, # supported
    "openpose": 4,  # supported
    "gray": 5,  # supported
    "low_quality": 6,  # supported
}


# https://github.com/pytorch/pytorch/issues/123834
def get_control_params():
    from diffusers.utils import load_image
    modes = []
    images = []
    scales = []
    for i, mode in enumerate(control_modes):
        if mode == -1 or control_images[i] is None: continue
        modes.append(control_modes[i])
        images.append(load_image(control_images[i]))
        scales.append(control_scales[i])
    return modes, images, scales


from preprocessor import Preprocessor
def preprocess_image(image: Image.Image, control_mode: str, height: int, width: int,
                     preprocess_resolution: int):
    if control_mode == "None": return image
    image_resolution = max(width, height)
    image_before = resize_image(expand2square(image.convert("RGB")), image_resolution, image_resolution, False)
    # generated control_
    print("start to generate control image")
    preprocessor = Preprocessor()
    if control_mode == "depth_midas":
        preprocessor.load("Midas")
        control_image = preprocessor(
            image=image_before,
            image_resolution=image_resolution,
            detect_resolution=preprocess_resolution,
        )
    if control_mode == "openpose":
        preprocessor.load("Openpose")
        control_image = preprocessor(
            image=image_before,
            hand_and_face=True,
            image_resolution=image_resolution,
            detect_resolution=preprocess_resolution,
        )
    if control_mode == "canny":
        preprocessor.load("Canny")
        control_image = preprocessor(
            image=image_before,
            image_resolution=image_resolution,
            detect_resolution=preprocess_resolution,
        )

    if control_mode == "mlsd":
        preprocessor.load("MLSD")
        control_image = preprocessor(
            image=image_before,
            image_resolution=image_resolution,
            detect_resolution=preprocess_resolution,
        )

    if control_mode == "scribble_hed":
        preprocessor.load("HED")
        control_image = preprocessor(
            image=image_before,
            image_resolution=image_resolution,
            detect_resolution=preprocess_resolution,
        )
    
    if control_mode == "low_quality" or control_mode == "gray" or control_mode == "blur" or control_mode == "tile":
        control_image = image_before
        image_width = 768
        image_height = 768
    else:
        # make sure control image size is same as resized_image
        image_width, image_height = control_image.size
    
    image_after = resize_image(control_image, width, height, False)
    ref_width, ref_height = image.size
    print(f"generate control image success: {ref_width}x{ref_height} => {image_width}x{image_height}")
    return image_after


def get_control_union_mode():
    return list(controlnet_union_modes.keys())


def set_control_union_mode(i: int, mode: str, scale: str):
    global control_modes
    global control_scales
    control_modes[i] = controlnet_union_modes.get(mode, 0)
    control_scales[i] = scale
    if mode != "None": return True
    else: return gr.update(visible=True)


def set_control_union_image(i: int, mode: str, image: Image.Image | None, height: int, width: int, preprocess_resolution: int):
    global control_images
    if image is None: return None
    control_images[i] = preprocess_image(image, mode, height, width, preprocess_resolution)
    return control_images[i]


def preprocess_i2i_image(image_path: str, is_preprocess: bool, height: int, width: int):
    try:
        if not is_preprocess: return image_path
        image_resolution = max(width, height) 
        image = Image.open(image_path)
        image_resized = resize_image(expand2square(image.convert("RGB")), image_resolution, image_resolution, False)
        image_resized.save(image_path)
    except Exception as e:
        raise gr.Error(f"Error: {e}")
    return image_path


def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str):
    lorajson[i]["name"] = str(name) if name != "None" else ""
    lorajson[i]["scale"] = float(scale)
    lorajson[i]["filename"] = str(filename)
    lorajson[i]["trigger"] = str(trigger)
    return lorajson


def is_valid_lora(lorajson: list[dict]):
    valid = False
    for d in lorajson:
        if "name" in d.keys() and d["name"] and d["name"] != "None": valid = True
    return valid


def get_trigger_word(lorajson: list[dict]):
    trigger = ""
    for d in lorajson:
        if "name" in d.keys() and d["name"] and d["name"] != "None" and d["trigger"]:
            trigger += ", " + d["trigger"]
    return trigger


def get_model_trigger(model_name: str):
    trigger = ""
    if model_name in model_trigger.keys(): trigger += ", " + model_trigger[model_name]
    return trigger


# https://huggingface.co/docs/diffusers/v0.23.1/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora
# https://github.com/huggingface/diffusers/issues/4919
def fuse_loras(pipe, lorajson: list[dict]):
    try:
        if not lorajson or not isinstance(lorajson, list): return pipe, [], []
        a_list = []
        w_list = []
        for d in lorajson:
            if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None": continue
            k = d["name"]
            if is_repo_name(k) and is_repo_exists(k):
                a_name = Path(k).stem
                pipe.load_lora_weights(k, weight_name=d["filename"], adapter_name = a_name, low_cpu_mem_usage=True)
            elif not Path(k).exists():
                print(f"LoRA not found: {k}")
                continue
            else:
                w_name = Path(k).name
                a_name = Path(k).stem
                pipe.load_lora_weights(k, weight_name = w_name, adapter_name = a_name, low_cpu_mem_usage=True)
            a_list.append(a_name)
            w_list.append(d["scale"])
        if not a_list: return pipe, [], []
        #pipe.set_adapters(a_list, adapter_weights=w_list)
        #pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0)
        #pipe.unload_lora_weights()
        return pipe, a_list, w_list
    except Exception as e:
        print(f"External LoRA Error: {e}")
        raise Exception(f"External LoRA Error: {e}") from e


def description_ui():
    gr.Markdown(
        """
- Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
 [multimodalart/flux-lora-lab](https://huggingface.co/spaces/multimodalart/flux-lora-lab),
 [jiuface/FLUX.1-dev-Controlnet-Union](https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union),
 [DamarJati/FLUX.1-DEV-Canny](https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny),
 [gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
"""
    )


from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
def load_prompt_enhancer():
    try:
        model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
        tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
        enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
    except Exception as e:
        print(e)
        enhancer_flux = None
    return enhancer_flux


enhancer_flux = load_prompt_enhancer()


@spaces.GPU(duration=30)
def enhance_prompt(input_prompt):
    result = enhancer_flux("enhance prompt: " + translate_to_en(input_prompt), max_length = 256)
    enhanced_text = result[0]['generated_text']
    return enhanced_text


def save_image(image, savefile, modelname, prompt, height, width, steps, cfg, seed):
    import uuid
    from PIL import PngImagePlugin
    import json
    try:
        if savefile is None: savefile = f"{modelname.split('/')[-1]}_{str(uuid.uuid4())}.png"
        metadata = {"prompt": prompt, "Model": {"Model": modelname.split("/")[-1]}}
        metadata["num_inference_steps"] = steps
        metadata["guidance_scale"] = cfg
        metadata["seed"] = seed
        metadata["resolution"] = f"{width} x {height}"
        metadata_str = json.dumps(metadata)
        info = PngImagePlugin.PngInfo()
        info.add_text("metadata", metadata_str)
        image.save(savefile, "PNG", pnginfo=info)
        return str(Path(savefile).resolve())
    except Exception as e:
        print(f"Failed to save image file: {e}")
        raise Exception(f"Failed to save image file:") from e


load_prompt_enhancer.zerogpu = True
fuse_loras.zerogpu = True
preprocess_image.zerogpu = True
get_control_params.zerogpu = True
clear_cache.zerogpu = True