dreampal / app.py
zhuguangbin
draft
9c7472b
raw
history blame
2.3 kB
import gradio as gr
import requests
import time
from PIL import Image
import os, io, json
import base64
sd_api_base = os.environ["SD_API_BASE"]
sd_api_key = os.environ["SD_API_KEY"]
# 发送POST请求的函数
def send_post_request(input_json_string):
try:
# 尝试将输入的字符串转换为JSON对象
data = json.loads(input_json_string)
except json.JSONDecodeError as e:
return f"输入的字符串不是有效的JSON格式: {e}"
url = f"{sd_api_base}/txt2img/run/"
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {sd_api_key}',
}
response = requests.post(url, headers=headers, json=data)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Error in POST request: {response.text}")
# 轮询GET请求,直到异步操作完成
def poll_status(id):
url = f"{sd_api_base}/txt2img/status/{id}"
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {sd_api_key}',
}
while True:
response = requests.get(url, headers=headers)
if response.status_code == 200:
result = response.json()
if result['status'] == 'COMPLETED':
return result
else:
time.sleep(1) # 等待1秒后再次尝试
else:
raise Exception(f"Error in GET request: {response.text}")
# 将Base64编码的图片数据转换为可显示的图片
def display_images(output_json):
images_data = output_json['output']['images']
images = []
for base64_data in images_data:
image_data = base64.b64decode(base64_data)
image = Image.open(io.BytesIO(image_data))
images.append(image)
return images
# Gradio界面的函数
def gradio_interface(input_json):
post_response = send_post_request(input_json)
print(post_response)
status_response = poll_status(post_response['id'])
images = display_images(status_response)
return images
# 设置Gradio界面
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(lines=2, placeholder="Type something here..."),
outputs="gallery"
# examples=[{"prompt": "a dog"}]
)
# 启动Gradio应用程序
iface.launch()