Spaces:
Sleeping
Sleeping
update directed
Browse files- app.py +15 -7
- requirements.txt +1 -1
app.py
CHANGED
@@ -1301,6 +1301,9 @@ def run_fn(
|
|
1301 |
model = load_alignedthreemodel()
|
1302 |
else:
|
1303 |
model = load_model(model_name)
|
|
|
|
|
|
|
1304 |
|
1305 |
if "stable" in model_name.lower() and "diffusion" in model_name.lower():
|
1306 |
model.timestep = layer
|
@@ -1611,13 +1614,21 @@ def load_dataset_images(is_advanced, dataset_name, num_images=10,
|
|
1611 |
for i in valid_classes:
|
1612 |
idx = np.where(labels == i)[0]
|
1613 |
if is_random:
|
1614 |
-
|
|
|
|
|
|
|
|
|
1615 |
else:
|
1616 |
idx = idx[:chunk_size]
|
1617 |
image_idx.extend(idx.tolist())
|
1618 |
if not is_filter:
|
1619 |
if is_random:
|
1620 |
-
|
|
|
|
|
|
|
|
|
1621 |
else:
|
1622 |
image_idx = list(range(num_images))
|
1623 |
key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
|
@@ -3303,21 +3314,18 @@ with demo:
|
|
3303 |
# convert to PIL
|
3304 |
mask = positive_mask.cpu().numpy()
|
3305 |
mask = mask.astype(np.uint8) * 255
|
3306 |
-
mask = [Image.fromarray(mask[i]) for i in range(len(mask))]
|
3307 |
|
3308 |
import cv2
|
3309 |
-
def get_bboxes_and_clean_mask(
|
3310 |
"""
|
3311 |
Args:
|
3312 |
-
-
|
3313 |
- min_area: Minimum area for a connected component to be considered valid (default 500).
|
3314 |
|
3315 |
Returns:
|
3316 |
- bounding_boxes: List of bounding boxes for valid objects (x, y, width, height).
|
3317 |
- cleaned_pil_mask: A Pillow image of the cleaned mask, with small components removed.
|
3318 |
"""
|
3319 |
-
# Convert the Pillow image to a NumPy array
|
3320 |
-
mask = np.array(pil_mask)
|
3321 |
|
3322 |
# Ensure the mask is binary (0 or 255)
|
3323 |
mask = np.where(mask > 127, 255, 0).astype(np.uint8)
|
|
|
1301 |
model = load_alignedthreemodel()
|
1302 |
else:
|
1303 |
model = load_model(model_name)
|
1304 |
+
|
1305 |
+
if directed: # save qkv for directed, need more memory
|
1306 |
+
model.enable_save_qkv()
|
1307 |
|
1308 |
if "stable" in model_name.lower() and "diffusion" in model_name.lower():
|
1309 |
model.timestep = layer
|
|
|
1614 |
for i in valid_classes:
|
1615 |
idx = np.where(labels == i)[0]
|
1616 |
if is_random:
|
1617 |
+
if chunk_size < len(idx):
|
1618 |
+
idx = np.random.RandomState(seed).choice(idx, chunk_size, replace=False)
|
1619 |
+
else:
|
1620 |
+
gr.Warning(f"Class {i} has less than {chunk_size} images.")
|
1621 |
+
idx = idx[:chunk_size]
|
1622 |
else:
|
1623 |
idx = idx[:chunk_size]
|
1624 |
image_idx.extend(idx.tolist())
|
1625 |
if not is_filter:
|
1626 |
if is_random:
|
1627 |
+
if num_images < len(dataset):
|
1628 |
+
image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
|
1629 |
+
else:
|
1630 |
+
gr.Warning(f"Dataset has less than {num_images} images.")
|
1631 |
+
image_idx = list(range(num_images))
|
1632 |
else:
|
1633 |
image_idx = list(range(num_images))
|
1634 |
key = 'image' if 'image' in dataset[0] else list(dataset[0].keys())[0]
|
|
|
3314 |
# convert to PIL
|
3315 |
mask = positive_mask.cpu().numpy()
|
3316 |
mask = mask.astype(np.uint8) * 255
|
|
|
3317 |
|
3318 |
import cv2
|
3319 |
+
def get_bboxes_and_clean_mask(mask, min_area=500):
|
3320 |
"""
|
3321 |
Args:
|
3322 |
+
- mask: A numpy image of a binary mask with 255 for the object and 0 for the background.
|
3323 |
- min_area: Minimum area for a connected component to be considered valid (default 500).
|
3324 |
|
3325 |
Returns:
|
3326 |
- bounding_boxes: List of bounding boxes for valid objects (x, y, width, height).
|
3327 |
- cleaned_pil_mask: A Pillow image of the cleaned mask, with small components removed.
|
3328 |
"""
|
|
|
|
|
3329 |
|
3330 |
# Ensure the mask is binary (0 or 255)
|
3331 |
mask = np.where(mask > 127, 255, 0).astype(np.uint8)
|
requirements.txt
CHANGED
@@ -21,4 +21,4 @@ timm==0.9.2
|
|
21 |
open-clip-torch==2.20.0
|
22 |
pytorch_lightning==1.9.4
|
23 |
gradio_image_prompter==0.1.0
|
24 |
-
ncut-pytorch>=1.
|
|
|
21 |
open-clip-torch==2.20.0
|
22 |
pytorch_lightning==1.9.4
|
23 |
gradio_image_prompter==0.1.0
|
24 |
+
ncut-pytorch>=1.5.2
|