ehristoforu commited on
Commit
8242674
1 Parent(s): a61e16b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import requests
4
+ import time
5
+ import json
6
+ import base64
7
+ import os
8
+ from PIL import Image
9
+ from io import BytesIO
10
+
11
+ class Prodia:
12
+ def __init__(self, api_key, base=None):
13
+ self.base = base or "https://api.prodia.com/v1"
14
+ self.headers = {
15
+ "X-Prodia-Key": api_key
16
+ }
17
+
18
+ def generate(self, params):
19
+ response = self._post(f"{self.base}/sd/generate", params)
20
+ return response.json()
21
+
22
+ def transform(self, params):
23
+ response = self._post(f"{self.base}/sd/transform", params)
24
+ return response.json()
25
+
26
+ def controlnet(self, params):
27
+ response = self._post(f"{self.base}/sd/controlnet", params)
28
+ return response.json()
29
+
30
+ def get_job(self, job_id):
31
+ response = self._get(f"{self.base}/job/{job_id}")
32
+ return response.json()
33
+
34
+ def wait(self, job):
35
+ job_result = job
36
+
37
+ while job_result['status'] not in ['succeeded', 'failed']:
38
+ time.sleep(0.25)
39
+ job_result = self.get_job(job['job'])
40
+
41
+ return job_result
42
+
43
+ def list_models(self):
44
+ response = self._get(f"{self.base}/models/list")
45
+ return response.json()
46
+
47
+ def _post(self, url, params):
48
+ headers = {
49
+ **self.headers,
50
+ "Content-Type": "application/json"
51
+ }
52
+ response = requests.post(url, headers=headers, data=json.dumps(params))
53
+
54
+ if response.status_code != 200:
55
+ raise Exception(f"Bad Prodia Response: {response.status_code}")
56
+
57
+ return response
58
+
59
+ def _get(self, url):
60
+ response = requests.get(url, headers=self.headers)
61
+
62
+ if response.status_code != 200:
63
+ raise Exception(f"Bad Prodia Response: {response.status_code}")
64
+
65
+ return response
66
+
67
+
68
+ def image_to_base64(image_path):
69
+ # Open the image with PIL
70
+ with Image.open(image_path) as image:
71
+ # Convert the image to bytes
72
+ buffered = BytesIO()
73
+ image.save(buffered, format="PNG") # You can change format to PNG if needed
74
+
75
+ # Encode the bytes to base64
76
+ img_str = base64.b64encode(buffered.getvalue())
77
+
78
+ return img_str.decode('utf-8') # Convert bytes to string
79
+
80
+
81
+
82
+ prodia_client = Prodia(api_key=os.getenv("PRODIA_API_KEY"))
83
+
84
+ def flip_text(prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed):
85
+ result = prodia_client.generate({
86
+ "prompt": prompt,
87
+ "negative_prompt": negative_prompt,
88
+ "model": model,
89
+ "steps": steps,
90
+ "sampler": sampler,
91
+ "cfg_scale": cfg_scale,
92
+ "width": width,
93
+ "height": height,
94
+ "seed": seed
95
+ })
96
+
97
+ job = prodia_client.wait(result)
98
+
99
+ return job["imageUrl"]
100
+
101
+ css = """
102
+ #generate {
103
+ height: 100%;
104
+ }
105
+ """
106
+
107
+ theme = "Base"
108
+
109
+ with gr.Blocks(css=css) as demo:
110
+
111
+
112
+
113
+ with gr.Column(scale=1):
114
+ gr.Markdown(elem_id="powered-by-prodia", value="AUTOMATIC1111 Stable Diffusion Web UI.<br>Powered by [Prodia](https://prodia.com).")
115
+
116
+
117
+ with gr.Row():
118
+ with gr.Column(scale=6, min_width=600):
119
+ prompt = gr.Textbox("puppies in a cloud, 4k", placeholder="Prompt", show_label=False, lines=3)
120
+ negative_prompt = gr.Textbox(placeholder="Negative Prompt", show_label=False, lines=3)
121
+ with gr.Column():
122
+ text_button = gr.Button("Generate", variant='primary', elem_id="generate")
123
+ with gr.Row():
124
+ with gr.Column(scale=6):
125
+ model = gr.Dropdown(interactive=True,value="v1-5-pruned-emaonly.safetensors [d7049739]", show_label=True, label="Model", choices=prodia_client.list_models())
126
+
127
+ with gr.Row():
128
+ with gr.Column(scale=3):
129
+ with gr.Tab("Generation"):
130
+ with gr.Row():
131
+ with gr.Column(scale=1):
132
+ sampler = gr.Dropdown(value="Euler a", show_label=True, label="Sampling Method", choices=[
133
+ "Euler",
134
+ "Euler a",
135
+ "LMS",
136
+ "Heun",
137
+ "DPM2",
138
+ "DPM2 a",
139
+ "DPM++ 2S a",
140
+ "DPM++ 2M",
141
+ "DPM++ SDE",
142
+ "DPM fast",
143
+ "DPM adaptive",
144
+ "LMS Karras",
145
+ "DPM2 Karras",
146
+ "DPM2 a Karras",
147
+ "DPM++ 2S a Karras",
148
+ "DPM++ 2M Karras",
149
+ "DPM++ SDE Karras",
150
+ "DDIM",
151
+ "PLMS",
152
+ ])
153
+
154
+ with gr.Column(scale=1):
155
+ steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=50, value=25, step=1)
156
+
157
+ with gr.Row():
158
+ with gr.Column(scale=1):
159
+ width = gr.Slider(label="Width", maximum=1024, value=512, step=8)
160
+ height = gr.Slider(label="Height", maximum=1024, value=512, step=8)
161
+
162
+ with gr.Column(scale=1):
163
+ batch_size = gr.Slider(label="Batch Size", maximum=1, value=1)
164
+ batch_count = gr.Slider(label="Batch Count", maximum=1, value=1)
165
+
166
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=1)
167
+ seed = gr.Number(label="Seed", value=-1)
168
+
169
+
170
+ with gr.Column(scale=2):
171
+ image_output = gr.Image()
172
+
173
+ text_button.click(flip_text, inputs=[prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed], outputs=image_output)
174
+
175
+ demo.queue(concurrency_count=10)
176
+ demo.launch()