Change-cloth-AI / app.py
BEfunnuga's picture
Update app.py
6a03b21 verified
raw
history blame
7.66 kB
import gradio as gr
from gradio_client import Client, handle_file
import re
import time
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Get Hugging Face token from environment variable
hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
# Initialize client with auth
client = Client(
"levihsu/OOTDiffusion",
hf_token=hf_token
)
def generate_outfit(model_image, garment_image, n_samples=1, n_steps=20, image_scale=2, seed=-1):
if model_image is None or garment_image is None:
# Return None for the image output and the error message for the text output
return None, "Please upload both model and garment images"
max_retries = 3 # You might want to adjust retries, but 3 is reasonable
for attempt in range(max_retries):
try:
# Use the client to predict
result = client.predict(
vton_img=handle_file(model_image),
garm_img=handle_file(garment_image),
n_samples=n_samples,
n_steps=n_steps,
image_scale=image_scale,
seed=seed,
api_name="/process_hd"
)
# --- Improved Result Handling ---
output_image_path = None
if isinstance(result, (list, tuple)) and len(result) > 0:
# Often results are tuples/lists; assume the image path is the first element
potential_path = result[0]
if isinstance(potential_path, str) and os.path.exists(potential_path): # Check if it looks like a valid path returned by the client
output_image_path = potential_path
elif isinstance(potential_path, dict) and 'image' in potential_path: # Handle dict case
output_image_path = potential_path['image']
elif isinstance(result, dict) and 'image' in result:
output_image_path = result['image']
elif isinstance(result, str) and os.path.exists(result): # Handle direct string path case
output_image_path = result
if output_image_path:
# Return the image path and None for the error message
return output_image_path, None
else:
# Log the unexpected result format for debugging if needed
print(f"Warning: Unexpected result format from API: {result}")
# Return None for image, and an informative error
return None, f"API returned unexpected format: {type(result)}"
# --- End Improved Result Handling ---
except Exception as e:
error_msg = str(e)
# Check for quota error specifically
if "exceeded your GPU quota" in error_msg or "queue is full" in error_msg.lower():
# Try to extract wait time, provide default if not found
wait_time_match = re.search(r'retry in (\d+:\d+:\d+|\d+\.?\d*\s*s)', error_msg) # Handle seconds too
wait_time_str = "an unknown period (try again in 5-10 mins)"
wait_seconds = 300 # Default 5 mins
if wait_time_match:
wait_time_str = wait_time_match.group(1)
try:
if 's' in wait_time_str:
wait_seconds = int(float(wait_time_str.replace('s','').strip()))
else:
# Convert HH:MM:SS to seconds
parts = list(map(int, wait_time_str.split(':')))
wait_seconds = sum(p * 60**i for i, p in enumerate(reversed(parts)))
except ValueError:
pass # Keep default wait_seconds
# Only sleep if it's not the last attempt
if attempt < max_retries - 1:
print(f"GPU/Queue issue detected. Waiting {wait_seconds} seconds before retry {attempt + 2}/{max_retries}...")
time.sleep(wait_seconds + 2) # Add a small buffer
continue # Go to the next attempt
else:
# Return None for image, and the specific quota error on last attempt
return None, f"GPU quota exceeded or queue full. Please wait ~{wait_time_str} before trying again."
else:
# For any other exception, return None for image and the general error message
print(f"An unexpected error occurred: {error_msg}") # Log the full error server-side
# Return a user-friendly error and None for the image path
return None, f"An error occurred processing the request. Details: {error_msg}"
# If all retries fail (e.g., due to persistent quota issues)
return None, "Failed to generate outfit after multiple retries due to server issues."
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""
## Outfit Diffusion - Try On Virtual Outfits
⚠️ **Note**: This demo uses a free, shared GPU resource which has limits. Errors can occur due to high demand or temporary issues on the service.
- **Try lower settings:** Use Steps (e.g., 10-20) and Scale (e.g., 1-3) for faster processing and lower resource use.
- **Be patient:** If you get quota/queue errors, wait the suggested time before retrying.
- **Check the Space:** You can visit the [OOTDiffusion Space](https://huggingface.co/spaces/levihsu/OOTDiffusion) directly to check its status.
""")
with gr.Row():
with gr.Column(scale=1):
model_image = gr.Image(
label="Model Image (Person)",
type="filepath",
height=400 # Adjusted height slightly
)
# (Keep examples)
model_examples = [...]
gr.Examples(examples=model_examples, inputs=model_image, label="Model Examples")
garment_image = gr.Image(
label="Garment Image (Clothing)",
type="filepath",
height=400 # Adjusted height slightly
)
# (Keep examples)
garment_examples = [...]
gr.Examples(examples=garment_examples, inputs=garment_image, label="Garment Examples")
with gr.Column(scale=1): # Give equal space initially
output_image = gr.Image(label="Generated Output", height=400) # Match height
error_text = gr.Markdown(value="") # Display errors here, start empty
with gr.Row():
# Group controls together
n_samples = gr.Slider(
label="Number of Samples", minimum=1, maximum=4, step=1, value=1 # Max 4 recommended by space
)
n_steps = gr.Slider(
label="Steps", minimum=10, maximum=40, step=1, value=20 # Adjusted default/range
)
image_scale = gr.Slider(
label="Guidance Scale", minimum=1.0, maximum=5.0, step=0.5, value=2.0 # Adjusted default/range
)
seed = gr.Number(
label="Seed (-1 for random)", value=-1, precision=0 # Ensure integer seed
)
generate_button = gr.Button("Generate Outfit", variant="primary") # Make button stand out
# Set up the action for the button
generate_button.click(
fn=generate_outfit,
inputs=[model_image, garment_image, n_samples, n_steps, image_scale, seed],
outputs=[output_image, error_text] # Map outputs correctly
)
# Launch the app with error visibility enabled
demo.launch(show_error=True, debug=True) # Added debug=True for potentially more local logs too