Spaces:
Running
Running
import gradio as gr | |
from gradio_client import Client, handle_file | |
import re | |
import time | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Get Hugging Face token from environment variable | |
hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN") | |
# Initialize client with auth | |
client = Client( | |
"levihsu/OOTDiffusion", | |
hf_token=hf_token | |
) | |
def generate_outfit(model_image, garment_image, n_samples=1, n_steps=20, image_scale=2, seed=-1): | |
if model_image is None or garment_image is None: | |
# Return None for the image output and the error message for the text output | |
return None, "Please upload both model and garment images" | |
max_retries = 3 # You might want to adjust retries, but 3 is reasonable | |
for attempt in range(max_retries): | |
try: | |
# Use the client to predict | |
result = client.predict( | |
vton_img=handle_file(model_image), | |
garm_img=handle_file(garment_image), | |
n_samples=n_samples, | |
n_steps=n_steps, | |
image_scale=image_scale, | |
seed=seed, | |
api_name="/process_hd" | |
) | |
# --- Improved Result Handling --- | |
output_image_path = None | |
if isinstance(result, (list, tuple)) and len(result) > 0: | |
# Often results are tuples/lists; assume the image path is the first element | |
potential_path = result[0] | |
if isinstance(potential_path, str) and os.path.exists(potential_path): # Check if it looks like a valid path returned by the client | |
output_image_path = potential_path | |
elif isinstance(potential_path, dict) and 'image' in potential_path: # Handle dict case | |
output_image_path = potential_path['image'] | |
elif isinstance(result, dict) and 'image' in result: | |
output_image_path = result['image'] | |
elif isinstance(result, str) and os.path.exists(result): # Handle direct string path case | |
output_image_path = result | |
if output_image_path: | |
# Return the image path and None for the error message | |
return output_image_path, None | |
else: | |
# Log the unexpected result format for debugging if needed | |
print(f"Warning: Unexpected result format from API: {result}") | |
# Return None for image, and an informative error | |
return None, f"API returned unexpected format: {type(result)}" | |
# --- End Improved Result Handling --- | |
except Exception as e: | |
error_msg = str(e) | |
# Check for quota error specifically | |
if "exceeded your GPU quota" in error_msg or "queue is full" in error_msg.lower(): | |
# Try to extract wait time, provide default if not found | |
wait_time_match = re.search(r'retry in (\d+:\d+:\d+|\d+\.?\d*\s*s)', error_msg) # Handle seconds too | |
wait_time_str = "an unknown period (try again in 5-10 mins)" | |
wait_seconds = 300 # Default 5 mins | |
if wait_time_match: | |
wait_time_str = wait_time_match.group(1) | |
try: | |
if 's' in wait_time_str: | |
wait_seconds = int(float(wait_time_str.replace('s','').strip())) | |
else: | |
# Convert HH:MM:SS to seconds | |
parts = list(map(int, wait_time_str.split(':'))) | |
wait_seconds = sum(p * 60**i for i, p in enumerate(reversed(parts))) | |
except ValueError: | |
pass # Keep default wait_seconds | |
# Only sleep if it's not the last attempt | |
if attempt < max_retries - 1: | |
print(f"GPU/Queue issue detected. Waiting {wait_seconds} seconds before retry {attempt + 2}/{max_retries}...") | |
time.sleep(wait_seconds + 2) # Add a small buffer | |
continue # Go to the next attempt | |
else: | |
# Return None for image, and the specific quota error on last attempt | |
return None, f"GPU quota exceeded or queue full. Please wait ~{wait_time_str} before trying again." | |
else: | |
# For any other exception, return None for image and the general error message | |
print(f"An unexpected error occurred: {error_msg}") # Log the full error server-side | |
# Return a user-friendly error and None for the image path | |
return None, f"An error occurred processing the request. Details: {error_msg}" | |
# If all retries fail (e.g., due to persistent quota issues) | |
return None, "Failed to generate outfit after multiple retries due to server issues." | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
## Outfit Diffusion - Try On Virtual Outfits | |
⚠️ **Note**: This demo uses a free, shared GPU resource which has limits. Errors can occur due to high demand or temporary issues on the service. | |
- **Try lower settings:** Use Steps (e.g., 10-20) and Scale (e.g., 1-3) for faster processing and lower resource use. | |
- **Be patient:** If you get quota/queue errors, wait the suggested time before retrying. | |
- **Check the Space:** You can visit the [OOTDiffusion Space](https://huggingface.co/spaces/levihsu/OOTDiffusion) directly to check its status. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_image = gr.Image( | |
label="Model Image (Person)", | |
type="filepath", | |
height=400 # Adjusted height slightly | |
) | |
# (Keep examples) | |
model_examples = [...] | |
gr.Examples(examples=model_examples, inputs=model_image, label="Model Examples") | |
garment_image = gr.Image( | |
label="Garment Image (Clothing)", | |
type="filepath", | |
height=400 # Adjusted height slightly | |
) | |
# (Keep examples) | |
garment_examples = [...] | |
gr.Examples(examples=garment_examples, inputs=garment_image, label="Garment Examples") | |
with gr.Column(scale=1): # Give equal space initially | |
output_image = gr.Image(label="Generated Output", height=400) # Match height | |
error_text = gr.Markdown(value="") # Display errors here, start empty | |
with gr.Row(): | |
# Group controls together | |
n_samples = gr.Slider( | |
label="Number of Samples", minimum=1, maximum=4, step=1, value=1 # Max 4 recommended by space | |
) | |
n_steps = gr.Slider( | |
label="Steps", minimum=10, maximum=40, step=1, value=20 # Adjusted default/range | |
) | |
image_scale = gr.Slider( | |
label="Guidance Scale", minimum=1.0, maximum=5.0, step=0.5, value=2.0 # Adjusted default/range | |
) | |
seed = gr.Number( | |
label="Seed (-1 for random)", value=-1, precision=0 # Ensure integer seed | |
) | |
generate_button = gr.Button("Generate Outfit", variant="primary") # Make button stand out | |
# Set up the action for the button | |
generate_button.click( | |
fn=generate_outfit, | |
inputs=[model_image, garment_image, n_samples, n_steps, image_scale, seed], | |
outputs=[output_image, error_text] # Map outputs correctly | |
) | |
# Launch the app with error visibility enabled | |
demo.launch(show_error=True, debug=True) # Added debug=True for potentially more local logs too |