import os import io import json import base64 import random import urllib.request import urllib.parse import websocket import uuid from dotenv import load_dotenv from flask import Flask, request, jsonify, render_template, send_file from PIL import Image from werkzeug.utils import secure_filename # Load environment variables from the .env file load_dotenv() # Initialize Flask app app = Flask(__name__) ALLOWED_EXTENSIONS = {'jpg', 'jpeg', 'png', 'webp'} # Define supported image types # Set server and websocket addresses from environment variables server_address = os.getenv("SERVER_ADDRESS") ws_address = os.getenv("WS_ADDRESS") # Generate a unique client ID client_id = str(uuid.uuid4()) def allowed_file(filename): """Check if the uploaded file has an allowed extension.""" return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def save_base64_image(b64_string): """Decode a base64 string and save it as an image.""" header, encoded = b64_string.split(',', 1) # Handle data URI schemes if provided image_data = base64.b64decode(encoded) # Determine image extension from data URI or use a default one ext = header.split('/')[1].split(';')[0] if '/' in header else 'png' image_path = f"/tmp/{uuid.uuid4()}.{ext}" with open(image_path, 'wb') as f: f.write(image_data) return image_path def make_request(url, data=None, headers=None): req = urllib.request.Request(url, data=data, headers=headers) try: with urllib.request.urlopen(req) as response: response_body = response.read().decode() # Decode the response # print(response_body) return json.loads(response_body) # Convert to JSON if valid except urllib.error.HTTPError as e: print(f"HTTPError: {e.code}, {e.reason}") print(e.read().decode()) # Print detailed error response except urllib.error.URLError as e: print(f"URLError: {e.reason}") def queue_prompt(prompt, token): payload = {"prompt": prompt, "client_id": client_id} data = json.dumps(payload).encode('utf-8') headers = { 'Authorization': f'Bearer {token}', 'Content-Type': 'application/json' } return make_request(f"{server_address}/prompt", data=data, headers=headers) def get_image(filename, subfolder, image_type, token): url_values = {'filename': filename, 'subfolder': subfolder, 'type': image_type} url = f"{server_address}/view?{urllib.parse.urlencode(url_values)}" req = urllib.request.Request(url) req.add_header("Authorization", f"Bearer {token}") try: return urllib.request.urlopen(req).read() except urllib.error.HTTPError as e: print(f"HTTP Error: {e.code} - {e.reason}") print(e.read()) raise def get_history(prompt_id, token): headers = { 'Authorization': f'Bearer {token}', 'Content-Type': 'application/json' } return make_request(f"{server_address}/history/{prompt_id}", headers=headers) def get_images(ws, prompt, token): prompt_id = queue_prompt(prompt, token)['prompt_id'] output_images = {} while True: out = ws.recv() if isinstance(out, str): message = json.loads(out) if message['type'] == 'executing': data = message['data'] if data['node'] is None and data['prompt_id'] == prompt_id: break # Execution is done history = get_history(prompt_id, token)[prompt_id] for node_id in history['outputs']: node_output = history['outputs'][node_id] images_output = [] if 'images' in node_output: for image in node_output['images']: image_data = get_image(image['filename'], image['subfolder'], image['type'], token) images_output.append(image_data) output_images[node_id] = images_output return output_images def fetch_video(video_data, token): video_url = f"{server_address}/download?file={video_data['filename']}" req = urllib.request.Request(video_url) req.add_header("Authorization", f"Bearer {token}") return urllib.request.urlopen(req).read() # Default route for home welcome @app.route('/') def home(): return render_template('home.html') # Generate image route @app.route('/generate_image', methods=['POST']) def generate_image(): data = request.json # Extract the token from the request headers token = request.headers.get('Authorization') # token = "Bearer JDJiJDEyJDgwbUJwQTFrQ0JYdS9lR2R4ZEZWdmV3WS9VTmlCeHNtc2txbnBITjR4Qm96ZmFnVUkvNDlh" if token is None: return jsonify({'error': 'No token provided'}), 400 if token.startswith("Bearer "): token = token.split(" ")[1] # Base64 decode the encoded token token = base64.b64decode(token).decode("utf-8") if 'text_prompt' not in data: return jsonify({'error': 'No text prompt provided'}), 400 text_prompt = data['text_prompt'] # Get the path to the current file's directory current_dir = os.path.dirname(os.path.abspath(__file__)) file_path = os.path.join(current_dir, 'workflows/flux1_dev_checkpoint_workflow_api.json') with open(file_path, 'r', encoding='utf-8') as file: workflow_jsondata = file.read() prompt = json.loads(workflow_jsondata) prompt["6"]["inputs"]["text"] = text_prompt # prompt["7"]["inputs"]["text"] = "text, watermark, low quality, extra hands, extra legs." # seednum = random.randint(1, 9999999999999) # prompt["3"]["inputs"]["seed"] = seednum ####################### # For model Flux1.dev # ####################### # Generate a random 15-digit seed as an integer seednum = random.randint(100000000000000, 999999999999999) prompt["31"]["inputs"]["seed"] = seednum ws = websocket.WebSocket() try: ws.connect(f"{ws_address}?clientId={client_id}&token={token}", header= {"Authorization": f"Bearer {token}"}) except websocket.WebSocketException as e: return jsonify({'error': f'WebSocket connection failed: {str(e)}'}), 500 images = get_images(ws, prompt, token) ws.close() output_images_base64 = [] for node_id in images: for image_data in images[node_id]: image = Image.open(io.BytesIO(image_data)) buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") output_images_base64.append(img_str) return jsonify({'images': output_images_base64}) # Get image route @app.route('/get_image/', methods=['GET']) def get_image_file(filename): return send_file(filename, mimetype='image/png') # Generate image to video route @app.route('/image_to_video', methods=['POST']) def image_to_video(): # return "Route works!" data = request.json # Extract token from headers token = request.headers.get('Authorization') if token is None: return jsonify({'error': 'No token provided'}), 400 if token.startswith("Bearer "): token = token.split(" ")[1] # token = base64.b64decode(token).decode("utf-8") # Extract text prompt text_prompt = data['text_prompt'] if not text_prompt: return jsonify({'error': 'Text prompt is required'}), 400 # Handle uploaded image or base64-encoded image image_file = request.files.get('image') base64_image = data['base64_image'] if image_file: # Validate and save uploaded image if not allowed_file(image_file.filename): return jsonify({'error': 'Unsupported image format'}), 400 filename = secure_filename(image_file.filename) image_path = f"/tmp/{uuid.uuid4()}_{filename}" image_file.save(image_path) elif base64_image: # Save base64-encoded image try: image_path = save_base64_image(base64_image) except Exception as e: return jsonify({'error': f'Invalid base64 image data: {str(e)}'}), 400 else: return jsonify({'error': 'Image is required (either file or base64)'}), 400 # Get the path to the workflow configuration file current_dir = os.path.dirname(os.path.abspath(__file__)) file_path = os.path.join(current_dir, 'workflows/cogvideox_image_to_video_workflow_api.json') print(f"Modified workflow: {file_path}", flush=True) # Load and modify workflow with open(file_path, 'r', encoding='utf-8') as file: workflow = json.load(file) workflow["30"]["inputs"]["prompt"] = text_prompt # Text prompt workflow["36"]["inputs"]["upload"] = image_path # Image path workflow["31"]["inputs"]["prompt"] = "Low quality, watermark, strange motion" # Negative prompt seed = random.randint(1e14, 9e14) workflow["57"]["inputs"]["seed"] = seed # Set reproducibility seed # WebSocket connection to trigger workflow ws = websocket.WebSocket() ws.connect(f"{ws_address}?clientId={client_id}&token={token}", header= {"Authorization": f"Bearer {token}"}) ws.send(json.dumps({"workflow": workflow})) # Send the modified workflow # Receive video processing result while True: out = ws.recv() message = json.loads(out) if message.get('type') == 'completed': video_data = message['data'] break # Fetch and return the generated video video_content = fetch_video(video_data, token) return send_file( io.BytesIO(video_content), mimetype='video/mp4', as_attachment=True, download_name='generated_video.mp4' ) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860) # Removed 'debug=True'