Kandinsky-API / app.py
ehristoforu's picture
Update app.py
721a723 verified
raw
history blame
3.09 kB
import json
import time
import requests
import base64
from io import BytesIO
from PIL import Image
import gradio as gr
import os
api_key = os.getenv("api_key")
secret_key = os.getenv("secret_key")
class Text2ImageAPI:
def __init__(self, url, api_key, secret_key):
self.URL = url
self.AUTH_HEADERS = {
'X-Key': f'Key {api_key}',
'X-Secret': f'Secret {secret_key}',
}
def get_model(self):
response = requests.get(self.URL + 'key/api/v1/models', headers=self.AUTH_HEADERS)
data = response.json()
return data[0]['id']
def generate(self, prompt, width, height, model):
params = {
"type": "GENERATE",
"numImages": 1,
"width": width,
"height": height,
"censored": True,
"generateParams": {
"query": f"{prompt}"
}
}
data = {
'model_id': (None, model),
'params': (None, json.dumps(params), 'application/json')
}
response = requests.post(self.URL + 'key/api/v1/text2image/run', headers=self.AUTH_HEADERS, files=data)
data = response.json()
return data['uuid']
def check_generation(self, request_id, attempts=10, delay=10):
while attempts > 0:
response = requests.get(self.URL + 'key/api/v1/text2image/status/' + request_id, headers=self.AUTH_HEADERS)
data = response.json()
if data['status'] == 'DONE':
return data['images']
attempts -= 1
time.sleep(delay)
def api_gradio(prompt, width, height):
api = Text2ImageAPI('https://api-key.fusionbrain.ai/', api_key, secret_key)
model_id = api.get_model()
uuid = api.generate(prompt, width, height, model_id)
images = api.check_generation(uuid)
decoded_data = base64.b64decode(images[0])
image = Image.open(BytesIO(decoded_data))
return [image]
css = """
footer {
visibility: hidden
}
#generate_button {
color: white;
border-color: #007bff;
background: #2563eb;
}
#save_button {
color: white;
border-color: #028b40;
background: #01b97c;
width: 200px;
}
#settings_header {
background: rgb(245, 105, 105);
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# Kandinsky ```API DEMO```")
with gr.Row():
prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=3, lines=1, interactive=True, scale=20)
button = gr.Button(value="Generate", scale=1)
with gr.Accordion("Advanced options", open=False):
with gr.Row():
width = gr.Slider(label="Width", minimum=512, maximum=1024, step=8, value=768, interactive=True)
height = gr.Slider(label="Height", minimum=512, maximum=1024, step=8, value=768, interactive=True)
with gr.Row():
gallery = gr.Gallery(show_label=False, rows=1, columns=1, allow_preview=True, preview=True)
button.click(api_gradio, inputs=[prompt, width, height], outputs=gallery)
demo.queue().launch(show_api=False)