mjbuehler commited on
Commit
10e4e6b
·
verified ·
1 Parent(s): b831b2e

Upload train_dreambooth_lora_sd3.py

Browse files
Files changed (1) hide show
  1. train_dreambooth_lora_sd3.py +5 -1
train_dreambooth_lora_sd3.py CHANGED
@@ -791,7 +791,11 @@ class DreamBoothDataset(Dataset):
791
  if class_data_root is not None:
792
  self.class_data_root = Path(class_data_root)
793
  self.class_data_root.mkdir(parents=True, exist_ok=True)
794
- self.class_images_path = list(self.class_data_root.iterdir())
 
 
 
 
795
  if class_num is not None:
796
  self.num_class_images = min(len(self.class_images_path), class_num)
797
  else:
 
791
  if class_data_root is not None:
792
  self.class_data_root = Path(class_data_root)
793
  self.class_data_root.mkdir(parents=True, exist_ok=True)
794
+
795
+ #self.class_images_path = list(self.class_data_root.iterdir())
796
+
797
+ self.class_images_path = [p for p in self.class_data_root.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}]
798
+
799
  if class_num is not None:
800
  self.num_class_images = min(len(self.class_images_path), class_num)
801
  else: