import torch from torch.utils.data import DataLoader, Dataset from torchvision import transforms from PIL import Image from diffusers import StableDiffusionPipeline import streamlit as st from transformers import CLIPTokenizer # Define the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define your custom dataset class CustomImageDataset(Dataset): def __init__(self, images, prompts, transform=None): self.images = images self.prompts = prompts self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] if self.transform: image = self.transform(image) prompt = self.prompts[idx] return image, prompt # Function to fine-tune the model def fine_tune_model(images, prompts, num_epochs=3): transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = CustomImageDataset(images, prompts, transform) dataloader = DataLoader(dataset, batch_size=4, shuffle=True) # Load Stable Diffusion model pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device) # Load model components vae = pipeline.vae.to(device) unet = pipeline.unet.to(device) text_encoder = pipeline.text_encoder.to(device) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") # Ensure correct tokenizer is used optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6) # Define the optimizer # Define timestep range for training timesteps = torch.linspace(0, 1, steps=5).to(device) # Fine-tuning loop for epoch in range(num_epochs): for i, (images, prompts) in enumerate(dataloader): images = images.to(device) # Move images to GPU if available # Tokenize the prompts inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device) latents = vae.encode(images).latent_dist.sample() * 0.18215 text_embeddings = text_encoder(inputs.input_ids).last_hidden_state noise = torch.randn_like(latents).to(device) noisy_latents = latents + noise # Pass text embeddings and timestep to UNet timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float() pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample loss = torch.nn.functional.mse_loss(pred_noise, noise) optimizer.zero_grad() loss.backward() optimizer.step() if i % 10 == 0: st.write(f"Epoch {epoch+1}/{num_epochs}, Step {i+1}/{len(dataloader)}, Loss: {loss.item()}") st.success("Fine-tuning completed!") # Function to convert tensor to PIL Image def tensor_to_pil(tensor): tensor = tensor.squeeze().cpu().clamp(0, 1) # Remove batch dimension if necessary tensor = transforms.ToPILImage()(tensor) return tensor # Function to generate images def generate_images(pipeline, prompt): with torch.no_grad(): # Generate image from the prompt output = pipeline(prompt) # Convert the output to PIL Image image = output.images[0] # Get the first generated image return image # Streamlit app layout st.title("Fine-Tune Stable Diffusion with Your Images") # Upload images uploaded_files = st.file_uploader("Upload your images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg']) # Input prompts prompts = [] images = [] if uploaded_files: for file in uploaded_files: image = Image.open(file).convert("RGB") # Convert uploaded file to PIL Image images.append(image) prompt = st.text_input(f"Enter a prompt for {file.name}") prompts.append(prompt) # Start fine-tuning if st.button("Start Fine-Tuning") and uploaded_files and prompts: fine_tune_model(images, prompts) # Generate new images st.subheader("Generate New Images") new_prompt = st.text_input("Enter a prompt to generate a new image") if st.button("Generate Image"): if new_prompt: with st.spinner("Generating image..."): # Use the fine-tuned pipeline for generation pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device) # Load the fine-tuned model image = generate_images(pipeline, new_prompt) st.image(image, caption="Generated Image") # Display the generated image # Save the generated image for download image.save("generated_image.png") st.download_button(label="Download Image", data=open("generated_image.png", "rb"), file_name="generated_image.png")