Spaces:
Sleeping
Sleeping
import os | |
import io | |
import json | |
import base64 | |
import random | |
import urllib.request | |
import urllib.parse | |
import websocket | |
import uuid | |
from dotenv import load_dotenv | |
from flask import Flask, request, jsonify, render_template, send_file | |
from PIL import Image | |
from werkzeug.utils import secure_filename | |
# Load environment variables from the .env file | |
load_dotenv() | |
# Initialize Flask app | |
app = Flask(__name__) | |
ALLOWED_EXTENSIONS = {'jpg', 'jpeg', 'png', 'webp'} # Define supported image types | |
# Set server and websocket addresses from environment variables | |
server_address = os.getenv("SERVER_ADDRESS") | |
ws_address = os.getenv("WS_ADDRESS") | |
# Generate a unique client ID | |
client_id = str(uuid.uuid4()) | |
def allowed_file(filename): | |
"""Check if the uploaded file has an allowed extension.""" | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
def save_base64_image(b64_string): | |
"""Decode a base64 string and save it as an image.""" | |
header, encoded = b64_string.split(',', 1) # Handle data URI schemes if provided | |
image_data = base64.b64decode(encoded) | |
# Determine image extension from data URI or use a default one | |
ext = header.split('/')[1].split(';')[0] if '/' in header else 'png' | |
image_path = f"/tmp/{uuid.uuid4()}.{ext}" | |
with open(image_path, 'wb') as f: | |
f.write(image_data) | |
return image_path | |
def make_request(url, data=None, headers=None): | |
req = urllib.request.Request(url, data=data, headers=headers) | |
try: | |
with urllib.request.urlopen(req) as response: | |
response_body = response.read().decode() # Decode the response | |
# print(response_body) | |
return json.loads(response_body) # Convert to JSON if valid | |
except urllib.error.HTTPError as e: | |
print(f"HTTPError: {e.code}, {e.reason}") | |
print(e.read().decode()) # Print detailed error response | |
except urllib.error.URLError as e: | |
print(f"URLError: {e.reason}") | |
def queue_prompt(prompt, token): | |
payload = {"prompt": prompt, "client_id": client_id} | |
data = json.dumps(payload).encode('utf-8') | |
headers = { | |
'Authorization': f'Bearer {token}', | |
'Content-Type': 'application/json' | |
} | |
return make_request(f"{server_address}/prompt", data=data, headers=headers) | |
def get_image(filename, subfolder, image_type, token): | |
url_values = {'filename': filename, 'subfolder': subfolder, 'type': image_type} | |
url = f"{server_address}/view?{urllib.parse.urlencode(url_values)}" | |
req = urllib.request.Request(url) | |
req.add_header("Authorization", f"Bearer {token}") | |
try: | |
return urllib.request.urlopen(req).read() | |
except urllib.error.HTTPError as e: | |
print(f"HTTP Error: {e.code} - {e.reason}") | |
print(e.read()) | |
raise | |
def get_history(prompt_id, token): | |
headers = { | |
'Authorization': f'Bearer {token}', | |
'Content-Type': 'application/json' | |
} | |
return make_request(f"{server_address}/history/{prompt_id}", headers=headers) | |
def get_images(ws, prompt, token): | |
prompt_id = queue_prompt(prompt, token)['prompt_id'] | |
output_images = {} | |
while True: | |
out = ws.recv() | |
if isinstance(out, str): | |
message = json.loads(out) | |
if message['type'] == 'executing': | |
data = message['data'] | |
if data['node'] is None and data['prompt_id'] == prompt_id: | |
break # Execution is done | |
history = get_history(prompt_id, token)[prompt_id] | |
for node_id in history['outputs']: | |
node_output = history['outputs'][node_id] | |
images_output = [] | |
if 'images' in node_output: | |
for image in node_output['images']: | |
image_data = get_image(image['filename'], image['subfolder'], image['type'], token) | |
images_output.append(image_data) | |
output_images[node_id] = images_output | |
return output_images | |
def fetch_video(video_data, token): | |
video_url = f"{server_address}/download?file={video_data['filename']}" | |
req = urllib.request.Request(video_url) | |
req.add_header("Authorization", f"Bearer {token}") | |
return urllib.request.urlopen(req).read() | |
# Default route for home welcome | |
def home(): | |
return render_template('home.html') | |
# Generate image route | |
def generate_image(): | |
data = request.json | |
# Extract the token from the request headers | |
token = request.headers.get('Authorization') | |
# token = "Bearer JDJiJDEyJDgwbUJwQTFrQ0JYdS9lR2R4ZEZWdmV3WS9VTmlCeHNtc2txbnBITjR4Qm96ZmFnVUkvNDlh" | |
if token is None: | |
return jsonify({'error': 'No token provided'}), 400 | |
if token.startswith("Bearer "): | |
token = token.split(" ")[1] | |
# Base64 decode the encoded token | |
token = base64.b64decode(token).decode("utf-8") | |
if 'text_prompt' not in data: | |
return jsonify({'error': 'No text prompt provided'}), 400 | |
text_prompt = data['text_prompt'] | |
# Get the path to the current file's directory | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
file_path = os.path.join(current_dir, 'workflows/flux1_dev_checkpoint_workflow_api.json') | |
with open(file_path, 'r', encoding='utf-8') as file: | |
workflow_jsondata = file.read() | |
prompt = json.loads(workflow_jsondata) | |
prompt["6"]["inputs"]["text"] = text_prompt | |
# prompt["7"]["inputs"]["text"] = "text, watermark, low quality, extra hands, extra legs." | |
# seednum = random.randint(1, 9999999999999) | |
# prompt["3"]["inputs"]["seed"] = seednum | |
####################### | |
# For model Flux1.dev # | |
####################### | |
# Generate a random 15-digit seed as an integer | |
seednum = random.randint(100000000000000, 999999999999999) | |
prompt["31"]["inputs"]["seed"] = seednum | |
ws = websocket.WebSocket() | |
try: | |
ws.connect(f"{ws_address}?clientId={client_id}&token={token}", header= | |
{"Authorization": f"Bearer {token}"}) | |
except websocket.WebSocketException as e: | |
return jsonify({'error': f'WebSocket connection failed: {str(e)}'}), 500 | |
images = get_images(ws, prompt, token) | |
ws.close() | |
output_images_base64 = [] | |
for node_id in images: | |
for image_data in images[node_id]: | |
image = Image.open(io.BytesIO(image_data)) | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
output_images_base64.append(img_str) | |
return jsonify({'images': output_images_base64}) | |
# Get image route | |
def get_image_file(filename): | |
return send_file(filename, mimetype='image/png') | |
# Generate image to video route | |
def image_to_video(): | |
# return "Route works!" | |
data = request.json | |
# Extract token from headers | |
token = request.headers.get('Authorization') | |
if token is None: | |
return jsonify({'error': 'No token provided'}), 400 | |
if token.startswith("Bearer "): | |
token = token.split(" ")[1] | |
# token = base64.b64decode(token).decode("utf-8") | |
# Extract text prompt | |
text_prompt = data['text_prompt'] | |
if not text_prompt: | |
return jsonify({'error': 'Text prompt is required'}), 400 | |
# Handle uploaded image or base64-encoded image | |
image_file = request.files.get('image') | |
base64_image = data['base64_image'] | |
if image_file: | |
# Validate and save uploaded image | |
if not allowed_file(image_file.filename): | |
return jsonify({'error': 'Unsupported image format'}), 400 | |
filename = secure_filename(image_file.filename) | |
image_path = f"/tmp/{uuid.uuid4()}_{filename}" | |
image_file.save(image_path) | |
elif base64_image: | |
# Save base64-encoded image | |
try: | |
image_path = save_base64_image(base64_image) | |
except Exception as e: | |
return jsonify({'error': f'Invalid base64 image data: {str(e)}'}), 400 | |
else: | |
return jsonify({'error': 'Image is required (either file or base64)'}), 400 | |
# Get the path to the workflow configuration file | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
file_path = os.path.join(current_dir, 'workflows/cogvideox_image_to_video_workflow_api.json') | |
print(f"Modified workflow: {file_path}", flush=True) | |
# Load and modify workflow | |
with open(file_path, 'r', encoding='utf-8') as file: | |
workflow = json.load(file) | |
workflow["30"]["inputs"]["prompt"] = text_prompt # Text prompt | |
workflow["36"]["inputs"]["upload"] = image_path # Image path | |
workflow["31"]["inputs"]["prompt"] = "Low quality, watermark, strange motion" # Negative prompt | |
seed = random.randint(1e14, 9e14) | |
workflow["57"]["inputs"]["seed"] = seed # Set reproducibility seed | |
# WebSocket connection to trigger workflow | |
ws = websocket.WebSocket() | |
ws.connect(f"{ws_address}?clientId={client_id}&token={token}", header= | |
{"Authorization": f"Bearer {token}"}) | |
ws.send(json.dumps({"workflow": workflow})) # Send the modified workflow | |
# Receive video processing result | |
while True: | |
out = ws.recv() | |
message = json.loads(out) | |
if message.get('type') == 'completed': | |
video_data = message['data'] | |
break | |
# Fetch and return the generated video | |
video_content = fetch_video(video_data, token) | |
return send_file( | |
io.BytesIO(video_content), | |
mimetype='video/mp4', | |
as_attachment=True, | |
download_name='generated_video.mp4' | |
) | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860) # Removed 'debug=True' | |