Our3D / lib /self_prompting.py
yansong1616's picture
Upload 384 files
b177539 verified
raw
history blame
6.54 kB
'''
Self prompting strategy
INPUT:
predictor: the initialized sam predictor :
rendered_mask_score: - : H*W*1
num_prompt: -
index_matrix: the matrix contains the 3D index of the rendered view : H*W*3
OUTPUT: a list of prompts
'''
import os
import torch
import math
import numpy as np
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
@torch.no_grad()
def mask_to_prompt(predictor, rendered_mask_score, index_matrix, num_prompts = 3):
'''main function for self prompting'''
h, w, _ = rendered_mask_score.shape
tmp = rendered_mask_score.view(-1)
print("tmp min:", tmp.min(), "tmp max:", tmp.max())
rand = torch.ones_like(tmp)
topk_v, topk_p = torch.topk(tmp*rand, k = 1)[0].cpu(), torch.topk(tmp*rand, k = 1)[1].cpu()
if topk_v <= 0:
print("No prompt is available")
return np.zeros((0,2)), np.ones((0))
prompt_points = []
prompt_points.append([topk_p[0] % w, topk_p[0] // w])
print((topk_p[0] % w).item(), (topk_p[0] // w).item(), h, w)
tmp_mask = rendered_mask_score.clone().detach()
area = to8b(tmp_mask.cpu().numpy()).sum() / 255
r = np.sqrt(area / math.pi)
masked_r = max(int(r) // 2, 2)
# masked_r = max(int(r) // 3, 2)
pre_tmp_mask_score = None
for _ in range(num_prompts - 1):
# mask out a region around the last prompt point
input_label = np.ones(len(prompt_points))
previous_masks, previous_scores, previous_logits = predictor.predict(
point_coords=np.array(prompt_points),
point_labels=input_label,
multimask_output=False,
)
l = 0 if prompt_points[-1][0]-masked_r <= 0 else prompt_points[-1][0]-masked_r
r = w-1 if prompt_points[-1][0]+masked_r >= w-1 else prompt_points[-1][0]+masked_r
t = 0 if prompt_points[-1][1]-masked_r <= 0 else prompt_points[-1][1]-masked_r
b = h-1 if prompt_points[-1][1]+masked_r >= h-1 else prompt_points[-1][1]+masked_r
tmp_mask[t:b+1, l:r+1, :] = -1e5
# bool: H W
previous_mask_tensor = torch.tensor(previous_masks[0])
previous_mask_tensor = previous_mask_tensor.unsqueeze(0).unsqueeze(0).float()
previous_mask_tensor = torch.nn.functional.max_pool2d(previous_mask_tensor, 25, stride = 1, padding = 12)
previous_mask_tensor = previous_mask_tensor.squeeze(0).permute([1,2,0])
# tmp_mask[previous_mask_tensor > 0] = -1e5
previous_max_score = torch.max(rendered_mask_score[previous_mask_tensor > 0])
previous_point_index = torch.zeros_like(index_matrix)
previous_point_index[:,:,0] = prompt_points[-1][1] / h
previous_point_index[:,:,1] = prompt_points[-1][0] / w
previous_point_index[:,:,2] = index_matrix[int(prompt_points[-1][1]), int(prompt_points[-1][0]), 2]
distance_matrix = torch.sqrt(((index_matrix - previous_point_index)**2).sum(-1))
distance_matrix = (distance_matrix.unsqueeze(-1) - distance_matrix.min()) / (distance_matrix.max() - distance_matrix.min())
cur_tmp_mask = tmp_mask - distance_matrix * max(previous_max_score, 0)
if pre_tmp_mask_score is None:
pre_tmp_mask_score = cur_tmp_mask
else:
pre_tmp_mask_score[pre_tmp_mask_score < cur_tmp_mask] = cur_tmp_mask[pre_tmp_mask_score < cur_tmp_mask]
pre_tmp_mask_score[tmp_mask == -1e5] = -1e5
tmp_val_point = pre_tmp_mask_score.view(-1).max(dim = 0)
if tmp_val_point[0] <= 0:
print("There are", len(prompt_points), "prompts")
break
prompt_points.append([int(tmp_val_point[1].cpu() % w), int(tmp_val_point[1].cpu() // w)])
prompt_points = np.array(prompt_points)
input_label = np.ones(len(prompt_points))
return prompt_points, input_label
from groundingdino.util.inference import load_model, load_image, predict, annotate
import cv2
import groundingdino.datasets.transforms as T
from torchvision.ops import box_convert
from PIL import Image
def image_transform(image) -> torch.Tensor:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_transformed, _ = transform(image, None)
return image_transformed
def grounding_dino_prompt(image, text):
image_tensor = image_transform(Image.fromarray(image))
model_root = './dependencies/GroundingDINO'
model = load_model(os.path.join(model_root, "groundingdino/config/GroundingDINO_SwinT_OGC.py"), os.path.join(model_root, "weights/groundingdino_swint_ogc.pth"))
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
boxes, logits, phrases = predict(
model=model,
image=image_tensor,
caption=text,
box_threshold=BOX_TRESHOLD,
text_threshold=TEXT_TRESHOLD
)
h, w, _ = image.shape
print("boxes device", boxes.device)
boxes = boxes * torch.Tensor([w, h, w, h]).to(boxes.device)
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
print(xyxy)
return xyxy
'''
# new prompt strategy: bbox based
# cannot be applied to 360
try:
prompt = seg_m_for_prompt[:,:,no]
prompt = prompt > 0
box_prompt = masks_to_boxes(prompt.unsqueeze(0))
width = box_prompt[0,2] - box_prompt[0,0]
height = box_prompt[0,3] - box_prompt[0,1]
box_prompt[0,0] -= 0.05*width
box_prompt[0,2] += 0.05*width
box_prompt[0,1] -= 0.05*height
box_prompt[0,3] += 0.05*height
# print(box_prompt)
transformed_boxes = predictor.transform.apply_boxes_torch(box_prompt, image.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
masks = masks.float()
except:
continue
'''
'''
# mask based
H,W,_ = prompt.shape
target_size = RLS.get_preprocess_shape(H,W, 256)
prompt = torch.nn.functional.interpolate(torch.tensor(prompt).float().unsqueeze(0).permute([0,3,1,2]), target_size, mode = 'bilinear')
h,w = prompt.shape[-2:]
padh = 256 - h
padw = 256 - w
prompt = F.pad(prompt, (0, padw, 0, padh))
prompt = (prompt / 255) * 40 - 20
print(prompt.shape)
prompt = prompt.squeeze(1).cpu().numpy()
masks, scores, logits = predictor.predict(
point_coords=None,
point_labels=None,
mask_input=prompt,
multimask_output=False,
)
'''