import os import pandas as np import torch import streamlit as st from PIL import Image from accelerate import Accelerator from diffusers import DDIMScheduler, AutoencoderKL from transformers import CLIPTextModel, CLIPTokenizer from src.mgd_pipelines.mgd_pipe import MGDPipe from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled from src.utils.set_seeds import set_seed from src.utils.image_from_pipe import generate_images_from_mgd_pipe from src.datasets.dresscode import DressCodeDataset # Set environment variables os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["WANDB_START_METHOD"] = "thread" # Function to process inputs and run inference def run_inference(prompt, sketch_image=None, category="dresses", seed=None, mixed_precision="fp16"): # Initialize accelerator accelerator = Accelerator(mixed_precision=mixed_precision) device = accelerator.device # Load models and datasets tokenizer = CLIPTokenizer.from_pretrained("microsoft/xclip-base-patch32", subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained("microsoft/xclip-base-patch32", subfolder="text_encoder") vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", subfolder="vae") val_scheduler = DDIMScheduler.from_pretrained("ptx0/pseudo-journey-v2", subfolder="scheduler") # Load UNet (assumed pretrained) unet = torch.hub.load("aimagelab/multimodal-garment-designer", "mgd", pretrained=True) # Freeze VAE and text encoder vae.requires_grad_(False) text_encoder.requires_grad_(False) # Set seed for reproducibility if seed is not None: set_seed(seed) # Load appropriate dataset category = [category] test_dataset = DressCodeDataset( dataroot_path="path_to_dataset", phase="test", category=category, size=(512, 384) ) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) # Move models to the device text_encoder.to(device) vae.to(device) unet.to(device).eval() # Handle sketch and text inputs if sketch_image is not None: # Process the sketch (resize, normalize, etc.) sketch_image = sketch_image.resize((512, 384)) sketch_tensor = torch.tensor(np.array(sketch_image)).unsqueeze(0).float().to(device) # Select pipeline (disentangled if required) val_pipe = MGDPipeDisentangled( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=val_scheduler, ).to(device) val_pipe.enable_attention_slicing() # Generate image generated_images = generate_images_from_mgd_pipe( test_dataloader=test_dataloader, pipe=val_pipe, guidance_scale=7.5, seed=seed, sketch_image=sketch_tensor if sketch_image is not None else None, prompt=prompt ) return generated_images[0] # Assuming single image output # Streamlit UI st.title("Fashion Image Generator") st.write("Generate colorful fashion images based on a rough sketch and/or a text prompt.") # Upload a sketch image uploaded_sketch = st.file_uploader("Upload a rough sketch (optional)", type=["png", "jpg", "jpeg"]) # Text input for prompt prompt = st.text_input("Enter a prompt (optional)", "A red dress with floral patterns") # Input options category = st.text_input("Enter category (optional):", "dresses") seed = st.slider("Seed", min_value=1, max_value=100, step=1, value=None) precision = st.selectbox("Select precision:", ["fp16", "fp32"]) # Show uploaded sketch image if uploaded_sketch is not None: sketch_image = Image.open(uploaded_sketch) st.image(sketch_image, caption="Uploaded Sketch", use_column_width=True) # Button to generate image if st.button("Generate Image"): with st.spinner("Generating image..."): # Run inference with sketch or prompt (or both) result_image = run_inference(prompt, sketch_image, category, seed, precision) st.image(result_image, caption="Generated Image", use_column_width=True)