File size: 4,728 Bytes
906c9b0
c02513c
 
 
 
 
 
 
67797ef
c02513c
4ab8df2
 
 
 
c02513c
4ab8df2
c02513c
4ab8df2
 
 
c02513c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67797ef
 
 
 
 
 
 
 
 
c02513c
67797ef
3690a76
 
 
 
 
67797ef
3690a76
67797ef
3690a76
67797ef
3690a76
 
 
 
 
67797ef
3690a76
67797ef
 
3690a76
67797ef
3690a76
4ab8df2
3690a76
 
 
4ab8df2
67797ef
3690a76
 
 
 
 
c02513c
4ab8df2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67797ef
4ab8df2
 
 
3690a76
67797ef
3690a76
 
 
 
 
 
 
 
 
 
 
 
 
4ab8df2
3690a76
e897bc2
67797ef
 
 
 
 
 
 
 
e897bc2
4ab8df2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from transformers import BertTokenizer, BertModel
import pandas as pd
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Load dataset
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')

# Filter out entries with None or null Model values
filtered_dataset = dataset.filter(lambda example: example['Model'] is not None)

# Preprocess text data
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
        self.label_encoder = LabelEncoder()
        self.labels = self.label_encoder.fit_transform(dataset['Model'])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.transform(self.dataset[idx]['image'])
        text = tokenizer(
            self.dataset[idx]['prompt'],
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        label = self.labels[idx]
        return image, text, label

# Define CNN for image processing
class ImageModel(nn.Module):
    def __init__(self):
        super(ImageModel, self).__init__()
        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, 512)

    def forward(self, x):
        return self.model(x)

# Define MLP for text processing
class TextModel(nn.Module):
    def __init__(self):
        super(TextModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(768, 512)

    def forward(self, x):
        output = self.bert(**x)
        return self.fc(output.pooler_output)

# Combined model
class CombinedModel(nn.Module):
    def __init__(self, num_classes):
        super(CombinedModel, self).__init__()
        self.image_model = ImageModel()
        self.text_model = TextModel()
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, image, text):
        image_features = self.image_model(image)
        text_features = self.text_model(text)
        combined = torch.cat((image_features, text_features), dim=1)
        return self.fc(combined)

def evaluate_model(model, test_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, texts, labels in test_loader:
            images = images.to(device)
            texts = {k: v.to(device) for k, v in texts.items()}
            labels = labels.to(device)
            
            outputs = model(images, texts)
            _, predicted = torch.max(outputs.data, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Generate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    # Print classification report
    print(classification_report(all_labels, all_preds))

# Instantiate model
dataset = CustomDataset(filtered_dataset)
num_classes = len(np.unique(dataset.labels))
model = CombinedModel(num_classes)

# Define predict function
def predict(image):
    model.eval()
    with torch.no_grad():
        image = transforms.ToTensor()(image).unsqueeze(0)
        image = transforms.Resize((224, 224))(image)
        text_input = tokenizer(
            "Sample prompt",
            return_tensors='pt',
            padding=True,
            truncation=True
        )
        output = model(image, text_input)
        _, indices = torch.topk(output, 5)
        recommended_models = [dataset.label_encoder.inverse_transform([i])[0] for i in indices[0]]
    return recommended_models

# Set up Gradio interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Textbox(label="Recommended Models"),
    title="AI Image Model Recommender",
    description="Upload an AI-generated image to receive model recommendations."
)

if __name__ == "__main__":
    # Launch the app
    interface.launch()