import os
# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "caching_allocator"
import gradio as gr
import numpy as np
from models import make_inpainting
import utils

from transformers import MaskFormerImageProcessor, MaskFormerForInstanceSegmentation
from PIL import Image
import requests
from transformers import pipeline
import torch
import random
import io
import base64
import json
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
from diffusers import StableDiffusionUpscalePipeline
from diffusers import LDMSuperResolutionPipeline
import cv2
import onnxruntime
from split_image import split

def removeFurniture(input_img1,
            input_img2,
            positive_prompt,
            negative_prompt,
            num_of_images,
            resolution
            ):

    print("removeFurniture")
    HEIGHT = resolution
    WIDTH = resolution

    input_img1 = input_img1.resize((resolution, resolution))
    input_img2 = input_img2.resize((resolution, resolution))

    canvas_mask = np.array(input_img2)
    mask = utils.get_mask(canvas_mask)

    print(input_img1, mask, positive_prompt, negative_prompt)

    retList=  make_inpainting(positive_prompt=positive_prompt,
                               image=input_img1,
                               mask_image=mask,
                               negative_prompt=negative_prompt,
                               num_of_images=num_of_images,
                               resolution=resolution
                               )
    # add the rest up to 10
    while (len(retList)<10):
        retList.append(None)

    return retList

def imageToString(img):

    output = io.BytesIO()
    img.save(output, format="png")
    return output.getvalue()

def segmentation(img):
    print("segmentation")

    # semantic_segmentation = pipeline("image-segmentation", "nvidia/segformer-b1-finetuned-cityscapes-1024-1024")
    pipe = pipeline("image-segmentation", "facebook/maskformer-swin-large-ade")    
    results = pipe(img)
    for p in results:
        p['mask'] = utils.image_to_byte_array(p['mask'])
        p['mask'] = base64.b64encode(p['mask']).decode("utf-8")
    #print(results)
    return json.dumps(results)
    




def upscale1(image, prompt):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("upscale1", device, image, prompt)    
    
    # image.thumbnail((512, 512))
    # print("resize",image)

    torch.backends.cuda.matmul.allow_tf32 = True

    pipe = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler",                                                           
                                                          torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                                                          use_safetensors=True)
    # pipe = StableDiffusionLatentUpscalePipeline.from_pretrained("stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16)
    pipe = pipe.to(device)
    pipe.enable_attention_slicing()
    pipe.enable_xformers_memory_efficient_attention()
    # pipe.enable_xformers_memory_efficient_attention(attention_op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp)
    # Workaround for not accepting attention shape using VAE for Flash Attention
    pipe.vae.enable_xformers_memory_efficient_attention()

    ret = pipe(prompt=prompt, 
                   image=image,
                   num_inference_steps=10,
                   guidance_scale=0)
    print("ret",ret)
    upscaled_image = ret.images[0]
    print("up",upscaled_image)

    return upscaled_image

def upscale2(image, prompt):
    print("upscale2",image,prompt)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("device",device)

    pipe = LDMSuperResolutionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages", torch_dtype=torch.float16)
    pipe = pipe.to(device)
    pipe.enable_attention_slicing()
    pipe.enable_xformers_memory_efficient_attention(attention_op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp)
    # Workaround for not accepting attention shape using VAE for Flash Attention
    pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)

    upscaled_image = pipe(image, num_inference_steps=10, eta=1).images[0]
    return upscaled_image

def convert_pil_to_cv2(image):
    # pil_image = image.convert("RGB")
    open_cv_image = np.array(image)
    # RGB to BGR
    open_cv_image = open_cv_image[:, :, ::-1].copy()
    return open_cv_image

def inference(model_path: str, img_array: np.array) -> np.array:
    options = onnxruntime.SessionOptions()
    options.intra_op_num_threads = 1
    options.inter_op_num_threads = 1
    ort_session = onnxruntime.InferenceSession(model_path, options)
    ort_inputs = {ort_session.get_inputs()[0].name: img_array}
    ort_outs = ort_session.run(None, ort_inputs)

    return ort_outs[0]

def post_process(img: np.array) -> np.array:
    # 1, C, H, W -> C, H, W
    img = np.squeeze(img)
    # C, H, W -> H, W, C
    img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
    return img

def pre_process(img: np.array) -> np.array:
    # H, W, C -> C, H, W
    img = np.transpose(img[:, :, 0:3], (2, 0, 1))
    # C, H, W -> 1, C, H, W
    img = np.expand_dims(img, axis=0).astype(np.float32)
    return img

def upscale3(image):
    print("upscale3",image)

    model_path = f"up_models/modelx4.ort"
    img = convert_pil_to_cv2(image)
    
    # if img.ndim == 2:
    #     print("upscale3","img.ndim == 2")
    #     img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    # if img.shape[2] == 4:
    #     print("upscale3","img.shape[2] == 4")
    #     alpha = img[:, :, 3]  # GRAY
    #     alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR)  # BGR
    #     alpha_output = post_process(inference(model_path, pre_process(alpha)))  # BGR
    #     alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY)  # GRAY

    #     img = img[:, :, 0:3]  # BGR
    #     image_output = post_process(inference(model_path, pre_process(img)))  # BGR
    #     image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA)  # BGRA
    #     image_output[:, :, 3] = alpha_output

    # print("upscale3","img.shape[2] == 3")
    image_output = post_process(inference(model_path, pre_process(img)))  # BGR

    return image_output



def split_image(im, rows, cols, should_square, should_quiet=False):
    im_width, im_height = im.size
    row_width = int(im_width / cols)
    row_height = int(im_height / rows)
    name = "image"
    ext = ".png"
    name = os.path.basename(name)
    images = []
    if should_square:
        min_dimension = min(im_width, im_height)
        max_dimension = max(im_width, im_height)
        if not should_quiet:
            print("Resizing image to a square...")
            print("Determining background color...")
        bg_color = split.determine_bg_color(im)
        if not should_quiet:
            print("Background color is... " + str(bg_color))
        im_r = Image.new("RGBA" if ext == "png" else "RGB",
                         (max_dimension, max_dimension), bg_color)
        offset = int((max_dimension - min_dimension) / 2)
        if im_width > im_height:
            im_r.paste(im, (0, offset))
        else:
            im_r.paste(im, (offset, 0))
        im = im_r
        row_width = int(max_dimension / cols)
        row_height = int(max_dimension / rows)
    n = 0
    for i in range(0, rows):
        for j in range(0, cols):
            box = (j * row_width, i * row_height, j * row_width +
                   row_width, i * row_height + row_height)
            outp = im.crop(box)
            outp_path = name + "_" + str(n) + ext
            if not should_quiet:
                print("Exporting image tile: " + outp_path)
            images.append(outp)
            n += 1
    return [img for img in images]

def upscale_image(img, rows, up_factor, cols, seed, prompt, negative_prompt, xformers, cpu_offload, attention_slicing, enable_custom_sliders=False, guidance=7, iterations=50):
    
    if up_factor==2:
        model_id = "stabilityai/sd-x2-latent-upscaler"
        try:
            pipeline = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        except:
            pipeline = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16, local_files_only=True)

    if up_factor==4:
        model_id = "stabilityai/stable-diffusion-x4-upscaler"

        try:
            pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        except:
            pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16, local_files_only=True)            


    pipeline = pipeline.to("cuda")
    if xformers:
        pipeline.enable_xformers_memory_efficient_attention()
    else:
        pipeline.disable_xformers_memory_efficient_attention()
    if cpu_offload:
        try:
            pipeline.enable_sequential_cpu_offload()
        except:
            pass
    if attention_slicing:
        pipeline.enable_attention_slicing()
    else:
        pipeline.disable_attention_slicing()
    img = Image.fromarray(img)
    # load model and scheduler
    if seed==-1:
        generator = torch.manual_seed(random.randint(0, 9999999))
    else:
        generator = torch.manual_seed(seed)
    
    original_width, original_height = img.size
    max_dimension = max(original_width, original_height)
    tiles = split_image(img, rows, cols, True, False)
    ups_tiles = []
    i = 0
    for x in tiles:
        i=i+1
        if enable_custom_sliders:
            ups_tile = pipeline(prompt=prompt,negative_prompt=negative_prompt,guidance_scale=guidance, num_inference_steps=iterations, image=x.convert("RGB"),generator=generator).images[0]
        else:
            ups_tile = pipeline(prompt=prompt,negative_prompt=negative_prompt, image=x.convert("RGB"),generator=generator).images[0]
        ups_tiles.append(ups_tile)
        
    # Determine the size of the merged upscaled image
    total_width = 0
    total_height = 0
    side = 0
    for ups_tile in ups_tiles:
        side = ups_tile.width
        break
    for x in tiles:
        tsize = x.width
        break

    ups_times = abs(side/tsize)
    new_size = (max_dimension * ups_times, max_dimension * ups_times)
    total_width = cols*side
    total_height = rows*side

    # Create a blank image with the calculated size
    merged_image = Image.new("RGB", (total_width, total_height))

    # Paste each upscaled tile into the blank image
    current_width = 0
    current_height = 0
    maximum_width = cols*side
    for ups_tile in ups_tiles:
        merged_image.paste(ups_tile, (current_width, current_height))
        current_width += ups_tile.width
        if current_width>=maximum_width:
            current_width = 0
            current_height = current_height+side

    # Using the center of the image as pivot, crop the image to the original dimension times four
    crop_left = (new_size[0] - original_width * ups_times) // 2
    crop_upper = (new_size[1] - original_height * ups_times) // 2
    crop_right = crop_left + original_width * ups_times
    crop_lower = crop_upper + original_height * ups_times
    final_img = merged_image.crop((crop_left, crop_upper, crop_right, crop_lower))

    # The resulting image should be identical to the original image in proportions / aspect ratio, with no loss of elements.
    # Save the merged image
    return final_img

    
def upscale( image, prompt, negative_prompt, rows, up_factor, guidance, iterations, xformers_input, cpu_offload_input, attention_slicing_input):
    print("upscale", prompt, negative_prompt, rows, up_factor, guidance, iterations, xformers_input, cpu_offload_input, attention_slicing_input)   
    return upscale_image(img=image, 
                         rows=rows,cols=rows,
                         up_factor=up_factor,
                         seed=-1,
                         prompt=prompt,
                         negative_prompt=negative_prompt, 
                         enable_custom_sliders=True,
                         xformers=xformers_input, 
                         cpu_offload=cpu_offload_input, 
                         attention_slicing=attention_slicing_input,
                         guidance=guidance,
                         iterations=iterations)

# modes = {
#     '1': '1',
#     'img2img': 'Image to Image',
#     'inpaint': 'Inpainting',
#     'upscale4x': 'Upscale 4x',
# }



with gr.Blocks() as app:    
    gr.HTML(
        f"""         
            Running on <b>{"GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"}</b>
          </div>
        """
    )

    with gr.Row():

        with gr.Column():
            gr.Button("FurnituRemove").click(removeFurniture, 
                                        inputs=[gr.Image(label="img", type="pil"),
                                                gr.Image(label="mask", type="pil"),
                                                gr.Textbox(label="positive_prompt",value="empty room"),
                                                gr.Textbox(label="negative_prompt",value=""),
                                                gr.Number(label="num_of_images",value=2),
                                                gr.Number(label="resolution",value=512)
                                                ], 
                                        outputs=[
                                                gr.Image(),
                                                gr.Image(),
                                                gr.Image(),
                                                gr.Image(),
                                                gr.Image(),
                                                gr.Image(),
                                                gr.Image(),
                                                gr.Image(),
                                                gr.Image(),
                                                gr.Image()])
        
        with gr.Column():  
            gr.Button("Segmentation").click(segmentation, inputs=gr.Image(type="pil"), outputs=gr.JSON())

        with gr.Column():
            gr.Button("Upscale").click(
                upscale, 
                inputs=[
                    gr.Image(label="Source Image to upscale"),
                    gr.Textbox(label="prompt",value="empty room"),
                    gr.Textbox(label="negative prompt",value="jpeg artifacts, lowres, bad quality, watermark, text"),
                    gr.Number(value=2, label="Tile grid dimension amount (number of rows and columns) - X by X "),
                    gr.Slider(2, 4, 2, step=2, label='Upscale 2 or 4'),
                    gr.Slider(2, 15, 7, step=1, label='Guidance Scale: How much the AI influences the Upscaling.'),
                    gr.Slider(2, 100, 10, step=1, label='Number of Iterations'),
                    gr.Checkbox(value=True,label="Enable Xformers memory efficient attention"),                    
                    gr.Checkbox(value=True,label="Enable sequential CPU offload"),
                    gr.Checkbox(value=True,label="Enable attention slicing")
                    ], 
                outputs=gr.Image())


# app.queue()
app.launch()

# UP 1