bgaspra commited on
Commit
3690a76
·
verified ·
1 Parent(s): 3ea3100

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +303 -90
app.py CHANGED
@@ -13,11 +13,32 @@ import seaborn as sns
13
  import matplotlib.pyplot as plt
14
  import numpy as np
15
  from tqdm import tqdm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Load dataset and filter out null/none values
 
18
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
19
  dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '')
20
 
 
 
 
 
 
21
  # Preprocess text data
22
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
23
 
@@ -27,39 +48,93 @@ class CustomDataset(Dataset):
27
  self.transform = transforms.Compose([
28
  transforms.Resize((224, 224)),
29
  transforms.ToTensor(),
 
 
30
  ])
31
  self.label_encoder = LabelEncoder()
32
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
33
-
34
- # Save unique model names for later use
35
  self.unique_models = self.label_encoder.classes_
 
 
36
 
37
  def __len__(self):
38
  return len(self.dataset)
39
 
40
  def __getitem__(self, idx):
41
- image = self.transform(self.dataset[idx]['image'])
42
- text = tokenizer(
43
- self.dataset[idx]['prompt'],
44
- padding='max_length',
45
- truncation=True,
46
- return_tensors='pt'
47
- )
48
- label = self.labels[idx]
49
- return image, text, label
 
 
 
 
 
50
 
51
- # Model classes remain the same as before
52
- # ... (ImageModel, TextModel, CombinedModel classes stay unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  class ModelTrainerEvaluator:
55
  def __init__(self, model, dataset, batch_size=32, learning_rate=0.001):
56
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
57
  self.model = model.to(self.device)
58
  self.batch_size = batch_size
59
  self.criterion = nn.CrossEntropyLoss()
60
- self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- # Split dataset into train, validation, and test
63
  total_size = len(dataset)
64
  train_size = int(0.7 * total_size)
65
  val_size = int(0.15 * total_size)
@@ -69,9 +144,22 @@ class ModelTrainerEvaluator:
69
  dataset, [train_size, val_size, test_size]
70
  )
71
 
72
- self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
73
- self.val_loader = DataLoader(val_dataset, batch_size=batch_size)
74
- self.test_loader = DataLoader(test_dataset, batch_size=batch_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  self.unique_models = dataset.unique_models
77
 
@@ -81,49 +169,75 @@ class ModelTrainerEvaluator:
81
  predictions = []
82
  actual_labels = []
83
 
84
- for batch in tqdm(self.train_loader, desc="Training"):
85
- images, texts, labels = batch
86
- images = images.to(self.device)
87
- labels = labels.to(self.device)
88
-
89
- # Forward pass
90
- self.optimizer.zero_grad()
91
- outputs = self.model(images, texts)
92
- loss = self.criterion(outputs, labels)
93
-
94
- # Backward pass
95
- loss.backward()
96
- self.optimizer.step()
97
-
98
- total_loss += loss.item()
99
-
100
- # Store predictions
101
- _, preds = torch.max(outputs, 1)
102
- predictions.extend(preds.cpu().numpy())
103
- actual_labels.extend(labels.cpu().numpy())
104
-
105
- return total_loss / len(self.train_loader), predictions, actual_labels
106
-
107
- def evaluate(self, loader, mode="Validation"):
108
- self.model.eval()
109
- total_loss = 0
110
- predictions = []
111
- actual_labels = []
112
-
113
- with torch.no_grad():
114
- for batch in tqdm(loader, desc=mode):
115
  images, texts, labels = batch
116
  images = images.to(self.device)
117
  labels = labels.to(self.device)
118
 
 
 
 
 
119
  outputs = self.model(images, texts)
120
  loss = self.criterion(outputs, labels)
121
 
 
 
 
 
 
122
  total_loss += loss.item()
123
 
124
  _, preds = torch.max(outputs, 1)
125
  predictions.extend(preds.cpu().numpy())
126
  actual_labels.extend(labels.cpu().numpy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  return total_loss / len(loader), predictions, actual_labels
129
 
@@ -134,22 +248,33 @@ class ModelTrainerEvaluator:
134
  plt.title(title)
135
  plt.ylabel('True Label')
136
  plt.xlabel('Predicted Label')
137
- plt.savefig(f'{title.lower().replace(" ", "_")}.png')
 
 
 
138
  plt.close()
 
139
 
140
  def generate_evaluation_report(self, y_true, y_pred, title):
141
- report = classification_report(y_true, y_pred,
142
- target_names=self.unique_models,
143
- output_dict=True)
 
 
 
144
  df_report = pd.DataFrame(report).transpose()
145
- df_report.to_csv(f'{title.lower().replace(" ", "_")}_report.csv')
 
 
 
 
146
 
147
  accuracy = accuracy_score(y_true, y_pred)
148
 
149
- print(f"\n{title} Results:")
150
- print(f"Accuracy: {accuracy:.4f}")
151
- print("\nClassification Report:")
152
- print(classification_report(y_true, y_pred, target_names=self.unique_models))
153
 
154
  return accuracy, df_report
155
 
@@ -157,74 +282,162 @@ class ModelTrainerEvaluator:
157
  best_val_loss = float('inf')
158
  train_accuracies = []
159
  val_accuracies = []
 
 
 
 
160
 
161
  for epoch in range(num_epochs):
162
- print(f"\nEpoch {epoch+1}/{num_epochs}")
163
 
164
  # Training
165
  train_loss, train_preds, train_labels = self.train_epoch()
166
  train_accuracy, _ = self.generate_evaluation_report(
167
- train_labels, train_preds, f"Training Epoch {epoch+1}"
 
 
168
  )
169
  self.plot_confusion_matrix(
170
- train_labels, train_preds, f"Training Confusion Matrix Epoch {epoch+1}"
 
 
171
  )
172
 
173
  # Validation
174
  val_loss, val_preds, val_labels = self.evaluate(self.val_loader)
175
  val_accuracy, _ = self.generate_evaluation_report(
176
- val_labels, val_preds, f"Validation Epoch {epoch+1}"
 
 
177
  )
178
  self.plot_confusion_matrix(
179
- val_labels, val_preds, f"Validation Confusion Matrix Epoch {epoch+1}"
 
 
180
  )
181
 
 
 
 
182
  train_accuracies.append(train_accuracy)
183
  val_accuracies.append(val_accuracy)
 
 
184
 
185
- print(f"\nTraining Loss: {train_loss:.4f}")
186
- print(f"Validation Loss: {val_loss:.4f}")
187
 
188
  # Save best model
189
  if val_loss < best_val_loss:
190
  best_val_loss = val_loss
191
- torch.save(self.model.state_dict(), 'best_model.pth')
 
 
 
 
 
 
192
 
193
  # Plot training history
194
- plt.figure(figsize=(10, 6))
 
 
 
195
  plt.plot(train_accuracies, label='Training Accuracy')
196
  plt.plot(val_accuracies, label='Validation Accuracy')
197
  plt.title('Model Accuracy over Epochs')
198
  plt.xlabel('Epoch')
199
  plt.ylabel('Accuracy')
200
  plt.legend()
201
- plt.savefig('training_history.png')
 
 
 
 
 
 
 
 
 
 
 
202
  plt.close()
203
 
204
- # Final test evaluation
205
- self.model.load_state_dict(torch.load('best_model.pth'))
 
 
206
  test_loss, test_preds, test_labels = self.evaluate(self.test_loader, "Test")
207
- self.generate_evaluation_report(test_labels, test_preds, "Final Test")
208
- self.plot_confusion_matrix(test_labels, test_preds, "Final Test Confusion Matrix")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- # Usage example
211
  def main():
212
- # Create dataset
213
- custom_dataset = CustomDataset(dataset)
214
-
215
- # Create model
216
- model = CombinedModel()
217
-
218
- # Create trainer/evaluator
219
- trainer = ModelTrainerEvaluator(
220
- model=model,
221
- dataset=custom_dataset,
222
- batch_size=32,
223
- learning_rate=0.001
224
- )
225
-
226
- # Train and evaluate
227
- trainer.train_and_evaluate(num_epochs=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  if __name__ == "__main__":
230
- main()
 
 
 
 
 
 
 
13
  import matplotlib.pyplot as plt
14
  import numpy as np
15
  from tqdm import tqdm
16
+ import os
17
+ import logging
18
+
19
+ # Set up logging
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(levelname)s - %(message)s',
23
+ handlers=[
24
+ logging.FileHandler('model_training.log'),
25
+ logging.StreamHandler()
26
+ ]
27
+ )
28
+
29
+ # Create output directory for results
30
+ os.makedirs('output', exist_ok=True)
31
 
32
  # Load dataset and filter out null/none values
33
+ logging.info("Loading and filtering dataset...")
34
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
35
  dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '')
36
 
37
+ if len(dataset) == 0:
38
+ raise ValueError("Dataset is empty after filtering!")
39
+
40
+ logging.info(f"Dataset size after filtering: {len(dataset)}")
41
+
42
  # Preprocess text data
43
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
44
 
 
48
  self.transform = transforms.Compose([
49
  transforms.Resize((224, 224)),
50
  transforms.ToTensor(),
51
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
52
+ std=[0.229, 0.224, 0.225])
53
  ])
54
  self.label_encoder = LabelEncoder()
55
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
 
 
56
  self.unique_models = self.label_encoder.classes_
57
+
58
+ logging.info(f"Number of unique models: {len(self.unique_models)}")
59
 
60
  def __len__(self):
61
  return len(self.dataset)
62
 
63
  def __getitem__(self, idx):
64
+ try:
65
+ image = self.transform(self.dataset[idx]['image'])
66
+ text = tokenizer(
67
+ self.dataset[idx]['prompt'],
68
+ padding='max_length',
69
+ truncation=True,
70
+ max_length=512,
71
+ return_tensors='pt'
72
+ )
73
+ label = self.labels[idx]
74
+ return image, text, label
75
+ except Exception as e:
76
+ logging.error(f"Error processing item {idx}: {str(e)}")
77
+ raise
78
 
79
+ class ImageModel(nn.Module):
80
+ def __init__(self):
81
+ super(ImageModel, self).__init__()
82
+ self.model = models.resnet18(pretrained=True)
83
+ self.model.fc = nn.Linear(self.model.fc.in_features, 512)
84
+
85
+ def forward(self, x):
86
+ x = self.model(x)
87
+ return nn.functional.relu(x)
88
+
89
+ class TextModel(nn.Module):
90
+ def __init__(self):
91
+ super(TextModel, self).__init__()
92
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
93
+ self.fc = nn.Linear(768, 512)
94
+
95
+ def forward(self, x):
96
+ outputs = self.bert(**x)
97
+ x = outputs.pooler_output
98
+ x = self.fc(x)
99
+ return nn.functional.relu(x)
100
+
101
+ class CombinedModel(nn.Module):
102
+ def __init__(self, num_classes):
103
+ super(CombinedModel, self).__init__()
104
+ self.image_model = ImageModel()
105
+ self.text_model = TextModel()
106
+ self.dropout = nn.Dropout(0.2)
107
+ self.fc = nn.Linear(1024, num_classes)
108
+
109
+ def forward(self, image, text):
110
+ image_features = self.image_model(image)
111
+ text_features = self.text_model(text)
112
+ combined = torch.cat((image_features, text_features), dim=1)
113
+ combined = self.dropout(combined)
114
+ return self.fc(combined)
115
 
116
  class ModelTrainerEvaluator:
117
  def __init__(self, model, dataset, batch_size=32, learning_rate=0.001):
118
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
119
+ logging.info(f"Using device: {self.device}")
120
+
121
  self.model = model.to(self.device)
122
  self.batch_size = batch_size
123
  self.criterion = nn.CrossEntropyLoss()
124
+ self.optimizer = torch.optim.AdamW(
125
+ model.parameters(),
126
+ lr=learning_rate,
127
+ weight_decay=0.01
128
+ )
129
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
130
+ self.optimizer,
131
+ mode='min',
132
+ factor=0.1,
133
+ patience=2,
134
+ verbose=True
135
+ )
136
 
137
+ # Split dataset
138
  total_size = len(dataset)
139
  train_size = int(0.7 * total_size)
140
  val_size = int(0.15 * total_size)
 
144
  dataset, [train_size, val_size, test_size]
145
  )
146
 
147
+ self.train_loader = DataLoader(
148
+ train_dataset,
149
+ batch_size=batch_size,
150
+ shuffle=True,
151
+ num_workers=4
152
+ )
153
+ self.val_loader = DataLoader(
154
+ val_dataset,
155
+ batch_size=batch_size,
156
+ num_workers=4
157
+ )
158
+ self.test_loader = DataLoader(
159
+ test_dataset,
160
+ batch_size=batch_size,
161
+ num_workers=4
162
+ )
163
 
164
  self.unique_models = dataset.unique_models
165
 
 
169
  predictions = []
170
  actual_labels = []
171
 
172
+ progress_bar = tqdm(self.train_loader, desc="Training")
173
+ for batch_idx, batch in enumerate(progress_bar):
174
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  images, texts, labels = batch
176
  images = images.to(self.device)
177
  labels = labels.to(self.device)
178
 
179
+ # Move text tensors to device
180
+ texts = {k: v.squeeze(1).to(self.device) for k, v in texts.items()}
181
+
182
+ self.optimizer.zero_grad()
183
  outputs = self.model(images, texts)
184
  loss = self.criterion(outputs, labels)
185
 
186
+ loss.backward()
187
+ # Gradient clipping
188
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
189
+ self.optimizer.step()
190
+
191
  total_loss += loss.item()
192
 
193
  _, preds = torch.max(outputs, 1)
194
  predictions.extend(preds.cpu().numpy())
195
  actual_labels.extend(labels.cpu().numpy())
196
+
197
+ # Update progress bar
198
+ progress_bar.set_postfix({
199
+ 'loss': f'{loss.item():.4f}',
200
+ 'avg_loss': f'{total_loss/(batch_idx+1):.4f}'
201
+ })
202
+
203
+ except Exception as e:
204
+ logging.error(f"Error in batch {batch_idx}: {str(e)}")
205
+ continue
206
+
207
+ return total_loss / len(self.train_loader), predictions, actual_labels
208
+
209
+ def evaluate(self, loader, mode="Validation"):
210
+ self.model.eval()
211
+ total_loss = 0
212
+ predictions = []
213
+ actual_labels = []
214
+
215
+ with torch.no_grad():
216
+ progress_bar = tqdm(loader, desc=mode)
217
+ for batch_idx, batch in enumerate(progress_bar):
218
+ try:
219
+ images, texts, labels = batch
220
+ images = images.to(self.device)
221
+ labels = labels.to(self.device)
222
+ texts = {k: v.squeeze(1).to(self.device) for k, v in texts.items()}
223
+
224
+ outputs = self.model(images, texts)
225
+ loss = self.criterion(outputs, labels)
226
+
227
+ total_loss += loss.item()
228
+
229
+ _, preds = torch.max(outputs, 1)
230
+ predictions.extend(preds.cpu().numpy())
231
+ actual_labels.extend(labels.cpu().numpy())
232
+
233
+ progress_bar.set_postfix({
234
+ 'loss': f'{loss.item():.4f}',
235
+ 'avg_loss': f'{total_loss/(batch_idx+1):.4f}'
236
+ })
237
+
238
+ except Exception as e:
239
+ logging.error(f"Error in {mode} batch {batch_idx}: {str(e)}")
240
+ continue
241
 
242
  return total_loss / len(loader), predictions, actual_labels
243
 
 
248
  plt.title(title)
249
  plt.ylabel('True Label')
250
  plt.xlabel('Predicted Label')
251
+
252
+ # Save plot
253
+ filename = f'output/{title.lower().replace(" ", "_")}.png'
254
+ plt.savefig(filename)
255
  plt.close()
256
+ logging.info(f"Saved confusion matrix to {filename}")
257
 
258
  def generate_evaluation_report(self, y_true, y_pred, title):
259
+ report = classification_report(
260
+ y_true,
261
+ y_pred,
262
+ target_names=self.unique_models,
263
+ output_dict=True
264
+ )
265
  df_report = pd.DataFrame(report).transpose()
266
+
267
+ # Save report
268
+ filename = f'output/{title.lower().replace(" ", "_")}_report.csv'
269
+ df_report.to_csv(filename)
270
+ logging.info(f"Saved classification report to {filename}")
271
 
272
  accuracy = accuracy_score(y_true, y_pred)
273
 
274
+ logging.info(f"\n{title} Results:")
275
+ logging.info(f"Accuracy: {accuracy:.4f}")
276
+ logging.info("\nClassification Report:")
277
+ logging.info("\n" + classification_report(y_true, y_pred, target_names=self.unique_models))
278
 
279
  return accuracy, df_report
280
 
 
282
  best_val_loss = float('inf')
283
  train_accuracies = []
284
  val_accuracies = []
285
+ train_losses = []
286
+ val_losses = []
287
+
288
+ logging.info(f"Starting training for {num_epochs} epochs...")
289
 
290
  for epoch in range(num_epochs):
291
+ logging.info(f"\nEpoch {epoch+1}/{num_epochs}")
292
 
293
  # Training
294
  train_loss, train_preds, train_labels = self.train_epoch()
295
  train_accuracy, _ = self.generate_evaluation_report(
296
+ train_labels,
297
+ train_preds,
298
+ f"Training_Epoch_{epoch+1}"
299
  )
300
  self.plot_confusion_matrix(
301
+ train_labels,
302
+ train_preds,
303
+ f"Training_Confusion_Matrix_Epoch_{epoch+1}"
304
  )
305
 
306
  # Validation
307
  val_loss, val_preds, val_labels = self.evaluate(self.val_loader)
308
  val_accuracy, _ = self.generate_evaluation_report(
309
+ val_labels,
310
+ val_preds,
311
+ f"Validation_Epoch_{epoch+1}"
312
  )
313
  self.plot_confusion_matrix(
314
+ val_labels,
315
+ val_preds,
316
+ f"Validation_Confusion_Matrix_Epoch_{epoch+1}"
317
  )
318
 
319
+ # Update learning rate scheduler
320
+ self.scheduler.step(val_loss)
321
+
322
  train_accuracies.append(train_accuracy)
323
  val_accuracies.append(val_accuracy)
324
+ train_losses.append(train_loss)
325
+ val_losses.append(val_loss)
326
 
327
+ logging.info(f"\nTraining Loss: {train_loss:.4f}")
328
+ logging.info(f"Validation Loss: {val_loss:.4f}")
329
 
330
  # Save best model
331
  if val_loss < best_val_loss:
332
  best_val_loss = val_loss
333
+ torch.save({
334
+ 'epoch': epoch,
335
+ 'model_state_dict': self.model.state_dict(),
336
+ 'optimizer_state_dict': self.optimizer.state_dict(),
337
+ 'val_loss': val_loss,
338
+ }, 'output/best_model.pth')
339
+ logging.info(f"Saved new best model with validation loss: {val_loss:.4f}")
340
 
341
  # Plot training history
342
+ plt.figure(figsize=(12, 4))
343
+
344
+ # Plot accuracies
345
+ plt.subplot(1, 2, 1)
346
  plt.plot(train_accuracies, label='Training Accuracy')
347
  plt.plot(val_accuracies, label='Validation Accuracy')
348
  plt.title('Model Accuracy over Epochs')
349
  plt.xlabel('Epoch')
350
  plt.ylabel('Accuracy')
351
  plt.legend()
352
+
353
+ # Plot losses
354
+ plt.subplot(1, 2, 2)
355
+ plt.plot(train_losses, label='Training Loss')
356
+ plt.plot(val_losses, label='Validation Loss')
357
+ plt.title('Model Loss over Epochs')
358
+ plt.xlabel('Epoch')
359
+ plt.ylabel('Loss')
360
+ plt.legend()
361
+
362
+ plt.tight_layout()
363
+ plt.savefig('output/training_history.png')
364
  plt.close()
365
 
366
+ # Final test evaluation using best model
367
+ logging.info("\nPerforming final evaluation on test set...")
368
+ checkpoint = torch.load('output/best_model.pth')
369
+ self.model.load_state_dict(checkpoint['model_state_dict'])
370
  test_loss, test_preds, test_labels = self.evaluate(self.test_loader, "Test")
371
+ self.generate_evaluation_report(test_labels, test_preds, "Final_Test")
372
+ self.plot_confusion_matrix(test_labels, test_preds, "Final_Test_Confusion_Matrix")
373
+
374
+ def predict(image):
375
+ model.eval()
376
+ with torch.no_grad():
377
+ image = transforms.ToTensor()(image).unsqueeze(0)
378
+ image = transforms.Resize((224, 224))(image)
379
+ text_input = tokenizer(
380
+ "Sample prompt",
381
+ return_tensors='pt',
382
+ padding=True,
383
+ truncation=True
384
+ )
385
+ output = model(image, text_input)
386
+ _, indices = torch.topk(output, 5)
387
+ recommended_models = [dataset['Model'][i] for i in indices[0]]
388
+ return recommended_models
389
 
 
390
  def main():
391
+ try:
392
+ # Create dataset
393
+ logging.info("Creating custom dataset...")
394
+ custom_dataset = CustomDataset(dataset)
395
+
396
+ # Create model
397
+ logging.info("Initializing model...")
398
+ model = CombinedModel(num_classes=len(custom_dataset.unique_models))
399
+
400
+ # Create trainer/evaluator
401
+ logging.info("Setting up trainer/evaluator...")
402
+ trainer = ModelTrainerEvaluator(
403
+ model=model,
404
+ dataset=custom_dataset,
405
+ batch_size=32,
406
+ learning_rate=0.001
407
+ )
408
+
409
+ # Train and evaluate
410
+ logging.info("Starting training process...")
411
+ trainer.train_and_evaluate(num_epochs=5)
412
+
413
+ # Create Gradio interface
414
+ logging.info("Setting up Gradio interface...")
415
+ interface = gr.Interface(
416
+ fn=predict,
417
+ inputs=gr.Image(type="pil"),
418
+ outputs=gr.Textbox(label="Recommended Models"),
419
+ title="AI Image Model Recommender",
420
+ description="Upload an AI-generated image to receive model recommendations.",
421
+ examples=[
422
+ ["example_image1.jpg"],
423
+ ["example_image2.jpg"]
424
+ ],
425
+ analytics_enabled=False
426
+ )
427
+
428
+ # Launch the interface
429
+ logging.info("Launching Gradio interface...")
430
+ interface.launch(share=True)
431
+
432
+ except Exception as e:
433
+ logging.error(f"Error in main function: {str(e)}")
434
+ raise
435
 
436
  if __name__ == "__main__":
437
+ try:
438
+ main()
439
+ except KeyboardInterrupt:
440
+ logging.info("Process interrupted by user")
441
+ except Exception as e:
442
+ logging.error(f"Fatal error: {str(e)}")
443
+ raise