from contextlib import nullcontext
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import urllib, urllib.request
import os
from xml.etree import ElementTree
import random
import re
from typing import List


pokemon_types = ["Normal",
                 "Water",
                 "Fire",
                 "Ice",
                 "Psychic",
                 "Rock",
                 "Dark",
                 "Electric",
                 "Grass",
                 "Fighting",
                 "Poison",
                 "Ground",
                 "Flying",
                 "Bug",
                 "Ghost",
                 "Dragon",
                 "Steel",
                 "Fairy"
                 ]

type_choices=["None", "Random"]
type_choices.extend(pokemon_types)

paper_name = None

device = "cuda" if torch.cuda.is_available() else "cpu"
context = autocast if device == "cuda" else nullcontext
dtype = torch.float16 if device == "cuda" else torch.float32

pipe = StableDiffusionPipeline.from_pretrained("lambdalabs/sd-pokemon-diffusers", torch_dtype=dtype)
pipe = pipe.to(device)


# Sometimes the nsfw checker is confused by the Pokémon images, you can disable
# it at your own risk here
disable_safety = True

if disable_safety:
  def null_safety(images, **kwargs):
      return images, False
  pipe.safety_checker = null_safety


def infer(prompt, n_samples, steps, scale):    
    with context("cuda"):
        images = pipe(n_samples*[prompt], guidance_scale=scale, num_inference_steps=steps).images

    return images

def get_paper_name(url: str):
    paper_id = os.path.basename(url)
    paper_id = paper_id.split(".pdf")[0]
    query_url = f"http://export.arxiv.org/api/query?id_list={paper_id}"
    hdr = { "Content-Type" : "application/atom+xml" }
    req = urllib.request.Request(query_url, headers=hdr)
    response = urllib.request.urlopen(req)
    tree = ElementTree.fromstring(response.read().decode("utf-8"))
    paper_title = tree.find("{http://www.w3.org/2005/Atom}entry").find("{http://www.w3.org/2005/Atom}title").text
    paper_title = paper_title.replace("\n", "")
    paper_title = re.sub(' +', ' ', paper_title)
    return paper_title
    


block = gr.Blocks()

examples = [
    [
        "https://arxiv.org/abs/1706.03762",
        2,
        7.5,
    ],
    [
        "https://arxiv.org/abs/1404.5997v2",
        2,
        7.5,
    ],
    [
        "https://arxiv.org/abs/2010.11929",
        2,
        7.5,
    ],
    [
        "https://arxiv.org/abs/1810.04805v2",
        2,
        7.5,
    ]
]

with block:
    gr.HTML(
        """
            <div style="text-align: center; max-width: 650px; margin: 50px auto;">
              <div>
                <h1 style="font-weight: 900; font-size: 3rem;">
                  Paper to Pokémon
                </h1>
              </div>
              <p style="margin-bottom: 10px; margin-top: 30px; font-size: 94%">
              Generate new Pokémon from an arXiv link. Just paste the link to the overview, the pdf or just give the ID of the paper. 
              
              It will create a prompt with the paper title, which you can then modify as you like or submit as it is.
              
              For general better quality increase the step size. (This will also increase the processing time)
              </p>
            </div>
        """
    )
    with gr.Group():
        with gr.Box():
            with gr.Row().style(mobile_collapse=False, equal_height=True):
                text = gr.Textbox(
                    label="Link or ID for paper",
                    show_label=False,
                    max_lines=1,
                    placeholder="Give arXiv link or ID for the paper",
                ).style(
                    border=(True, False, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )
                btn = gr.Button("Generate image").style(
                    margin=False,
                    rounded=(False, True, True, False),
                )
        poke_type = gr.Radio(choices=type_choices, value="None", label="Pokemon Type")
        
        prompt_ideas = gr.CheckboxGroup(choices=["as a bird", 
                                                 "with four legs", 
                                                 "with wings", 
                                                 "as a koala", 
                                                 "with a beak", 
                                                 "looking like a llama"],
                                        label="Additional prompt ideas")
        
        prompt_box = gr.Textbox(placeholder="Your prompt appears here", interactive=True, label="Prompt")

        gallery = gr.Gallery(
            label="Generated images", show_label=False, elem_id="gallery"
        ).style(grid=[2], height="auto")


        with gr.Row(elem_id="advanced-options"):
            samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
            steps = gr.Slider(label="Steps", minimum=5, maximum=50, value=25, step=5)
            scale = gr.Slider(
                label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
            )


        ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, scale], outputs=gallery, cache_examples=False)
        ex.dataset.headers = [""]

        def resolve_poke_type(pok_type: str):
            if pok_type == "None":
                return ""
            elif pok_type == "Random":
                idx = random.randint(0,len(pokemon_types)-1)
                return pokemon_types[idx]
            else:
                return pok_type
        
        def update_prompt_link(new_link: str, pok_type: str, prompt_ideas: List[str]):
            global paper_name
            paper_name = get_paper_name(new_link)
            pok_type = resolve_poke_type(pok_type)
            
            prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}"
            
            return build_prompt_text(paper_name, pok_type, prompt_ideas)

        def update_prompt_type(paper_link: str, pok_type: str, prompt_ideas: List[str]):
            global paper_name
            if paper_name is None:
                paper_name = get_paper_name(paper_link)

            pok_type = resolve_poke_type(pok_type)
            
            return build_prompt_text(paper_name, pok_type, prompt_ideas)
        
        def build_prompt_text(paper_name, pok_type, add_ideas):
            prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}"
            prompt_text = f"""{prompt_text} {" ".join(add_ideas)}"""
            return prompt_text
        
        text.change(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
        text.submit(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
        
        poke_type.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
        prompt_ideas.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box)
        
        
        btn.click(infer, inputs=[prompt_box, samples, steps, scale], outputs=gallery)
        gr.HTML(
            """
                <div class="footer" style="text-align: center; max-width: 650px; margin: 50px auto;">
                    <p>Inspired by and cloned from the great <a href="https://huggingface.co/spaces/lambdalabs/text-to-pokemon">
                    Text-to-Pokémon</a> space by Lambda labs</p>
                    <p> Gradio Demo by johko</p>
               </div>
           """
        )

block.launch()