import gradio as gr
from models import models
from PIL import Image
import requests
import uuid
import io 
import base64
import cv2
import numpy
from transforms import RGBTransform
import random

#import torch
#from diffusers import AutoPipelineForImage2Image
#from diffusers.utils import make_image_grid, load_image
import uuid

base_url=f'https://omnibus-top-20-img-img-tint.hf.space/file='
loaded_model=[]
for i,model in enumerate(models):
    try:
        loaded_model.append(gr.load(f'models/{model}'))
    except Exception as e:
        print(e)
        pass
print (loaded_model)

#pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None, variant="fp16", use_safetensors=True).to("cpu")
#pipeline.unet = torch.compile(pipeline.unet)

grid_wide=10



def get_concat_h_cut(in1, in2):
    print(in1)
    print(in2)
    im1=Image.open(in1)
    im2=Image.open(in2)
    #im1=in1
    #im2=in2
    dst = Image.new('RGB', (im1.width + im2.width,
                            min(im1.height, im2.height)))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst


def get_concat_v_cut(in1, in2):
    print(in1)
    print(in2)
    im1=Image.open(in1)
    im2=Image.open(in2)
    #im1=in1
    #im2=in2
    dst = Image.new(
        'RGB', (min(im1.width, im2.width), im1.height + im2.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (0, im1.height))
    return dst






def load_model(model_drop):
    pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32, use_safetensors=True)




def run_dif_color(out_prompt,im_path,model_drop,tint,im_height,im_width):
    p_seed=""
    out_box=[]
    out_html=""
    im_height=int(im_height)
    im_width=int(im_width)
    #for i,ea in enumerate(im_path.root):
    cnt = 0
    for hh in range(int(im_height/grid_wide)):
        for b in range(int(im_width/grid_wide)):
            uid=uuid.uuid4()
        
            print(f'root::{im_path.root[cnt]}')
            #print(f'ea:: {ea}')        
            #print(f'impath:: {im_path.path}')
            url = base_url+im_path.root[cnt].image.path
            print(url)

            myimg = cv2.imread(im_path.root[cnt].image.path)
            avg_color_per_row = numpy.average(myimg, axis=0)
            avg_color = numpy.average(avg_color_per_row, axis=0)
            r,g,b= avg_color
            color = (int(r),int(g),int(b))
            print (color)
            rand=random.randint(1,500)
            for i in range(rand):
                p_seed+=" "
            try:
                #model=gr.load(f'models/{model[int(model_drop)]}')        
                model=loaded_model[int(model_drop)]
                out_img=model(out_prompt+p_seed)
                #print(out_img)
    
                raw=Image.open(out_img)
                raw=raw.convert('RGB')
    
                colorize = RGBTransform().mix_with(color,factor=float(tint)).applied_to(raw)    
                print (colorize)
                colorize.save(f'tmp-{uid}.png')
                #out_box.append(f'tmp-{uid}.png')           
                out_box.append(f'tmp-{uid}.png') 

                print(f'out_box:: {out_box}')
                if out_box:
                    if len(out_box)>1:
                        #im_roll = get_concat_v_cut(f'{out_box[0]}',f'{out_box[1]}')
                        #im_roll.save(f'comb-{uid}-tmp.png')                
                        for i in range(2,len(out_box)):
                            im_roll = get_concat_h_cut(f'comb-{uid}-tmp.png',f'{out_box[i]}')
                            im_roll.save(f'comb-{uid}-tmp.png')
                        out = f'comb-{uid}-tmp.png'
                    else:
                        tmp_im = Image.open(out_box[0])
                        #tmp_im = out_box[0]
                        tmp_im.save(f'comb-{uid}-tmp.png')
                        out = f'comb-{uid}-tmp.png'                   
                    yield out,out_html
            
            
            except Exception as e:
                print(e)
                out_html=str(e)
                pass
 
            cnt+=1
       
        yield out,out_html
















def run_dif(prompt,im_path,model_drop,cnt,strength,guidance,infer,im_height,im_width):
    uid=uuid.uuid4()
    print(f'im_path:: {im_path}')
    print(f'im_path0:: {im_path.root[0]}')
    print(f'im_path0.image.path:: {im_path.root[0].image.path}')
    out_box=[]
    im_height=int(im_height)
    im_width=int(im_width)
    
    for i,ea in enumerate(im_path.root):
        for hh in range(int(im_height/grid_wide)):
            for b in range(int(im_width/grid_wide)):
            
                print(f'root::{im_path.root[i]}')
                #print(f'ea:: {ea}')        
                #print(f'impath:: {im_path.path}')
                url = base_url+im_path.root[i].image.path
                print(url)
                #init_image = load_image(url)
                init_image=load_image(url)
                #prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
                
                # pass prompt and image to pipeline
                #image = pipeline(prompt, image=init_image, strength=0.8,guidance_scale=8.0,negative_prompt=negative_prompt,num_inference_steps=50).images[0]
                image = pipeline(prompt, image=init_image, strength=float(strength),guidance_scale=float(guidance),num_inference_steps=int(infer)).images[0]
                #make_image_grid([init_image, image], rows=1, cols=2)
                
                
                out_box.append(image)

                if out_box:
                    if len(out_box)>1:
                        im_roll = get_concat_v_cut(f'{out_box[0]}',f'{out_box[1]}')
                        im_roll.save(f'comb-{uid}-tmp.png')                
                        for i in range(2,len(out_box)):
                            im_roll = get_concat_v_cut(f'comb-{uid}-tmp.png',f'{out_box[i]}')
                            im_roll.save(f'comb-{uid}-tmp.png')
                        out = f'comb-{uid}-tmp.png'
                    else:
                        #tmp_im = Image.open(out_box[0])
                        tmp_im = out_box[0]
                        tmp_im.save(f'comb-{uid}-tmp.png')
                        out = f'comb-{uid}-tmp.png'



                
                yield out,""




def run_dif_old(out_prompt,model_drop,cnt):
    p_seed=""
    out_box=[]
    out_html=""
    #for i,ea in enumerate(loaded_model):
    for i in range(int(cnt)):
        p_seed+=" "
        try:
            model=loaded_model[int(model_drop)]
            out_img=model(out_prompt+p_seed)
            print(out_img)
            out_box.append(out_img)           
        except Exception as e:
            print(e)
            out_html=str(e)
            pass
        yield out_box,out_html

def run_dif_og(out_prompt,model_drop,cnt):
    out_box=[]
    out_html=""
    #for i,ea in enumerate(loaded_model):
    for i in range(cnt):
        try:
            #print (ea)
            model=loaded_model[int(model_drop)]
            out_img=model(out_prompt)
            print(out_img)
            url=f'https://omnibus-top-20.hf.space/file={out_img}'
            print(url)
            uid = uuid.uuid4()
            #urllib.request.urlretrieve(image, 'tmp.png')
            #out=Image.open('tmp.png')
            r = requests.get(url, stream=True)
            
            if r.status_code == 200:
                img_buffer = io.BytesIO(r.content)
                print (f'bytes:: {io.BytesIO(r.content)}')
                str_equivalent_image = base64.b64encode(img_buffer.getvalue()).decode()
                img_tag = "<img src='data:image/png;base64," + str_equivalent_image + "'/>"                
                out_html+=f"<div  class='img_class'><a href='https://huggingface.co/models/{models[i]}'>{models[i]}</a><br>"+img_tag+"</div>"
                out = Image.open(io.BytesIO(r.content))
                out_box.append(out)
            html_out = "<div class='grid_class'>"+out_html+"</div>"
            yield out_box,html_out
        except Exception as e:
            out_html+=str(e)
            html_out = "<div class='grid_class'>"+out_html+"</div>"
            
            yield out_box,html_out

def thread_dif(out_prompt,mod):
    out_box=[]
    out_html=""
    #for i,ea in enumerate(loaded_model):
    try:
        print (ea)
        model=loaded_model[int(mod)]
        out_img=model(out_prompt)
        print(out_img)
        url=f'https://omnibus-top-20.hf.space/file={out_img}'
        print(url)
        uid = uuid.uuid4()
        #urllib.request.urlretrieve(image, 'tmp.png')
        #out=Image.open('tmp.png')
        r = requests.get(url, stream=True)
        
        if r.status_code == 200:
            img_buffer = io.BytesIO(r.content)
            print (f'bytes:: {io.BytesIO(r.content)}')
            str_equivalent_image = base64.b64encode(img_buffer.getvalue()).decode()
            img_tag = "<img src='data:image/png;base64," + str_equivalent_image + "'/>"                
            
            #out_html+=f"<div  class='img_class'><a href='https://huggingface.co/models/{models[i]}'>{models[i]}</a><br>"+img_tag+"</div>"
            out = Image.open(io.BytesIO(r.content))
            out_box.append(out)
        else:
            out_html=r.status_code
        html_out = "<div class='grid_class'>"+out_html+"</div>"
        return out_box,html_out
    except Exception as e:
        out_html=str(e)
        #out_html+=str(e)
        html_out = "<div class='grid_class'>"+out_html+"</div>"
        
        return out_box,html_out


css="""
.grid_class{
display:flex;
height:100%;
}
.img_class{
min-width:200px;
}

"""

def load_im(img):
    im_box=[]
    im = Image.open(img)
    width, height = im.size
    new_w=int(width/grid_wide)
    new_h=new_w
    w=0
    h=0
    newsize=(512,512)
    for i in range(int(height/new_h)):
        print(i)
        for b in range(grid_wide):
            print(b)
            # Setting the points for cropped image
            left = w
            top = h
            right = left+new_w
            bottom = top+new_h
             
            # Cropped image of above dimension
            # (It will not change original image)
            im1 = im.crop((left, top, right, bottom))
            #im1 = im1.resize(newsize)
            
            im_box.append(im1)
            w+=new_w
        #yield im_box,[]
        h+=new_h
        w=0
    yield im_box,im_box,height,width
with gr.Blocks(css=css) as app:
    with gr.Row():
        with gr.Column():
            inp=gr.Textbox(label="Prompt")
            strength=gr.Slider(label="Strength",minimum=0,maximum=1,step=0.1,value=0.2)
            guidance=gr.Slider(label="Guidance",minimum=0,maximum=10,step=0.1,value=8.0)
            infer=gr.Slider(label="Inference Steps",minimum=0,maximum=50,step=1,value=10)
            tint = gr.Slider(label="Tint Strength", minimum=0, maximum=1, step=0.01, value=0.30)

            with gr.Row():
                btn=gr.Button()
                stop_btn=gr.Button("Stop")
        with gr.Column():
            inp_im=gr.Image(type='filepath')
            im_btn=gr.Button("Image Grid")
    with gr.Row():
        model_drop=gr.Dropdown(label="Models", choices=models, type='index', value=models[0])
        cnt = gr.Number(value=1)
        
    out_html=gr.HTML()
    outp=gr.Gallery(columns=grid_wide)
    #fingal=gr.Gallery(columns=grid_wide)
    fin=gr.Image()
    im_height=gr.Number()
    im_width=gr.Number()
    
    im_list=gr.Textbox(visible=False)    
    im_btn.click(load_im,inp_im,[outp,im_list,im_height,im_width])

    go_btn=btn.click(run_dif_color,[inp,outp,model_drop,tint,im_height,im_width],[fin,out_html])
    
    #go_btn = btn.click(run_dif_color,[inp,outp,model_drop,cnt,strength,guidance,infer,im_height,im_width],[fin,out_html])
    stop_btn.click(None,None,None,cancels=[go_btn])
app.queue().launch()