File size: 3,766 Bytes
1cbcd7d
 
 
 
be37c92
1cbcd7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e88f579
1cbcd7d
 
 
e88f579
1cbcd7d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from huggingface_hub import cached_download, hf_hub_url
from PIL import Image
import os
import gradio as gr
import spaces
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPModel

@spaces.GPU()
def train_image_generation_model(image_folder, text_folder, model_name="image_generation_model"):
    """Trains an image generation model on the provided dataset.

    Args:
        image_folder (str): Path to the folder containing training images.
        text_folder (str): Path to the folder containing text prompts for each image.
        model_name (str, optional): Name for the saved model file. Defaults to "image_generation_model".

    Returns:
        str: Path to the saved model file.
    """

    class ImageTextDataset(Dataset):
        def __init__(self, image_folder, text_folder, transform=None):
            self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            self.text_paths = [os.path.join(text_folder, f) for f in os.listdir(text_folder) if f.lower().endswith('.txt')]
            self.transform = transform

        def __len__(self):
            return len(self.image_paths)

        def __getitem__(self, idx):
            image = Image.open(self.image_paths[idx]).convert("RGB")
            if self.transform:
                image = self.transform(image)
            with open(self.text_paths[idx], 'r') as f:
                text = f.read().strip()
            return image, text

    # Load CLIP model
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    # Define image and text transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
    ])

    # Create dataset and dataloader
    dataset = ImageTextDataset(image_folder, text_folder, transform=transform)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

    # Define optimizer and loss function
    optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-5)
    loss_fn = nn.CrossEntropyLoss()

    # Train the model
    for epoch in range(10):
        for i, (images, texts) in enumerate(dataloader):
            optimizer.zero_grad()
            image_features = clip_model.get_image_features(images)
            text_features = clip_model.get_text_features(tokenizer(texts, return_tensors="pt")["input_ids"])
            similarity = image_features @ text_features.T
            loss = loss_fn(similarity, torch.arange(images.size(0), device=images.device))
            loss.backward()
            optimizer.step()
            print(f"Epoch: {epoch} | Iteration: {i} | Loss: {loss.item()}")

    # Save the trained model
    model_path = os.path.join(os.getcwd(), model_name + ".pt")
    torch.save(clip_model.state_dict(), model_path)

    return model_path

# Define Gradio interface
iface = gr.Interface(
    fn=train_image_generation_model,
    inputs=[
        gr.File(label="Image Folder"),
        gr.File(label="Text Prompts Folder"),
        gr.Textbox(label="Model Name"),
    ],
    outputs=gr.File(label="Model File"),
    title="Image Generation Model Trainer",
    description="Upload a folder of images and their corresponding text prompts to train a model.\n Images foler should contain image files. Prompts folder should contain .txt files. Each text file is prompt for each image in images folder.",
)

iface.launch(share=True)