import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from torchvision import models from transformers import BertTokenizer, BertModel import pandas as pd from datasets import load_dataset from torch.utils.data import DataLoader, Dataset, random_split from sklearn.preprocessing import LabelEncoder from sklearn.metrics import confusion_matrix, classification_report, accuracy_score import seaborn as sns import matplotlib.pyplot as plt import numpy as np from tqdm import tqdm import os import logging # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('model_training.log'), logging.StreamHandler() ] ) # Create output directory for results os.makedirs('output', exist_ok=True) # Load dataset and filter out null/none values logging.info("Loading and filtering dataset...") dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '') if len(dataset) == 0: raise ValueError("Dataset is empty after filtering!") logging.info(f"Dataset size after filtering: {len(dataset)}") # Preprocess text data tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') class CustomDataset(Dataset): def __init__(self, dataset): self.dataset = dataset self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.label_encoder = LabelEncoder() self.labels = self.label_encoder.fit_transform(dataset['Model']) self.unique_models = self.label_encoder.classes_ logging.info(f"Number of unique models: {len(self.unique_models)}") def __len__(self): return len(self.dataset) def __getitem__(self, idx): try: image = self.transform(self.dataset[idx]['image']) text = tokenizer( self.dataset[idx]['prompt'], padding='max_length', truncation=True, max_length=512, return_tensors='pt' ) label = self.labels[idx] return image, text, label except Exception as e: logging.error(f"Error processing item {idx}: {str(e)}") raise class ImageModel(nn.Module): def __init__(self): super(ImageModel, self).__init__() self.model = models.resnet18(pretrained=True) self.model.fc = nn.Linear(self.model.fc.in_features, 512) def forward(self, x): x = self.model(x) return nn.functional.relu(x) class TextModel(nn.Module): def __init__(self): super(TextModel, self).__init__() self.bert = BertModel.from_pretrained('bert-base-uncased') self.fc = nn.Linear(768, 512) def forward(self, x): outputs = self.bert(**x) x = outputs.pooler_output x = self.fc(x) return nn.functional.relu(x) class CombinedModel(nn.Module): def __init__(self, num_classes): super(CombinedModel, self).__init__() self.image_model = ImageModel() self.text_model = TextModel() self.dropout = nn.Dropout(0.2) self.fc = nn.Linear(1024, num_classes) def forward(self, image, text): image_features = self.image_model(image) text_features = self.text_model(text) combined = torch.cat((image_features, text_features), dim=1) combined = self.dropout(combined) return self.fc(combined) class ModelTrainerEvaluator: def __init__(self, model, dataset, batch_size=32, learning_rate=0.001): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f"Using device: {self.device}") self.model = model.to(self.device) self.batch_size = batch_size self.criterion = nn.CrossEntropyLoss() self.optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=0.01 ) self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='min', factor=0.1, patience=2, verbose=True ) # Split dataset total_size = len(dataset) train_size = int(0.7 * total_size) val_size = int(0.15 * total_size) test_size = total_size - train_size - val_size train_dataset, val_dataset, test_dataset = random_split( dataset, [train_size, val_size, test_size] ) self.train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=4 ) self.val_loader = DataLoader( val_dataset, batch_size=batch_size, num_workers=4 ) self.test_loader = DataLoader( test_dataset, batch_size=batch_size, num_workers=4 ) self.unique_models = dataset.unique_models def train_epoch(self): self.model.train() total_loss = 0 predictions = [] actual_labels = [] progress_bar = tqdm(self.train_loader, desc="Training") for batch_idx, batch in enumerate(progress_bar): try: images, texts, labels = batch images = images.to(self.device) labels = labels.to(self.device) # Move text tensors to device texts = {k: v.squeeze(1).to(self.device) for k, v in texts.items()} self.optimizer.zero_grad() outputs = self.model(images, texts) loss = self.criterion(outputs, labels) loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() total_loss += loss.item() _, preds = torch.max(outputs, 1) predictions.extend(preds.cpu().numpy()) actual_labels.extend(labels.cpu().numpy()) # Update progress bar progress_bar.set_postfix({ 'loss': f'{loss.item():.4f}', 'avg_loss': f'{total_loss/(batch_idx+1):.4f}' }) except Exception as e: logging.error(f"Error in batch {batch_idx}: {str(e)}") continue return total_loss / len(self.train_loader), predictions, actual_labels def evaluate(self, loader, mode="Validation"): self.model.eval() total_loss = 0 predictions = [] actual_labels = [] with torch.no_grad(): progress_bar = tqdm(loader, desc=mode) for batch_idx, batch in enumerate(progress_bar): try: images, texts, labels = batch images = images.to(self.device) labels = labels.to(self.device) texts = {k: v.squeeze(1).to(self.device) for k, v in texts.items()} outputs = self.model(images, texts) loss = self.criterion(outputs, labels) total_loss += loss.item() _, preds = torch.max(outputs, 1) predictions.extend(preds.cpu().numpy()) actual_labels.extend(labels.cpu().numpy()) progress_bar.set_postfix({ 'loss': f'{loss.item():.4f}', 'avg_loss': f'{total_loss/(batch_idx+1):.4f}' }) except Exception as e: logging.error(f"Error in {mode} batch {batch_idx}: {str(e)}") continue return total_loss / len(loader), predictions, actual_labels def plot_confusion_matrix(self, y_true, y_pred, title): cm = confusion_matrix(y_true, y_pred) plt.figure(figsize=(15, 15)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') plt.title(title) plt.ylabel('True Label') plt.xlabel('Predicted Label') # Save plot filename = f'output/{title.lower().replace(" ", "_")}.png' plt.savefig(filename) plt.close() logging.info(f"Saved confusion matrix to {filename}") def generate_evaluation_report(self, y_true, y_pred, title): report = classification_report( y_true, y_pred, target_names=self.unique_models, output_dict=True ) df_report = pd.DataFrame(report).transpose() # Save report filename = f'output/{title.lower().replace(" ", "_")}_report.csv' df_report.to_csv(filename) logging.info(f"Saved classification report to {filename}") accuracy = accuracy_score(y_true, y_pred) logging.info(f"\n{title} Results:") logging.info(f"Accuracy: {accuracy:.4f}") logging.info("\nClassification Report:") logging.info("\n" + classification_report(y_true, y_pred, target_names=self.unique_models)) return accuracy, df_report def train_and_evaluate(self, num_epochs=5): best_val_loss = float('inf') train_accuracies = [] val_accuracies = [] train_losses = [] val_losses = [] logging.info(f"Starting training for {num_epochs} epochs...") for epoch in range(num_epochs): logging.info(f"\nEpoch {epoch+1}/{num_epochs}") # Training train_loss, train_preds, train_labels = self.train_epoch() train_accuracy, _ = self.generate_evaluation_report( train_labels, train_preds, f"Training_Epoch_{epoch+1}" ) self.plot_confusion_matrix( train_labels, train_preds, f"Training_Confusion_Matrix_Epoch_{epoch+1}" ) # Validation val_loss, val_preds, val_labels = self.evaluate(self.val_loader) val_accuracy, _ = self.generate_evaluation_report( val_labels, val_preds, f"Validation_Epoch_{epoch+1}" ) self.plot_confusion_matrix( val_labels, val_preds, f"Validation_Confusion_Matrix_Epoch_{epoch+1}" ) # Update learning rate scheduler self.scheduler.step(val_loss) train_accuracies.append(train_accuracy) val_accuracies.append(val_accuracy) train_losses.append(train_loss) val_losses.append(val_loss) logging.info(f"\nTraining Loss: {train_loss:.4f}") logging.info(f"Validation Loss: {val_loss:.4f}") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'val_loss': val_loss, }, 'output/best_model.pth') logging.info(f"Saved new best model with validation loss: {val_loss:.4f}") # Plot training history plt.figure(figsize=(12, 4)) # Plot accuracies plt.subplot(1, 2, 1) plt.plot(train_accuracies, label='Training Accuracy') plt.plot(val_accuracies, label='Validation Accuracy') plt.title('Model Accuracy over Epochs') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend() # Plot losses plt.subplot(1, 2, 2) plt.plot(train_losses, label='Training Loss') plt.plot(val_losses, label='Validation Loss') plt.title('Model Loss over Epochs') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.tight_layout() plt.savefig('output/training_history.png') plt.close() # Final test evaluation using best model logging.info("\nPerforming final evaluation on test set...") checkpoint = torch.load('output/best_model.pth') self.model.load_state_dict(checkpoint['model_state_dict']) test_loss, test_preds, test_labels = self.evaluate(self.test_loader, "Test") self.generate_evaluation_report(test_labels, test_preds, "Final_Test") self.plot_confusion_matrix(test_labels, test_preds, "Final_Test_Confusion_Matrix") def predict(image): model.eval() with torch.no_grad(): image = transforms.ToTensor()(image).unsqueeze(0) image = transforms.Resize((224, 224))(image) text_input = tokenizer( "Sample prompt", return_tensors='pt', padding=True, truncation=True ) output = model(image, text_input) _, indices = torch.topk(output, 5) recommended_models = [dataset['Model'][i] for i in indices[0]] return recommended_models def main(): try: # Create dataset logging.info("Creating custom dataset...") custom_dataset = CustomDataset(dataset) # Create model logging.info("Initializing model...") model = CombinedModel(num_classes=len(custom_dataset.unique_models)) # Create trainer/evaluator logging.info("Setting up trainer/evaluator...") trainer = ModelTrainerEvaluator( model=model, dataset=custom_dataset, batch_size=32, learning_rate=0.001 ) # Train and evaluate logging.info("Starting training process...") trainer.train_and_evaluate(num_epochs=5) # Create Gradio interface logging.info("Setting up Gradio interface...") interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox(label="Recommended Models"), title="AI Image Model Recommender", description="Upload an AI-generated image to receive model recommendations.", examples=[ ["example_image1.jpg"], ["example_image2.jpg"] ], analytics_enabled=False ) # Launch the interface logging.info("Launching Gradio interface...") interface.launch(share=True) except Exception as e: logging.error(f"Error in main function: {str(e)}") raise if __name__ == "__main__": try: main() except KeyboardInterrupt: logging.info("Process interrupted by user") except Exception as e: logging.error(f"Fatal error: {str(e)}") raise