|
import gradio as gr |
|
import requests |
|
import time |
|
from PIL import Image |
|
import os, io, json |
|
import base64 |
|
|
|
sd_api_base = os.environ["SD_API_BASE"] |
|
sd_api_key = os.environ["SD_API_KEY"] |
|
|
|
|
|
def send_post_request(input_json_string): |
|
|
|
try: |
|
|
|
data = json.loads(input_json_string) |
|
except json.JSONDecodeError as e: |
|
return f"输入的字符串不是有效的JSON格式: {e}" |
|
|
|
url = f"{sd_api_base}/txt2img/run/" |
|
headers = { |
|
'Content-Type': 'application/json', |
|
'Authorization': f'Bearer {sd_api_key}', |
|
} |
|
response = requests.post(url, headers=headers, json=data) |
|
if response.status_code == 200: |
|
return response.json() |
|
else: |
|
raise Exception(f"Error in POST request: {response.text}") |
|
|
|
|
|
def poll_status(id): |
|
url = f"{sd_api_base}/txt2img/status/{id}" |
|
headers = { |
|
'Content-Type': 'application/json', |
|
'Authorization': f'Bearer {sd_api_key}', |
|
} |
|
while True: |
|
response = requests.get(url, headers=headers) |
|
if response.status_code == 200: |
|
result = response.json() |
|
if result['status'] == 'COMPLETED': |
|
return result |
|
else: |
|
time.sleep(1) |
|
else: |
|
raise Exception(f"Error in GET request: {response.text}") |
|
|
|
|
|
def display_images(output_json): |
|
images_data = output_json['output']['images'] |
|
images = [] |
|
for base64_data in images_data: |
|
image_data = base64.b64decode(base64_data) |
|
image = Image.open(io.BytesIO(image_data)) |
|
images.append(image) |
|
return images |
|
|
|
|
|
def gradio_interface(input_json): |
|
post_response = send_post_request(input_json) |
|
print(post_response) |
|
status_response = poll_status(post_response['id']) |
|
images = display_images(status_response) |
|
|
|
return images |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=gr.Textbox(lines=2, placeholder="Type something here..."), |
|
outputs="gallery" |
|
|
|
) |
|
|
|
|
|
iface.launch() |
|
|