Spaces:
Running
Running
import os | |
import io | |
import json | |
import base64 | |
import random | |
import urllib.request | |
import urllib.parse | |
import websocket | |
import uuid | |
from dotenv import load_dotenv | |
from flask import Flask, request, jsonify, render_template | |
from PIL import Image | |
# Load environment variables from the .env file | |
load_dotenv() | |
# Initialize Flask app | |
app = Flask(__name__) | |
# Set server and websocket addresses from environment variables | |
server_address = os.getenv("SERVER_ADDRESS") | |
ws_address = os.getenv("WS_ADDRESS") | |
# Generate a unique client ID | |
client_id = str(uuid.uuid4()) | |
def make_request(url, data=None, headers=None): | |
req = urllib.request.Request(url, data=data, headers=headers) | |
with urllib.request.urlopen(req) as response: | |
return json.loads(response.read()) | |
def queue_prompt(prompt, token): | |
payload = {"prompt": prompt, "client_id": client_id} | |
data = json.dumps(payload).encode('utf-8') | |
headers = { | |
'Authorization': f'Bearer {token}', | |
'Content-Type': 'application/json' | |
} | |
return make_request(f"{server_address}/prompt", data=data, headers=headers) | |
def get_image(filename, subfolder, image_type, token): | |
url_values = {'filename': filename, 'subfolder': subfolder, 'type': image_type} | |
url = f"{server_address}/view?{urllib.parse.urlencode(url_values)}" | |
req = urllib.request.Request(url) | |
req.add_header("Authorization", f"Bearer {token}") | |
try: | |
return urllib.request.urlopen(req).read() | |
except urllib.error.HTTPError as e: | |
print(f"HTTP Error: {e.code} - {e.reason}") | |
print(e.read()) | |
raise | |
def get_history(prompt_id, token): | |
headers = { | |
'Authorization': f'Bearer {token}', | |
'Content-Type': 'application/json' | |
} | |
return make_request(f"{server_address}/history/{prompt_id}", headers=headers) | |
def get_images(ws, prompt, token): | |
prompt_id = queue_prompt(prompt, token)['prompt_id'] | |
output_images = {} | |
while True: | |
out = ws.recv() | |
if isinstance(out, str): | |
message = json.loads(out) | |
if message['type'] == 'executing': | |
data = message['data'] | |
if data['node'] is None and data['prompt_id'] == prompt_id: | |
break # Execution is done | |
history = get_history(prompt_id, token)[prompt_id] | |
for node_id in history['outputs']: | |
node_output = history['outputs'][node_id] | |
images_output = [] | |
if 'images' in node_output: | |
for image in node_output['images']: | |
image_data = get_image(image['filename'], image['subfolder'], image['type'], token) | |
images_output.append(image_data) | |
output_images[node_id] = images_output | |
return output_images | |
# Default route for home welcome | |
def home(): | |
return render_template('home.html') | |
# Generate image route | |
def generate_image(): | |
data = request.json | |
# Extract the token from the request headers | |
# token = request.headers.get('Authorization') | |
token = "Bearer JDJiJDEyJDgwbUJwQTFrQ0JYdS9lR2R4ZEZWdmV3WS9VTmlCeHNtc2txbnBITjR4Qm96ZmFnVUkvNDlh" | |
if token is None: | |
return jsonify({'error': 'No token provided'}), 400 | |
if token.startswith("Bearer "): | |
token = token.split(" ")[1] | |
# Base64 decode the encoded token | |
token = base64.b64decode(token).decode("utf-8") | |
if 'text_prompt' not in data: | |
return jsonify({'error': 'No text prompt provided'}), 400 | |
text_prompt = data['text_prompt'] | |
# Get the path to the current file's directory | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
file_path = os.path.join(current_dir, 'workflow_api.json') | |
with open(file_path, 'r', encoding='utf-8') as file: | |
workflow_jsondata = file.read() | |
prompt = json.loads(workflow_jsondata) | |
prompt["6"]["inputs"]["text"] = text_prompt | |
prompt["7"]["inputs"]["text"] = "text, watermark, low quality, extra hands, extra legs." | |
seednum = random.randint(1, 9999999999999) | |
prompt["3"]["inputs"]["seed"] = seednum | |
ws = websocket.WebSocket() | |
ws.connect(f"{ws_address}?clientId={client_id}&token={token}") | |
images = get_images(ws, prompt, token) | |
ws.close() | |
output_images_base64 = [] | |
for node_id in images: | |
for image_data in images[node_id]: | |
image = Image.open(io.BytesIO(image_data)) | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
output_images_base64.append(img_str) | |
return jsonify({'images': output_images_base64}) | |
def get_image_file(filename): | |
return send_file(filename, mimetype='image/png') | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860) # Removed 'debug=True' | |