|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from torchvision import models |
|
from transformers import BertTokenizer, BertModel |
|
import pandas as pd |
|
from datasets import load_dataset |
|
from torch.utils.data import DataLoader, Dataset, random_split |
|
from sklearn.preprocessing import LabelEncoder |
|
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from tqdm import tqdm |
|
import os |
|
import logging |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler('model_training.log'), |
|
logging.StreamHandler() |
|
] |
|
) |
|
|
|
|
|
os.makedirs('output', exist_ok=True) |
|
|
|
|
|
logging.info("Loading and filtering dataset...") |
|
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') |
|
dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '') |
|
|
|
if len(dataset) == 0: |
|
raise ValueError("Dataset is empty after filtering!") |
|
|
|
logging.info(f"Dataset size after filtering: {len(dataset)}") |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
self.label_encoder = LabelEncoder() |
|
self.labels = self.label_encoder.fit_transform(dataset['Model']) |
|
self.unique_models = self.label_encoder.classes_ |
|
|
|
logging.info(f"Number of unique models: {len(self.unique_models)}") |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
try: |
|
image = self.transform(self.dataset[idx]['image']) |
|
text = tokenizer( |
|
self.dataset[idx]['prompt'], |
|
padding='max_length', |
|
truncation=True, |
|
max_length=512, |
|
return_tensors='pt' |
|
) |
|
label = self.labels[idx] |
|
return image, text, label |
|
except Exception as e: |
|
logging.error(f"Error processing item {idx}: {str(e)}") |
|
raise |
|
|
|
class ImageModel(nn.Module): |
|
def __init__(self): |
|
super(ImageModel, self).__init__() |
|
self.model = models.resnet18(pretrained=True) |
|
self.model.fc = nn.Linear(self.model.fc.in_features, 512) |
|
|
|
def forward(self, x): |
|
x = self.model(x) |
|
return nn.functional.relu(x) |
|
|
|
class TextModel(nn.Module): |
|
def __init__(self): |
|
super(TextModel, self).__init__() |
|
self.bert = BertModel.from_pretrained('bert-base-uncased') |
|
self.fc = nn.Linear(768, 512) |
|
|
|
def forward(self, x): |
|
outputs = self.bert(**x) |
|
x = outputs.pooler_output |
|
x = self.fc(x) |
|
return nn.functional.relu(x) |
|
|
|
class CombinedModel(nn.Module): |
|
def __init__(self, num_classes): |
|
super(CombinedModel, self).__init__() |
|
self.image_model = ImageModel() |
|
self.text_model = TextModel() |
|
self.dropout = nn.Dropout(0.2) |
|
self.fc = nn.Linear(1024, num_classes) |
|
|
|
def forward(self, image, text): |
|
image_features = self.image_model(image) |
|
text_features = self.text_model(text) |
|
combined = torch.cat((image_features, text_features), dim=1) |
|
combined = self.dropout(combined) |
|
return self.fc(combined) |
|
|
|
class ModelTrainerEvaluator: |
|
def __init__(self, model, dataset, batch_size=32, learning_rate=0.001): |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
logging.info(f"Using device: {self.device}") |
|
|
|
self.model = model.to(self.device) |
|
self.batch_size = batch_size |
|
self.criterion = nn.CrossEntropyLoss() |
|
self.optimizer = torch.optim.AdamW( |
|
model.parameters(), |
|
lr=learning_rate, |
|
weight_decay=0.01 |
|
) |
|
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
self.optimizer, |
|
mode='min', |
|
factor=0.1, |
|
patience=2, |
|
verbose=True |
|
) |
|
|
|
|
|
total_size = len(dataset) |
|
train_size = int(0.7 * total_size) |
|
val_size = int(0.15 * total_size) |
|
test_size = total_size - train_size - val_size |
|
|
|
train_dataset, val_dataset, test_dataset = random_split( |
|
dataset, [train_size, val_size, test_size] |
|
) |
|
|
|
self.train_loader = DataLoader( |
|
train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=4 |
|
) |
|
self.val_loader = DataLoader( |
|
val_dataset, |
|
batch_size=batch_size, |
|
num_workers=4 |
|
) |
|
self.test_loader = DataLoader( |
|
test_dataset, |
|
batch_size=batch_size, |
|
num_workers=4 |
|
) |
|
|
|
self.unique_models = dataset.unique_models |
|
|
|
def train_epoch(self): |
|
self.model.train() |
|
total_loss = 0 |
|
predictions = [] |
|
actual_labels = [] |
|
|
|
progress_bar = tqdm(self.train_loader, desc="Training") |
|
for batch_idx, batch in enumerate(progress_bar): |
|
try: |
|
images, texts, labels = batch |
|
images = images.to(self.device) |
|
labels = labels.to(self.device) |
|
|
|
|
|
texts = {k: v.squeeze(1).to(self.device) for k, v in texts.items()} |
|
|
|
self.optimizer.zero_grad() |
|
outputs = self.model(images, texts) |
|
loss = self.criterion(outputs, labels) |
|
|
|
loss.backward() |
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
|
self.optimizer.step() |
|
|
|
total_loss += loss.item() |
|
|
|
_, preds = torch.max(outputs, 1) |
|
predictions.extend(preds.cpu().numpy()) |
|
actual_labels.extend(labels.cpu().numpy()) |
|
|
|
|
|
progress_bar.set_postfix({ |
|
'loss': f'{loss.item():.4f}', |
|
'avg_loss': f'{total_loss/(batch_idx+1):.4f}' |
|
}) |
|
|
|
except Exception as e: |
|
logging.error(f"Error in batch {batch_idx}: {str(e)}") |
|
continue |
|
|
|
return total_loss / len(self.train_loader), predictions, actual_labels |
|
|
|
def evaluate(self, loader, mode="Validation"): |
|
self.model.eval() |
|
total_loss = 0 |
|
predictions = [] |
|
actual_labels = [] |
|
|
|
with torch.no_grad(): |
|
progress_bar = tqdm(loader, desc=mode) |
|
for batch_idx, batch in enumerate(progress_bar): |
|
try: |
|
images, texts, labels = batch |
|
images = images.to(self.device) |
|
labels = labels.to(self.device) |
|
texts = {k: v.squeeze(1).to(self.device) for k, v in texts.items()} |
|
|
|
outputs = self.model(images, texts) |
|
loss = self.criterion(outputs, labels) |
|
|
|
total_loss += loss.item() |
|
|
|
_, preds = torch.max(outputs, 1) |
|
predictions.extend(preds.cpu().numpy()) |
|
actual_labels.extend(labels.cpu().numpy()) |
|
|
|
progress_bar.set_postfix({ |
|
'loss': f'{loss.item():.4f}', |
|
'avg_loss': f'{total_loss/(batch_idx+1):.4f}' |
|
}) |
|
|
|
except Exception as e: |
|
logging.error(f"Error in {mode} batch {batch_idx}: {str(e)}") |
|
continue |
|
|
|
return total_loss / len(loader), predictions, actual_labels |
|
|
|
def plot_confusion_matrix(self, y_true, y_pred, title): |
|
cm = confusion_matrix(y_true, y_pred) |
|
plt.figure(figsize=(15, 15)) |
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') |
|
plt.title(title) |
|
plt.ylabel('True Label') |
|
plt.xlabel('Predicted Label') |
|
|
|
|
|
filename = f'output/{title.lower().replace(" ", "_")}.png' |
|
plt.savefig(filename) |
|
plt.close() |
|
logging.info(f"Saved confusion matrix to {filename}") |
|
|
|
def generate_evaluation_report(self, y_true, y_pred, title): |
|
report = classification_report( |
|
y_true, |
|
y_pred, |
|
target_names=self.unique_models, |
|
output_dict=True |
|
) |
|
df_report = pd.DataFrame(report).transpose() |
|
|
|
|
|
filename = f'output/{title.lower().replace(" ", "_")}_report.csv' |
|
df_report.to_csv(filename) |
|
logging.info(f"Saved classification report to {filename}") |
|
|
|
accuracy = accuracy_score(y_true, y_pred) |
|
|
|
logging.info(f"\n{title} Results:") |
|
logging.info(f"Accuracy: {accuracy:.4f}") |
|
logging.info("\nClassification Report:") |
|
logging.info("\n" + classification_report(y_true, y_pred, target_names=self.unique_models)) |
|
|
|
return accuracy, df_report |
|
|
|
def train_and_evaluate(self, num_epochs=5): |
|
best_val_loss = float('inf') |
|
train_accuracies = [] |
|
val_accuracies = [] |
|
train_losses = [] |
|
val_losses = [] |
|
|
|
logging.info(f"Starting training for {num_epochs} epochs...") |
|
|
|
for epoch in range(num_epochs): |
|
logging.info(f"\nEpoch {epoch+1}/{num_epochs}") |
|
|
|
|
|
train_loss, train_preds, train_labels = self.train_epoch() |
|
train_accuracy, _ = self.generate_evaluation_report( |
|
train_labels, |
|
train_preds, |
|
f"Training_Epoch_{epoch+1}" |
|
) |
|
self.plot_confusion_matrix( |
|
train_labels, |
|
train_preds, |
|
f"Training_Confusion_Matrix_Epoch_{epoch+1}" |
|
) |
|
|
|
|
|
val_loss, val_preds, val_labels = self.evaluate(self.val_loader) |
|
val_accuracy, _ = self.generate_evaluation_report( |
|
val_labels, |
|
val_preds, |
|
f"Validation_Epoch_{epoch+1}" |
|
) |
|
self.plot_confusion_matrix( |
|
val_labels, |
|
val_preds, |
|
f"Validation_Confusion_Matrix_Epoch_{epoch+1}" |
|
) |
|
|
|
|
|
self.scheduler.step(val_loss) |
|
|
|
train_accuracies.append(train_accuracy) |
|
val_accuracies.append(val_accuracy) |
|
train_losses.append(train_loss) |
|
val_losses.append(val_loss) |
|
|
|
logging.info(f"\nTraining Loss: {train_loss:.4f}") |
|
logging.info(f"Validation Loss: {val_loss:.4f}") |
|
|
|
|
|
if val_loss < best_val_loss: |
|
best_val_loss = val_loss |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': self.model.state_dict(), |
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
'val_loss': val_loss, |
|
}, 'output/best_model.pth') |
|
logging.info(f"Saved new best model with validation loss: {val_loss:.4f}") |
|
|
|
|
|
plt.figure(figsize=(12, 4)) |
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
plt.plot(train_accuracies, label='Training Accuracy') |
|
plt.plot(val_accuracies, label='Validation Accuracy') |
|
plt.title('Model Accuracy over Epochs') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Accuracy') |
|
plt.legend() |
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
plt.plot(train_losses, label='Training Loss') |
|
plt.plot(val_losses, label='Validation Loss') |
|
plt.title('Model Loss over Epochs') |
|
plt.xlabel('Epoch') |
|
plt.ylabel('Loss') |
|
plt.legend() |
|
|
|
plt.tight_layout() |
|
plt.savefig('output/training_history.png') |
|
plt.close() |
|
|
|
|
|
logging.info("\nPerforming final evaluation on test set...") |
|
checkpoint = torch.load('output/best_model.pth') |
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
test_loss, test_preds, test_labels = self.evaluate(self.test_loader, "Test") |
|
self.generate_evaluation_report(test_labels, test_preds, "Final_Test") |
|
self.plot_confusion_matrix(test_labels, test_preds, "Final_Test_Confusion_Matrix") |
|
|
|
def predict(image): |
|
model.eval() |
|
with torch.no_grad(): |
|
image = transforms.ToTensor()(image).unsqueeze(0) |
|
image = transforms.Resize((224, 224))(image) |
|
text_input = tokenizer( |
|
"Sample prompt", |
|
return_tensors='pt', |
|
padding=True, |
|
truncation=True |
|
) |
|
output = model(image, text_input) |
|
_, indices = torch.topk(output, 5) |
|
recommended_models = [dataset['Model'][i] for i in indices[0]] |
|
return recommended_models |
|
|
|
def main(): |
|
try: |
|
|
|
logging.info("Creating custom dataset...") |
|
custom_dataset = CustomDataset(dataset) |
|
|
|
|
|
logging.info("Initializing model...") |
|
model = CombinedModel(num_classes=len(custom_dataset.unique_models)) |
|
|
|
|
|
logging.info("Setting up trainer/evaluator...") |
|
trainer = ModelTrainerEvaluator( |
|
model=model, |
|
dataset=custom_dataset, |
|
batch_size=32, |
|
learning_rate=0.001 |
|
) |
|
|
|
|
|
logging.info("Starting training process...") |
|
trainer.train_and_evaluate(num_epochs=5) |
|
|
|
|
|
logging.info("Setting up Gradio interface...") |
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Textbox(label="Recommended Models"), |
|
title="AI Image Model Recommender", |
|
description="Upload an AI-generated image to receive model recommendations.", |
|
examples=[ |
|
["example_image1.jpg"], |
|
["example_image2.jpg"] |
|
], |
|
analytics_enabled=False |
|
) |
|
|
|
|
|
logging.info("Launching Gradio interface...") |
|
interface.launch(share=True) |
|
|
|
except Exception as e: |
|
logging.error(f"Error in main function: {str(e)}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
try: |
|
main() |
|
except KeyboardInterrupt: |
|
logging.info("Process interrupted by user") |
|
except Exception as e: |
|
logging.error(f"Fatal error: {str(e)}") |
|
raise |
|
|