import os
from io import BytesIO

import gradio as gr
import numpy as np
import replicate
import requests
from PIL import Image

model_id = "stability-ai/sdxl:2b017d9b67edd2ee1401238df49d75da53c523f36e363881e057f5dc3ed3c5b2"
default_steps = 25
def generate(prompt: str, secret_key: str):
    """
    プロンプトから生成画像(PIL.Image.open)を取得
    """
    if secret_key == os.environ["SECRET_KEY"]:
        output = replicate.run(
            model_id,
            input={
                "prompt": prompt,
                "num_inference_steps": default_steps,
            },
        )
        # リンク取得
        png_link = output[0]
        # PNGファイルをリンクから取得
        response = requests.get(png_link)
        # イメージをメモリ上に開く
        img = Image.open(BytesIO(response.content))
        return img


examples = [
    ["station"],
    ["station, ghibli style"],
    ["Elon Musk"],
    ["Elon Musk playing Shogi"],
    # ["An astronaut riding a rainbow unicorn, cinematic, dramatic", ""],
    # ["A robot painted as graffiti on a brick wall. a sidewalk is in front of the wall, and grass is growing out of cracks in the concrete.", ""],
    # ["Panda mad scientist mixing sparkling chemicals, artstation.", ""],
    # ["An astronaut riding a rainbow unicorn, cinematic, dramatic"],
    # [
    #     "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography."
    # ],
    # ["a giant monster hybrid of dragon and spider, in dark dense foggy forest"],
    # [
    #     "a man in a space suit playing a piano, highly detailed illustration, full color illustration, very detailed illustration"
    # ],
]

with gr.Blocks(title="Stable Diffusion XL (SDXL 1.0)") as demo:
    with gr.Row():
        with gr.Column(scale=1, min_width=600):
            gr_prompt = gr.Textbox(label="プロンプト")
            gr_password = gr.Textbox(label="パスワード")
            gr_generate_button = gr.Button("生成")
            # with gr.Accordion("advanced settings", open=False):
            #     gr_steps = gr.Number(label="steps", value=default_steps)
            #     gr_seed = gr.Number(label="seed", value=-1)
        with gr.Column(scale=1, min_width=600):
            gr_image = gr.Image()
        # examples=examples
        gr_generate_button.click(
            generate,
            inputs=[gr_prompt, gr_password],
            outputs=[gr_image],
        )
    with gr.Row():
        gr.Examples(examples, inputs=[gr_prompt], label="プロンプト例")

demo.launch()