ehristoforu commited on
Commit
7bdc0ee
·
verified ·
1 Parent(s): 2d13d6c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+
4
+ import requests
5
+ import base64
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ import gradio as gr
9
+ import os
10
+
11
+ api_key = os.getenv("api_key")
12
+ secret_key = os.getenv("secret_key")
13
+
14
+ class Text2ImageAPI:
15
+
16
+ def __init__(self, url, api_key, secret_key):
17
+ self.URL = url
18
+ self.AUTH_HEADERS = {
19
+ 'X-Key': f'Key {api_key}',
20
+ 'X-Secret': f'Secret {secret_key}',
21
+ }
22
+
23
+ def get_model(self):
24
+ response = requests.get(self.URL + 'key/api/v1/models', headers=self.AUTH_HEADERS)
25
+ data = response.json()
26
+ return data[0]['id']
27
+
28
+ def generate(self, prompt, width, height, model):
29
+ params = {
30
+ "type": "GENERATE",
31
+ "numImages": 1,
32
+ "width": width,
33
+ "height": height,
34
+ "censored": true,
35
+ "generateParams": {
36
+ "query": f"{prompt}"
37
+ }
38
+ }
39
+
40
+ data = {
41
+ 'model_id': (None, model),
42
+ 'params': (None, json.dumps(params), 'application/json')
43
+ }
44
+ response = requests.post(self.URL + 'key/api/v1/text2image/run', headers=self.AUTH_HEADERS, files=data)
45
+ data = response.json()
46
+ return data['uuid']
47
+
48
+ def check_generation(self, request_id, attempts=10, delay=10):
49
+ while attempts > 0:
50
+ response = requests.get(self.URL + 'key/api/v1/text2image/status/' + request_id, headers=self.AUTH_HEADERS)
51
+ data = response.json()
52
+ if data['status'] == 'DONE':
53
+ return data['images']
54
+
55
+ attempts -= 1
56
+ time.sleep(delay)
57
+
58
+
59
+ def api_gradio(prompt, width, height):
60
+ api = Text2ImageAPI('https://api-key.fusionbrain.ai/', api_key, secret_key)
61
+ model_id = api.get_model()
62
+ uuid = api.generate(prompt, width, height, model_id)
63
+ images = api.check_generation(uuid)
64
+
65
+ decoded_data = base64.b64decode(images[0])
66
+ image = Image.open(BytesIO(decoded_data))
67
+
68
+ return [image]
69
+
70
+ css = """
71
+ footer {
72
+ visibility: hidden
73
+ }
74
+ #generate_button {
75
+ color: white;
76
+ border-color: #007bff;
77
+ background: #2563eb;
78
+ }
79
+ #save_button {
80
+ color: white;
81
+ border-color: #028b40;
82
+ background: #01b97c;
83
+ width: 200px;
84
+ }
85
+ #settings_header {
86
+ background: rgb(245, 105, 105);
87
+ }
88
+ """
89
+
90
+ with gr.Blocks(css=css) as demo:
91
+ gr.Markdown("# Kandinsky ```API DEMO```")
92
+ with gr.Row():
93
+ prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=3, lines=1, interactive=True, scale=20)
94
+ button = gr.Button(value="Generate", scale=1)
95
+ with gr.Accordion("Advanced options", open=False):
96
+ with gr.Row():
97
+ width = gr.Slider(label="Width", minimum=1024, maximum=2048, step=8, value=1024, interactive=True)
98
+ height = gr.Slider(label="Height", minimum=1024, maximum=2048, step=8, value=1024, interactive=True)
99
+ with gr.Row():
100
+ gallery = gr.Gallery(show_label=False, rows=1, columns=1, allow_preview=True, preview=True)
101
+
102
+ button.click(api_gradio, inputs=[prompt, width, height], outputs=gallery)
103
+
104
+ demo.queue().launch(show_api=False)