Upload train_dreambooth_lora_sd3.py
Browse files
train_dreambooth_lora_sd3.py
CHANGED
@@ -791,7 +791,11 @@ class DreamBoothDataset(Dataset):
|
|
791 |
if class_data_root is not None:
|
792 |
self.class_data_root = Path(class_data_root)
|
793 |
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
794 |
-
|
|
|
|
|
|
|
|
|
795 |
if class_num is not None:
|
796 |
self.num_class_images = min(len(self.class_images_path), class_num)
|
797 |
else:
|
|
|
791 |
if class_data_root is not None:
|
792 |
self.class_data_root = Path(class_data_root)
|
793 |
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
794 |
+
|
795 |
+
#self.class_images_path = list(self.class_data_root.iterdir())
|
796 |
+
|
797 |
+
self.class_images_path = [p for p in self.class_data_root.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}]
|
798 |
+
|
799 |
if class_num is not None:
|
800 |
self.num_class_images = min(len(self.class_images_path), class_num)
|
801 |
else:
|