from huggingface_hub import cached_download, hf_hub_url from PIL import Image import os import gradio as gr import spaces import torch from torch import nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPModel @spaces.GPU() def train_image_generation_model(image_folder, text_folder, model_name="image_generation_model"): """Trains an image generation model on the provided dataset. Args: image_folder (str): Path to the folder containing training images. text_folder (str): Path to the folder containing text prompts for each image. model_name (str, optional): Name for the saved model file. Defaults to "image_generation_model". Returns: str: Path to the saved model file. """ class ImageTextDataset(Dataset): def __init__(self, image_folder, text_folder, transform=None): self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] self.text_paths = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.lower().endswith('.txt')] self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image = Image.open(self.image_paths[idx]).convert("RGB") if self.transform: image = self.transform(image) with open(self.text_paths[idx], 'r') as f: text = f.read().strip() return image, text # Load CLIP model clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") # Define image and text transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) ]) # Create dataset and dataloader dataset = ImageTextDataset(image_folder, text_folder, transform=transform) dataloader = DataLoader(dataset, batch_size=8, shuffle=True) # Define optimizer and loss function optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-5) loss_fn = nn.CrossEntropyLoss() # Train the model for epoch in range(10): for i, (images, texts) in enumerate(dataloader): optimizer.zero_grad() image_features = clip_model.get_image_features(images) text_features = clip_model.get_text_features(tokenizer(texts, return_tensors="pt")["input_ids"]) similarity = image_features @ text_features.T loss = loss_fn(similarity, torch.arange(images.size(0), device=images.device)) loss.backward() optimizer.step() print(f"Epoch: {epoch} | Iteration: {i} | Loss: {loss.item()}") # Save the trained model model_path = os.path.join(os.getcwd(), model_name + ".pt") torch.save(clip_model.state_dict(), model_path) return model_path # Define Gradio interface iface = gr.Interface( fn=train_image_generation_model, inputs=[ gr.File(label="Image Folder"), gr.File(label="Text Prompts Folder"), ], outputs=gr.File(label="Model File"), title="Image Generation Model Trainer", description="Upload a folder of images and their corresponding text prompts to train a model.", ) iface.launch(share=True)