File size: 4,728 Bytes
906c9b0 c02513c 67797ef c02513c 4ab8df2 c02513c 4ab8df2 c02513c 4ab8df2 c02513c 67797ef c02513c 67797ef 3690a76 67797ef 3690a76 67797ef 3690a76 67797ef 3690a76 67797ef 3690a76 67797ef 3690a76 67797ef 3690a76 4ab8df2 3690a76 4ab8df2 67797ef 3690a76 c02513c 4ab8df2 67797ef 4ab8df2 3690a76 67797ef 3690a76 4ab8df2 3690a76 e897bc2 67797ef e897bc2 4ab8df2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from transformers import BertTokenizer, BertModel
import pandas as pd
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
# Load dataset
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
# Filter out entries with None or null Model values
filtered_dataset = dataset.filter(lambda example: example['Model'] is not None)
# Preprocess text data
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
class CustomDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
self.label_encoder = LabelEncoder()
self.labels = self.label_encoder.fit_transform(dataset['Model'])
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image = self.transform(self.dataset[idx]['image'])
text = tokenizer(
self.dataset[idx]['prompt'],
padding='max_length',
truncation=True,
return_tensors='pt'
)
label = self.labels[idx]
return image, text, label
# Define CNN for image processing
class ImageModel(nn.Module):
def __init__(self):
super(ImageModel, self).__init__()
self.model = models.resnet18(pretrained=True)
self.model.fc = nn.Linear(self.model.fc.in_features, 512)
def forward(self, x):
return self.model(x)
# Define MLP for text processing
class TextModel(nn.Module):
def __init__(self):
super(TextModel, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.fc = nn.Linear(768, 512)
def forward(self, x):
output = self.bert(**x)
return self.fc(output.pooler_output)
# Combined model
class CombinedModel(nn.Module):
def __init__(self, num_classes):
super(CombinedModel, self).__init__()
self.image_model = ImageModel()
self.text_model = TextModel()
self.fc = nn.Linear(1024, num_classes)
def forward(self, image, text):
image_features = self.image_model(image)
text_features = self.text_model(text)
combined = torch.cat((image_features, text_features), dim=1)
return self.fc(combined)
def evaluate_model(model, test_loader, device):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, texts, labels in test_loader:
images = images.to(device)
texts = {k: v.to(device) for k, v in texts.items()}
labels = labels.to(device)
outputs = model(images, texts)
_, predicted = torch.max(outputs.data, 1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# Generate confusion matrix
cm = confusion_matrix(all_labels, all_preds)
# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.savefig('confusion_matrix.png')
plt.close()
# Print classification report
print(classification_report(all_labels, all_preds))
# Instantiate model
dataset = CustomDataset(filtered_dataset)
num_classes = len(np.unique(dataset.labels))
model = CombinedModel(num_classes)
# Define predict function
def predict(image):
model.eval()
with torch.no_grad():
image = transforms.ToTensor()(image).unsqueeze(0)
image = transforms.Resize((224, 224))(image)
text_input = tokenizer(
"Sample prompt",
return_tensors='pt',
padding=True,
truncation=True
)
output = model(image, text_input)
_, indices = torch.topk(output, 5)
recommended_models = [dataset.label_encoder.inverse_transform([i])[0] for i in indices[0]]
return recommended_models
# Set up Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Recommended Models"),
title="AI Image Model Recommender",
description="Upload an AI-generated image to receive model recommendations."
)
if __name__ == "__main__":
# Launch the app
interface.launch() |