huzey commited on
Commit
0eae266
·
1 Parent(s): 6375d85

update directed

Browse files
Files changed (2) hide show
  1. app.py +15 -7
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1301,6 +1301,9 @@ def run_fn(
1301
  model = load_alignedthreemodel()
1302
  else:
1303
  model = load_model(model_name)
 
 
 
1304
 
1305
  if "stable" in model_name.lower() and "diffusion" in model_name.lower():
1306
  model.timestep = layer
@@ -1611,13 +1614,21 @@ def load_dataset_images(is_advanced, dataset_name, num_images=10,
1611
  for i in valid_classes:
1612
  idx = np.where(labels == i)[0]
1613
  if is_random:
1614
- idx = np.random.RandomState(seed).choice(idx, chunk_size, replace=False)
 
 
 
 
1615
  else:
1616
  idx = idx[:chunk_size]
1617
  image_idx.extend(idx.tolist())
1618
  if not is_filter:
1619
  if is_random:
1620
- image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
 
 
 
 
1621
  else:
1622
  image_idx = list(range(num_images))
1623
  key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
@@ -3303,21 +3314,18 @@ with demo:
3303
  # convert to PIL
3304
  mask = positive_mask.cpu().numpy()
3305
  mask = mask.astype(np.uint8) * 255
3306
- mask = [Image.fromarray(mask[i]) for i in range(len(mask))]
3307
 
3308
  import cv2
3309
- def get_bboxes_and_clean_mask(pil_mask, min_area=500):
3310
  """
3311
  Args:
3312
- - pil_mask: A Pillow image of a binary mask with 255 for the object and 0 for the background.
3313
  - min_area: Minimum area for a connected component to be considered valid (default 500).
3314
 
3315
  Returns:
3316
  - bounding_boxes: List of bounding boxes for valid objects (x, y, width, height).
3317
  - cleaned_pil_mask: A Pillow image of the cleaned mask, with small components removed.
3318
  """
3319
- # Convert the Pillow image to a NumPy array
3320
- mask = np.array(pil_mask)
3321
 
3322
  # Ensure the mask is binary (0 or 255)
3323
  mask = np.where(mask > 127, 255, 0).astype(np.uint8)
 
1301
  model = load_alignedthreemodel()
1302
  else:
1303
  model = load_model(model_name)
1304
+
1305
+ if directed: # save qkv for directed, need more memory
1306
+ model.enable_save_qkv()
1307
 
1308
  if "stable" in model_name.lower() and "diffusion" in model_name.lower():
1309
  model.timestep = layer
 
1614
  for i in valid_classes:
1615
  idx = np.where(labels == i)[0]
1616
  if is_random:
1617
+ if chunk_size < len(idx):
1618
+ idx = np.random.RandomState(seed).choice(idx, chunk_size, replace=False)
1619
+ else:
1620
+ gr.Warning(f"Class {i} has less than {chunk_size} images.")
1621
+ idx = idx[:chunk_size]
1622
  else:
1623
  idx = idx[:chunk_size]
1624
  image_idx.extend(idx.tolist())
1625
  if not is_filter:
1626
  if is_random:
1627
+ if num_images < len(dataset):
1628
+ image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
1629
+ else:
1630
+ gr.Warning(f"Dataset has less than {num_images} images.")
1631
+ image_idx = list(range(num_images))
1632
  else:
1633
  image_idx = list(range(num_images))
1634
  key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
 
3314
  # convert to PIL
3315
  mask = positive_mask.cpu().numpy()
3316
  mask = mask.astype(np.uint8) * 255
 
3317
 
3318
  import cv2
3319
+ def get_bboxes_and_clean_mask(mask, min_area=500):
3320
  """
3321
  Args:
3322
+ - mask: A numpy image of a binary mask with 255 for the object and 0 for the background.
3323
  - min_area: Minimum area for a connected component to be considered valid (default 500).
3324
 
3325
  Returns:
3326
  - bounding_boxes: List of bounding boxes for valid objects (x, y, width, height).
3327
  - cleaned_pil_mask: A Pillow image of the cleaned mask, with small components removed.
3328
  """
 
 
3329
 
3330
  # Ensure the mask is binary (0 or 255)
3331
  mask = np.where(mask > 127, 255, 0).astype(np.uint8)
requirements.txt CHANGED
@@ -21,4 +21,4 @@ timm==0.9.2
21
  open-clip-torch==2.20.0
22
  pytorch_lightning==1.9.4
23
  gradio_image_prompter==0.1.0
24
- ncut-pytorch>=1.4.2
 
21
  open-clip-torch==2.20.0
22
  pytorch_lightning==1.9.4
23
  gradio_image_prompter==0.1.0
24
+ ncut-pytorch>=1.5.2