import os import io import json import base64 import random import urllib.request import urllib.parse import websocket import uuid from dotenv import load_dotenv from flask import Flask, request, jsonify, render_template from PIL import Image # Load environment variables from the .env file load_dotenv() # Initialize Flask app app = Flask(__name__) # Set server and websocket addresses from environment variables server_address = os.getenv("SERVER_ADDRESS") ws_address = os.getenv("WS_ADDRESS") # Generate a unique client ID client_id = str(uuid.uuid4()) def make_request(url, data=None, headers=None): req = urllib.request.Request(url, data=data, headers=headers) with urllib.request.urlopen(req) as response: return json.loads(response.read()) def queue_prompt(prompt, token): payload = {"prompt": prompt, "client_id": client_id} data = json.dumps(payload).encode('utf-8') headers = { 'Authorization': f'Bearer {token}', 'Content-Type': 'application/json' } return make_request(f"{server_address}/prompt", data=data, headers=headers) def get_image(filename, subfolder, image_type, token): url_values = {'filename': filename, 'subfolder': subfolder, 'type': image_type} url = f"{server_address}/view?{urllib.parse.urlencode(url_values)}" req = urllib.request.Request(url) req.add_header("Authorization", f"Bearer {token}") try: return urllib.request.urlopen(req).read() except urllib.error.HTTPError as e: print(f"HTTP Error: {e.code} - {e.reason}") print(e.read()) raise def get_history(prompt_id, token): headers = { 'Authorization': f'Bearer {token}', 'Content-Type': 'application/json' } return make_request(f"{server_address}/history/{prompt_id}", headers=headers) def get_images(ws, prompt, token): prompt_id = queue_prompt(prompt, token)['prompt_id'] output_images = {} while True: out = ws.recv() if isinstance(out, str): message = json.loads(out) if message['type'] == 'executing': data = message['data'] if data['node'] is None and data['prompt_id'] == prompt_id: break # Execution is done history = get_history(prompt_id, token)[prompt_id] for node_id in history['outputs']: node_output = history['outputs'][node_id] images_output = [] if 'images' in node_output: for image in node_output['images']: image_data = get_image(image['filename'], image['subfolder'], image['type'], token) images_output.append(image_data) output_images[node_id] = images_output return output_images # Default route for home welcome @app.route('/') def home(): return render_template('home.html') # Generate image route @app.route('/generate_image', methods=['POST']) def generate_image(): data = request.json # Extract the token from the request headers token = request.headers.get('Authorization') # token = "Bearer JDJiJDEyJDgwbUJwQTFrQ0JYdS9lR2R4ZEZWdmV3WS9VTmlCeHNtc2txbnBITjR4Qm96ZmFnVUkvNDlh" if token is None: return jsonify({'error': 'No token provided'}), 400 if token.startswith("Bearer "): token = token.split(" ")[1] # Base64 decode the encoded token token = base64.b64decode(token).decode("utf-8") if 'text_prompt' not in data: return jsonify({'error': 'No text prompt provided'}), 400 text_prompt = data['text_prompt'] # Get the path to the current file's directory current_dir = os.path.dirname(os.path.abspath(__file__)) file_path = os.path.join(current_dir, 'workflows/flux1_dev_checkpoint_workflow_api.json') with open(file_path, 'r', encoding='utf-8') as file: workflow_jsondata = file.read() prompt = json.loads(workflow_jsondata) prompt["6"]["inputs"]["text"] = text_prompt # prompt["7"]["inputs"]["text"] = "text, watermark, low quality, extra hands, extra legs." # seednum = random.randint(1, 9999999999999) # prompt["3"]["inputs"]["seed"] = seednum # For model Flux1.dev # Generate a random 15-digit seed as an integer seednum = random.randint(100000000000000, 999999999999999) prompt["31"]["inputs"]["seed"] = seednum ws = websocket.WebSocket() ws.connect(f"{ws_address}?clientId={client_id}&token={token}") images = get_images(ws, prompt, token) ws.close() output_images_base64 = [] for node_id in images: for image_data in images[node_id]: image = Image.open(io.BytesIO(image_data)) buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") output_images_base64.append(img_str) return jsonify({'images': output_images_base64}) @app.route('/get_image/', methods=['GET']) def get_image_file(filename): return send_file(filename, mimetype='image/png') if __name__ == '__main__': app.run(host='0.0.0.0', port=7860) # Removed 'debug=True'