bgaspra commited on
Commit
67797ef
·
verified ·
1 Parent(s): ba5420b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -369
app.py CHANGED
@@ -6,39 +6,14 @@ from torchvision import models
6
  from transformers import BertTokenizer, BertModel
7
  import pandas as pd
8
  from datasets import load_dataset
9
- from torch.utils.data import DataLoader, Dataset, random_split
10
  from sklearn.preprocessing import LabelEncoder
11
- from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
12
- import seaborn as sns
13
- import matplotlib.pyplot as plt
14
- import numpy as np
15
- from tqdm import tqdm
16
- import os
17
- import logging
18
-
19
- # Set up logging
20
- logging.basicConfig(
21
- level=logging.INFO,
22
- format='%(asctime)s - %(levelname)s - %(message)s',
23
- handlers=[
24
- logging.FileHandler('model_training.log'),
25
- logging.StreamHandler()
26
- ]
27
- )
28
-
29
- # Create output directory for results
30
- os.makedirs('output', exist_ok=True)
31
 
32
  # Load dataset and filter out null/none values
33
- logging.info("Loading and filtering dataset...")
34
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
 
35
  dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '')
36
 
37
- if len(dataset) == 0:
38
- raise ValueError("Dataset is empty after filtering!")
39
-
40
- logging.info(f"Dataset size after filtering: {len(dataset)}")
41
-
42
  # Preprocess text data
43
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
44
 
@@ -48,329 +23,63 @@ class CustomDataset(Dataset):
48
  self.transform = transforms.Compose([
49
  transforms.Resize((224, 224)),
50
  transforms.ToTensor(),
51
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
52
- std=[0.229, 0.224, 0.225])
53
  ])
54
  self.label_encoder = LabelEncoder()
55
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
56
- self.unique_models = self.label_encoder.classes_
57
-
58
- logging.info(f"Number of unique models: {len(self.unique_models)}")
59
 
60
  def __len__(self):
61
  return len(self.dataset)
62
 
63
  def __getitem__(self, idx):
64
- try:
65
- image = self.transform(self.dataset[idx]['image'])
66
- text = tokenizer(
67
- self.dataset[idx]['prompt'],
68
- padding='max_length',
69
- truncation=True,
70
- max_length=512,
71
- return_tensors='pt'
72
- )
73
- label = self.labels[idx]
74
- return image, text, label
75
- except Exception as e:
76
- logging.error(f"Error processing item {idx}: {str(e)}")
77
- raise
78
 
 
79
  class ImageModel(nn.Module):
80
  def __init__(self):
81
  super(ImageModel, self).__init__()
82
  self.model = models.resnet18(pretrained=True)
83
  self.model.fc = nn.Linear(self.model.fc.in_features, 512)
84
-
85
  def forward(self, x):
86
- x = self.model(x)
87
- return nn.functional.relu(x)
88
 
 
89
  class TextModel(nn.Module):
90
  def __init__(self):
91
  super(TextModel, self).__init__()
92
  self.bert = BertModel.from_pretrained('bert-base-uncased')
93
  self.fc = nn.Linear(768, 512)
94
-
95
  def forward(self, x):
96
- outputs = self.bert(**x)
97
- x = outputs.pooler_output
98
- x = self.fc(x)
99
- return nn.functional.relu(x)
100
 
 
101
  class CombinedModel(nn.Module):
102
- def __init__(self, num_classes):
103
  super(CombinedModel, self).__init__()
104
  self.image_model = ImageModel()
105
  self.text_model = TextModel()
106
- self.dropout = nn.Dropout(0.2)
107
- self.fc = nn.Linear(1024, num_classes)
108
-
109
  def forward(self, image, text):
110
  image_features = self.image_model(image)
111
  text_features = self.text_model(text)
112
  combined = torch.cat((image_features, text_features), dim=1)
113
- combined = self.dropout(combined)
114
  return self.fc(combined)
115
 
116
- class ModelTrainerEvaluator:
117
- def __init__(self, model, dataset, batch_size=32, learning_rate=0.001):
118
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
119
- logging.info(f"Using device: {self.device}")
120
-
121
- self.model = model.to(self.device)
122
- self.batch_size = batch_size
123
- self.criterion = nn.CrossEntropyLoss()
124
- self.optimizer = torch.optim.AdamW(
125
- model.parameters(),
126
- lr=learning_rate,
127
- weight_decay=0.01
128
- )
129
- self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
130
- self.optimizer,
131
- mode='min',
132
- factor=0.1,
133
- patience=2,
134
- verbose=True
135
- )
136
-
137
- # Split dataset
138
- total_size = len(dataset)
139
- train_size = int(0.7 * total_size)
140
- val_size = int(0.15 * total_size)
141
- test_size = total_size - train_size - val_size
142
-
143
- train_dataset, val_dataset, test_dataset = random_split(
144
- dataset, [train_size, val_size, test_size]
145
- )
146
-
147
- self.train_loader = DataLoader(
148
- train_dataset,
149
- batch_size=batch_size,
150
- shuffle=True,
151
- num_workers=4
152
- )
153
- self.val_loader = DataLoader(
154
- val_dataset,
155
- batch_size=batch_size,
156
- num_workers=4
157
- )
158
- self.test_loader = DataLoader(
159
- test_dataset,
160
- batch_size=batch_size,
161
- num_workers=4
162
- )
163
-
164
- self.unique_models = dataset.unique_models
165
-
166
- def train_epoch(self):
167
- self.model.train()
168
- total_loss = 0
169
- predictions = []
170
- actual_labels = []
171
-
172
- progress_bar = tqdm(self.train_loader, desc="Training")
173
- for batch_idx, batch in enumerate(progress_bar):
174
- try:
175
- images, texts, labels = batch
176
- images = images.to(self.device)
177
- labels = labels.to(self.device)
178
-
179
- # Move text tensors to device
180
- texts = {k: v.squeeze(1).to(self.device) for k, v in texts.items()}
181
-
182
- self.optimizer.zero_grad()
183
- outputs = self.model(images, texts)
184
- loss = self.criterion(outputs, labels)
185
-
186
- loss.backward()
187
- # Gradient clipping
188
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
189
- self.optimizer.step()
190
-
191
- total_loss += loss.item()
192
-
193
- _, preds = torch.max(outputs, 1)
194
- predictions.extend(preds.cpu().numpy())
195
- actual_labels.extend(labels.cpu().numpy())
196
-
197
- # Update progress bar
198
- progress_bar.set_postfix({
199
- 'loss': f'{loss.item():.4f}',
200
- 'avg_loss': f'{total_loss/(batch_idx+1):.4f}'
201
- })
202
-
203
- except Exception as e:
204
- logging.error(f"Error in batch {batch_idx}: {str(e)}")
205
- continue
206
-
207
- return total_loss / len(self.train_loader), predictions, actual_labels
208
-
209
- def evaluate(self, loader, mode="Validation"):
210
- self.model.eval()
211
- total_loss = 0
212
- predictions = []
213
- actual_labels = []
214
-
215
- with torch.no_grad():
216
- progress_bar = tqdm(loader, desc=mode)
217
- for batch_idx, batch in enumerate(progress_bar):
218
- try:
219
- images, texts, labels = batch
220
- images = images.to(self.device)
221
- labels = labels.to(self.device)
222
- texts = {k: v.squeeze(1).to(self.device) for k, v in texts.items()}
223
-
224
- outputs = self.model(images, texts)
225
- loss = self.criterion(outputs, labels)
226
-
227
- total_loss += loss.item()
228
-
229
- _, preds = torch.max(outputs, 1)
230
- predictions.extend(preds.cpu().numpy())
231
- actual_labels.extend(labels.cpu().numpy())
232
-
233
- progress_bar.set_postfix({
234
- 'loss': f'{loss.item():.4f}',
235
- 'avg_loss': f'{total_loss/(batch_idx+1):.4f}'
236
- })
237
-
238
- except Exception as e:
239
- logging.error(f"Error in {mode} batch {batch_idx}: {str(e)}")
240
- continue
241
-
242
- return total_loss / len(loader), predictions, actual_labels
243
-
244
- def plot_confusion_matrix(self, y_true, y_pred, title):
245
- cm = confusion_matrix(y_true, y_pred)
246
- plt.figure(figsize=(15, 15))
247
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
248
- plt.title(title)
249
- plt.ylabel('True Label')
250
- plt.xlabel('Predicted Label')
251
-
252
- # Save plot
253
- filename = f'output/{title.lower().replace(" ", "_")}.png'
254
- plt.savefig(filename)
255
- plt.close()
256
- logging.info(f"Saved confusion matrix to {filename}")
257
-
258
- def generate_evaluation_report(self, y_true, y_pred, title):
259
- report = classification_report(
260
- y_true,
261
- y_pred,
262
- target_names=self.unique_models,
263
- output_dict=True
264
- )
265
- df_report = pd.DataFrame(report).transpose()
266
-
267
- # Save report
268
- filename = f'output/{title.lower().replace(" ", "_")}_report.csv'
269
- df_report.to_csv(filename)
270
- logging.info(f"Saved classification report to {filename}")
271
-
272
- accuracy = accuracy_score(y_true, y_pred)
273
-
274
- logging.info(f"\n{title} Results:")
275
- logging.info(f"Accuracy: {accuracy:.4f}")
276
- logging.info("\nClassification Report:")
277
- logging.info("\n" + classification_report(y_true, y_pred, target_names=self.unique_models))
278
-
279
- return accuracy, df_report
280
-
281
- def train_and_evaluate(self, num_epochs=5):
282
- best_val_loss = float('inf')
283
- train_accuracies = []
284
- val_accuracies = []
285
- train_losses = []
286
- val_losses = []
287
-
288
- logging.info(f"Starting training for {num_epochs} epochs...")
289
-
290
- for epoch in range(num_epochs):
291
- logging.info(f"\nEpoch {epoch+1}/{num_epochs}")
292
-
293
- # Training
294
- train_loss, train_preds, train_labels = self.train_epoch()
295
- train_accuracy, _ = self.generate_evaluation_report(
296
- train_labels,
297
- train_preds,
298
- f"Training_Epoch_{epoch+1}"
299
- )
300
- self.plot_confusion_matrix(
301
- train_labels,
302
- train_preds,
303
- f"Training_Confusion_Matrix_Epoch_{epoch+1}"
304
- )
305
-
306
- # Validation
307
- val_loss, val_preds, val_labels = self.evaluate(self.val_loader)
308
- val_accuracy, _ = self.generate_evaluation_report(
309
- val_labels,
310
- val_preds,
311
- f"Validation_Epoch_{epoch+1}"
312
- )
313
- self.plot_confusion_matrix(
314
- val_labels,
315
- val_preds,
316
- f"Validation_Confusion_Matrix_Epoch_{epoch+1}"
317
- )
318
-
319
- # Update learning rate scheduler
320
- self.scheduler.step(val_loss)
321
-
322
- train_accuracies.append(train_accuracy)
323
- val_accuracies.append(val_accuracy)
324
- train_losses.append(train_loss)
325
- val_losses.append(val_loss)
326
-
327
- logging.info(f"\nTraining Loss: {train_loss:.4f}")
328
- logging.info(f"Validation Loss: {val_loss:.4f}")
329
-
330
- # Save best model
331
- if val_loss < best_val_loss:
332
- best_val_loss = val_loss
333
- torch.save({
334
- 'epoch': epoch,
335
- 'model_state_dict': self.model.state_dict(),
336
- 'optimizer_state_dict': self.optimizer.state_dict(),
337
- 'val_loss': val_loss,
338
- }, 'output/best_model.pth')
339
- logging.info(f"Saved new best model with validation loss: {val_loss:.4f}")
340
-
341
- # Plot training history
342
- plt.figure(figsize=(12, 4))
343
-
344
- # Plot accuracies
345
- plt.subplot(1, 2, 1)
346
- plt.plot(train_accuracies, label='Training Accuracy')
347
- plt.plot(val_accuracies, label='Validation Accuracy')
348
- plt.title('Model Accuracy over Epochs')
349
- plt.xlabel('Epoch')
350
- plt.ylabel('Accuracy')
351
- plt.legend()
352
-
353
- # Plot losses
354
- plt.subplot(1, 2, 2)
355
- plt.plot(train_losses, label='Training Loss')
356
- plt.plot(val_losses, label='Validation Loss')
357
- plt.title('Model Loss over Epochs')
358
- plt.xlabel('Epoch')
359
- plt.ylabel('Loss')
360
- plt.legend()
361
-
362
- plt.tight_layout()
363
- plt.savefig('output/training_history.png')
364
- plt.close()
365
-
366
- # Final test evaluation using best model
367
- logging.info("\nPerforming final evaluation on test set...")
368
- checkpoint = torch.load('output/best_model.pth')
369
- self.model.load_state_dict(checkpoint['model_state_dict'])
370
- test_loss, test_preds, test_labels = self.evaluate(self.test_loader, "Test")
371
- self.generate_evaluation_report(test_labels, test_preds, "Final_Test")
372
- self.plot_confusion_matrix(test_labels, test_preds, "Final_Test_Confusion_Matrix")
373
 
 
374
  def predict(image):
375
  model.eval()
376
  with torch.no_grad():
@@ -387,57 +96,14 @@ def predict(image):
387
  recommended_models = [dataset['Model'][i] for i in indices[0]]
388
  return recommended_models
389
 
390
- def main():
391
- try:
392
- # Create dataset
393
- logging.info("Creating custom dataset...")
394
- custom_dataset = CustomDataset(dataset)
395
-
396
- # Create model
397
- logging.info("Initializing model...")
398
- model = CombinedModel(num_classes=len(custom_dataset.unique_models))
399
-
400
- # Create trainer/evaluator
401
- logging.info("Setting up trainer/evaluator...")
402
- trainer = ModelTrainerEvaluator(
403
- model=model,
404
- dataset=custom_dataset,
405
- batch_size=32,
406
- learning_rate=0.001
407
- )
408
-
409
- # Train and evaluate
410
- logging.info("Starting training process...")
411
- trainer.train_and_evaluate(num_epochs=5)
412
-
413
- # Create Gradio interface
414
- logging.info("Setting up Gradio interface...")
415
- interface = gr.Interface(
416
- fn=predict,
417
- inputs=gr.Image(type="pil"),
418
- outputs=gr.Textbox(label="Recommended Models"),
419
- title="AI Image Model Recommender",
420
- description="Upload an AI-generated image to receive model recommendations.",
421
- examples=[
422
- ["example_image1.jpg"],
423
- ["example_image2.jpg"]
424
- ],
425
- analytics_enabled=False
426
- )
427
-
428
- # Launch the interface
429
- logging.info("Launching Gradio interface...")
430
- interface.launch(share=True)
431
-
432
- except Exception as e:
433
- logging.error(f"Error in main function: {str(e)}")
434
- raise
435
 
436
- if __name__ == "__main__":
437
- try:
438
- main()
439
- except KeyboardInterrupt:
440
- logging.info("Process interrupted by user")
441
- except Exception as e:
442
- logging.error(f"Fatal error: {str(e)}")
443
- raise
 
6
  from transformers import BertTokenizer, BertModel
7
  import pandas as pd
8
  from datasets import load_dataset
9
+ from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Load dataset and filter out null/none values
 
13
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
14
+ # Filter out entries where Model is None or empty
15
  dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '')
16
 
 
 
 
 
 
17
  # Preprocess text data
18
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
19
 
 
23
  self.transform = transforms.Compose([
24
  transforms.Resize((224, 224)),
25
  transforms.ToTensor(),
 
 
26
  ])
27
  self.label_encoder = LabelEncoder()
28
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
 
 
 
29
 
30
  def __len__(self):
31
  return len(self.dataset)
32
 
33
  def __getitem__(self, idx):
34
+ image = self.transform(self.dataset[idx]['image'])
35
+ text = tokenizer(
36
+ self.dataset[idx]['prompt'],
37
+ padding='max_length',
38
+ truncation=True,
39
+ return_tensors='pt'
40
+ )
41
+ label = self.labels[idx]
42
+ return image, text, label
 
 
 
 
 
43
 
44
+ # Define CNN for image processing
45
  class ImageModel(nn.Module):
46
  def __init__(self):
47
  super(ImageModel, self).__init__()
48
  self.model = models.resnet18(pretrained=True)
49
  self.model.fc = nn.Linear(self.model.fc.in_features, 512)
50
+
51
  def forward(self, x):
52
+ return self.model(x)
 
53
 
54
+ # Define MLP for text processing
55
  class TextModel(nn.Module):
56
  def __init__(self):
57
  super(TextModel, self).__init__()
58
  self.bert = BertModel.from_pretrained('bert-base-uncased')
59
  self.fc = nn.Linear(768, 512)
60
+
61
  def forward(self, x):
62
+ output = self.bert(**x)
63
+ return self.fc(output.pooler_output)
 
 
64
 
65
+ # Combined model
66
  class CombinedModel(nn.Module):
67
+ def __init__(self):
68
  super(CombinedModel, self).__init__()
69
  self.image_model = ImageModel()
70
  self.text_model = TextModel()
71
+ self.fc = nn.Linear(1024, len(dataset['Model']))
72
+
 
73
  def forward(self, image, text):
74
  image_features = self.image_model(image)
75
  text_features = self.text_model(text)
76
  combined = torch.cat((image_features, text_features), dim=1)
 
77
  return self.fc(combined)
78
 
79
+ # Instantiate model
80
+ model = CombinedModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ # Define predict function
83
  def predict(image):
84
  model.eval()
85
  with torch.no_grad():
 
96
  recommended_models = [dataset['Model'][i] for i in indices[0]]
97
  return recommended_models
98
 
99
+ # Set up Gradio interface
100
+ interface = gr.Interface(
101
+ fn=predict,
102
+ inputs=gr.Image(type="pil"),
103
+ outputs=gr.Textbox(label="Recommended Models"),
104
+ title="AI Image Model Recommender",
105
+ description="Upload an AI-generated image to receive model recommendations."
106
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ # Launch the app
109
+ interface.launch()