Spaces:
Paused
Paused
''' | |
Self prompting strategy | |
INPUT: | |
predictor: the initialized sam predictor : | |
rendered_mask_score: - : H*W*1 | |
num_prompt: - | |
index_matrix: the matrix contains the 3D index of the rendered view : H*W*3 | |
OUTPUT: a list of prompts | |
''' | |
import os | |
import torch | |
import math | |
import numpy as np | |
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) | |
def mask_to_prompt(predictor, rendered_mask_score, index_matrix, num_prompts = 3): | |
'''main function for self prompting''' | |
h, w, _ = rendered_mask_score.shape | |
tmp = rendered_mask_score.view(-1) | |
print("tmp min:", tmp.min(), "tmp max:", tmp.max()) | |
rand = torch.ones_like(tmp) | |
topk_v, topk_p = torch.topk(tmp*rand, k = 1)[0].cpu(), torch.topk(tmp*rand, k = 1)[1].cpu() | |
if topk_v <= 0: | |
print("No prompt is available") | |
return np.zeros((0,2)), np.ones((0)) | |
prompt_points = [] | |
prompt_points.append([topk_p[0] % w, topk_p[0] // w]) | |
print((topk_p[0] % w).item(), (topk_p[0] // w).item(), h, w) | |
tmp_mask = rendered_mask_score.clone().detach() | |
area = to8b(tmp_mask.cpu().numpy()).sum() / 255 | |
r = np.sqrt(area / math.pi) | |
masked_r = max(int(r) // 2, 2) | |
# masked_r = max(int(r) // 3, 2) | |
pre_tmp_mask_score = None | |
for _ in range(num_prompts - 1): | |
# mask out a region around the last prompt point | |
input_label = np.ones(len(prompt_points)) | |
previous_masks, previous_scores, previous_logits = predictor.predict( | |
point_coords=np.array(prompt_points), | |
point_labels=input_label, | |
multimask_output=False, | |
) | |
l = 0 if prompt_points[-1][0]-masked_r <= 0 else prompt_points[-1][0]-masked_r | |
r = w-1 if prompt_points[-1][0]+masked_r >= w-1 else prompt_points[-1][0]+masked_r | |
t = 0 if prompt_points[-1][1]-masked_r <= 0 else prompt_points[-1][1]-masked_r | |
b = h-1 if prompt_points[-1][1]+masked_r >= h-1 else prompt_points[-1][1]+masked_r | |
tmp_mask[t:b+1, l:r+1, :] = -1e5 | |
# bool: H W | |
previous_mask_tensor = torch.tensor(previous_masks[0]) | |
previous_mask_tensor = previous_mask_tensor.unsqueeze(0).unsqueeze(0).float() | |
previous_mask_tensor = torch.nn.functional.max_pool2d(previous_mask_tensor, 25, stride = 1, padding = 12) | |
previous_mask_tensor = previous_mask_tensor.squeeze(0).permute([1,2,0]) | |
# tmp_mask[previous_mask_tensor > 0] = -1e5 | |
previous_max_score = torch.max(rendered_mask_score[previous_mask_tensor > 0]) | |
previous_point_index = torch.zeros_like(index_matrix) | |
previous_point_index[:,:,0] = prompt_points[-1][1] / h | |
previous_point_index[:,:,1] = prompt_points[-1][0] / w | |
previous_point_index[:,:,2] = index_matrix[int(prompt_points[-1][1]), int(prompt_points[-1][0]), 2] | |
distance_matrix = torch.sqrt(((index_matrix - previous_point_index)**2).sum(-1)) | |
distance_matrix = (distance_matrix.unsqueeze(-1) - distance_matrix.min()) / (distance_matrix.max() - distance_matrix.min()) | |
cur_tmp_mask = tmp_mask - distance_matrix * max(previous_max_score, 0) | |
if pre_tmp_mask_score is None: | |
pre_tmp_mask_score = cur_tmp_mask | |
else: | |
pre_tmp_mask_score[pre_tmp_mask_score < cur_tmp_mask] = cur_tmp_mask[pre_tmp_mask_score < cur_tmp_mask] | |
pre_tmp_mask_score[tmp_mask == -1e5] = -1e5 | |
tmp_val_point = pre_tmp_mask_score.view(-1).max(dim = 0) | |
if tmp_val_point[0] <= 0: | |
print("There are", len(prompt_points), "prompts") | |
break | |
prompt_points.append([int(tmp_val_point[1].cpu() % w), int(tmp_val_point[1].cpu() // w)]) | |
prompt_points = np.array(prompt_points) | |
input_label = np.ones(len(prompt_points)) | |
return prompt_points, input_label | |
from groundingdino.util.inference import load_model, load_image, predict, annotate | |
import cv2 | |
import groundingdino.datasets.transforms as T | |
from torchvision.ops import box_convert | |
from PIL import Image | |
def image_transform(image) -> torch.Tensor: | |
transform = T.Compose( | |
[ | |
T.RandomResize([800], max_size=1333), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
image_transformed, _ = transform(image, None) | |
return image_transformed | |
def grounding_dino_prompt(image, text): | |
image_tensor = image_transform(Image.fromarray(image)) | |
model_root = './dependencies/GroundingDINO' | |
model = load_model(os.path.join(model_root, "groundingdino/config/GroundingDINO_SwinT_OGC.py"), os.path.join(model_root, "weights/groundingdino_swint_ogc.pth")) | |
BOX_TRESHOLD = 0.35 | |
TEXT_TRESHOLD = 0.25 | |
boxes, logits, phrases = predict( | |
model=model, | |
image=image_tensor, | |
caption=text, | |
box_threshold=BOX_TRESHOLD, | |
text_threshold=TEXT_TRESHOLD | |
) | |
h, w, _ = image.shape | |
print("boxes device", boxes.device) | |
boxes = boxes * torch.Tensor([w, h, w, h]).to(boxes.device) | |
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() | |
print(xyxy) | |
return xyxy | |
''' | |
# new prompt strategy: bbox based | |
# cannot be applied to 360 | |
try: | |
prompt = seg_m_for_prompt[:,:,no] | |
prompt = prompt > 0 | |
box_prompt = masks_to_boxes(prompt.unsqueeze(0)) | |
width = box_prompt[0,2] - box_prompt[0,0] | |
height = box_prompt[0,3] - box_prompt[0,1] | |
box_prompt[0,0] -= 0.05*width | |
box_prompt[0,2] += 0.05*width | |
box_prompt[0,1] -= 0.05*height | |
box_prompt[0,3] += 0.05*height | |
# print(box_prompt) | |
transformed_boxes = predictor.transform.apply_boxes_torch(box_prompt, image.shape[:2]) | |
masks, _, _ = predictor.predict_torch( | |
point_coords=None, | |
point_labels=None, | |
boxes=transformed_boxes, | |
multimask_output=False, | |
) | |
masks = masks.float() | |
except: | |
continue | |
''' | |
''' | |
# mask based | |
H,W,_ = prompt.shape | |
target_size = RLS.get_preprocess_shape(H,W, 256) | |
prompt = torch.nn.functional.interpolate(torch.tensor(prompt).float().unsqueeze(0).permute([0,3,1,2]), target_size, mode = 'bilinear') | |
h,w = prompt.shape[-2:] | |
padh = 256 - h | |
padw = 256 - w | |
prompt = F.pad(prompt, (0, padw, 0, padh)) | |
prompt = (prompt / 255) * 40 - 20 | |
print(prompt.shape) | |
prompt = prompt.squeeze(1).cpu().numpy() | |
masks, scores, logits = predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
mask_input=prompt, | |
multimask_output=False, | |
) | |
''' |