import gradio as gr from PIL import Image, ImageDraw, ImageFilter import requests from io import BytesIO import torch import torchvision.transforms as T from torchvision import models import numpy as np import cv2 # AI model repo for design generation repo = "artificialguybr/TshirtDesignRedmond-V2" def generate_cloth(color_prompt): prompt = f"A plain {color_prompt} colored T-shirt hanging on a plain wall." api_url = f"https://api-inference.huggingface.co/models/{repo}" headers = {} payload = {"inputs": prompt, "parameters": {"num_inference_steps": 30}} response = requests.post(api_url, headers=headers, json=payload) if response.status_code == 200: return Image.open(BytesIO(response.content)).convert("RGB") else: raise Exception(f"Error generating cloth: {response.status_code}") def generate_design(design_prompt): prompt = f"A bold {design_prompt} design with vibrant colors, highly detailed." api_url = f"https://api-inference.huggingface.co/models/{repo}" headers = {} payload = {"inputs": prompt, "parameters": {"num_inference_steps": 30}} response = requests.post(api_url, headers=headers, json=payload) if response.status_code == 200: return Image.open(BytesIO(response.content)).convert("RGBA") else: raise Exception(f"Error generating design: {response.status_code}") # Load pretrained DeepLabV3 model for T-shirt segmentation segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval() # Apply segmentation to extract T-shirt mask def get_tshirt_mask(image): image = image.convert("RGB") # Ensure 3 channels preprocess = T.Compose([ T.Resize((520, 520)), # Resize to avoid distortion T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) input_tensor = preprocess(image).unsqueeze(0) with torch.no_grad(): output = segmentation_model(input_tensor)["out"][0] # Extract T-shirt mask (class 15 in COCO dataset) mask = output.argmax(0).byte().cpu().numpy() raw_mask = Image.fromarray((mask == 15).astype("uint8") * 255) # Binary mask processed_mask = post_process_mask(raw_mask) # Apply post-processing return processed_mask.resize(image.size) # Post-process mask to improve quality def post_process_mask(mask): # Convert mask to NumPy array mask_np = np.array(mask) # Morphological operations to refine mask kernel = np.ones((5, 5), np.uint8) mask_np = cv2.dilate(mask_np, kernel, iterations=2) # Expand mask mask_np = cv2.erode(mask_np, kernel, iterations=1) # Remove noise # Convert back to PIL image and smooth processed_mask = Image.fromarray(mask_np).filter(ImageFilter.GaussianBlur(3)) return processed_mask # Get bounding box from mask def get_bounding_box(mask): mask_np = np.array(mask) coords = np.column_stack(np.where(mask_np > 0)) if coords.size == 0: raise Exception("No T-shirt detected in the image.") x_min, y_min = coords.min(axis=0) x_max, y_max = coords.max(axis=0) return (x_min, y_min, x_max, y_max) # Visualize mask and bounding box on the image for debugging def visualize_mask(image, mask): overlay = image.copy().convert("RGBA") draw = ImageDraw.Draw(overlay) bbox = get_bounding_box(mask) draw.rectangle(bbox, outline="red", width=3) # Draw bounding box blended = Image.blend(image.convert("RGBA"), overlay, alpha=0.5) # Overlay mask blended.save("debug_visualization.png") # Save debug image return blended # Overlay design on the T-shirt def overlay_design(cloth_image, design_image): # Ensure images are in RGBA mode cloth_image = cloth_image.convert("RGBA") design_image = design_image.convert("RGBA") # Generate T-shirt mask mask = get_tshirt_mask(cloth_image) # Extract bounding box for precise placement bbox = get_bounding_box(mask) tshirt_width = bbox[2] - bbox[0] tshirt_height = bbox[3] - bbox[1] # Resize the design to fit the T-shirt design_width = int(tshirt_width * 0.6) design_height = int(tshirt_height * 0.6) resized_design = design_image.resize((design_width, design_height)) # Position the design in the center of the T-shirt design_position = ( bbox[0] + (tshirt_width - design_width) // 2, bbox[1] + (tshirt_height - design_height) // 2 ) # Create a transparent layer for the design transparent_layer = Image.new("RGBA", cloth_image.size, (0, 0, 0, 0)) transparent_layer.paste(resized_design, design_position, resized_design) # Mask the design to the T-shirt area masked_design = Image.composite(transparent_layer, Image.new("RGBA", cloth_image.size), mask) # Combine the cloth image with the masked design final_image = Image.alpha_composite(cloth_image, masked_design) return final_image def debug_intermediate_outputs(cloth_image, mask): # Save debug images cloth_image.save("debug_cloth_image.png") mask.save("debug_tshirt_mask.png") def design_tshirt(color_prompt, design_prompt): cloth_image = generate_cloth(color_prompt) design_image = generate_design(design_prompt) try: mask = get_tshirt_mask(cloth_image) debug_intermediate_outputs(cloth_image, mask) # Debugging visualize_mask(cloth_image, mask) # Save visualization final_image = overlay_design(cloth_image, design_image) return final_image except Exception as e: raise Exception(f"Error in design process: {str(e)}") # Gradio UI with gr.Blocks() as interface: gr.Markdown("# **AI Cloth Designer**") gr.Markdown("Generate custom T-shirts by specifying a color and adding a design that perfectly fits the T-shirt.") with gr.Row(): with gr.Column(): color_prompt = gr.Textbox(label="Cloth Color", placeholder="E.g., Red, Blue") design_prompt = gr.Textbox(label="Design Details", placeholder="E.g., Abstract art, Nature patterns") generate_button = gr.Button("Generate T-Shirt") with gr.Column(): output_image = gr.Image(label="Final T-Shirt Design") generate_button.click( design_tshirt, inputs=[color_prompt, design_prompt], outputs=output_image, ) interface.launch(debug=True)