import gradio as gr
import numpy as np
import PIL.Image
from PIL import Image, ImageOps
import random
#from diffusers import DiffusionPipeline
#from diffusers import StableDiffusionXLPipeline
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
from controlnet_aux import PidiNetDetector, HEDdetector
from diffusers.utils import load_image
import cv2
import torch
import spaces

def nms(x, t, s):
    x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)

    f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
    f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
    f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
    f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)

    y = np.zeros_like(x)

    for f in [f1, f2, f3, f4]:
        np.putmask(y, cv2.dilate(x, kernel=f) == x, x)

    z = np.zeros_like(y, dtype=np.uint8)
    z[y > t] = 255
    return z

def HWC3(x):
    assert x.dtype == np.uint8
    if x.ndim == 2:
        x = x[:, :, None]
    assert x.ndim == 3
    H, W, C = x.shape
    assert C == 1 or C == 3 or C == 4
    if C == 3:
        return x
    if C == 1:
        return np.concatenate([x, x, x], axis=2)
    if C == 4:
        color = x[:, :, 0:3].astype(np.float32)
        alpha = x[:, :, 3:4].astype(np.float32) / 255.0
        y = color * alpha + 255.0 * (1.0 - alpha)
        y = y.clip(0, 255).astype(np.uint8)
        return y
        

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")


controlnet = ControlNetModel.from_pretrained(
    "xinsir/controlnet-scribble-sdxl-1.0",
    #"2vXpSwA7/test_controlnet2/CN-anytest_v4-marged_am_dim256.safetensors"
    
    torch_dtype=torch.float16
)

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    #"sd-community/sdxl-flash",
    #"yodayo-ai/kivotos-xl-2.0", 
    "yodayo-ai/holodayo-xl-2.1", 
    controlnet=controlnet,
    vae=vae,
    torch_dtype=torch.float16,
    # scheduler=eulera_scheduler,
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

pipe.to(device)


MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1216

#pipe = StableDiffusionXLPipeline.from_pretrained(
#    #"yodayo-ai/kivotos-xl-2.0", 
#    "yodayo-ai/holodayo-xl-2.1", 
#    torch_dtype=torch.float16, 
#    use_safetensors=True,
#    custom_pipeline="lpw_stable_diffusion_xl",
#    add_watermarker=False,
#    variant="fp16"
#)
#pipe.to('cuda')

prompt = "1girl, solo, upper body, v, smile, looking at viewer, outdoors, night, masterpiece, best quality, very aesthetic, absurdres"
negative_prompt = "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"

def nms(x, t, s):
    x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)

    f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
    f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
    f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
    f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)

    y = np.zeros_like(x)

    for f in [f1, f2, f3, f4]:
        np.putmask(y, cv2.dilate(x, kernel=f) == x, x)

    z = np.zeros_like(y, dtype=np.uint8)
    z[y > t] = 255
    return z
    
@spaces.GPU
def infer(image: PIL.Image.Image,prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps)-> PIL.Image.Image:

    width, height  = image['composite'].size
    ratio = np.sqrt(1024. * 1024. / (width * height))
    new_width, new_height = int(width * ratio), int(height * ratio)
    image = image['composite'].resize((new_width, new_height))

    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    controlnet_img = image
    # following is some processing to simulate human sketch draw, different threshold can generate different width of lines
    controlnet_img = np.array(controlnet_img)
    controlnet_img = nms(controlnet_img, 127, 3)
    controlnet_img = cv2.GaussianBlur(controlnet_img, (0, 0), 3)

    # higher threshold, thiner line
    random_val = int(round(random.uniform(0.01, 0.10), 2) * 255)
    controlnet_img[controlnet_img > random_val] = 255
    controlnet_img[controlnet_img < 255] = 0
    image = Image.fromarray(controlnet_img)
    
    generator = torch.Generator().manual_seed(seed)
    
    output_image = pipe(
        prompt = prompt+", masterpiece, best quality, very aesthetic, absurdres", 
        negative_prompt = negative_prompt,
        guidance_scale = guidance_scale, 
        num_inference_steps = num_inference_steps, 
        width = width, 
        height = height,
        generator = generator
    ).images[0] 
    
    return output_image

css="""
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # Text-to-Image Demo
        using [Holodayo XL 2.1](https://huggingface.co/yodayo-ai/holodayo-xl-2.1)
        """)
        
        with gr.Row():
            
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            
            run_button = gr.Button("Run", scale=0)
            
        image = gr.ImageEditor(type="pil", image_mode="L", crop_size=(512, 512))
        result = gr.Image(label="Result", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
                #visible=False,
                value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
            )
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=832,
                )
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1216,
                )
            
            with gr.Row():
                
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=20.0,
                    step=0.1,
                    value=7,
                )
                
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=28,
                    step=1,
                    value=28,
                )

    run_button.click(
        fn = infer,
        inputs = [image,prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
        outputs = [result]
    )

demo.queue().launch()