import argparse import os import copy import numpy as np import torch from PIL import Image, ImageDraw, ImageFont import PIL # OwlViT Detection from transformers import OwlViTProcessor, OwlViTForObjectDetection # segment anything from segment_anything import build_sam, SamPredictor import cv2 import numpy as np import matplotlib.pyplot as plt import gc def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) def plot_boxes_to_image(image_pil, tgt): H, W = tgt["size"] boxes = tgt["boxes"] labels = tgt["labels"] assert len(boxes) == len(labels), "boxes and labels must have same length" draw = ImageDraw.Draw(image_pil) mask = Image.new("L", image_pil.size, 0) mask_draw = ImageDraw.Draw(mask) # draw boxes and masks for box, label in zip(boxes, labels): # random color color = tuple(np.random.randint(0, 255, size=3).tolist()) # draw x0, y0, x1, y1 = box x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) draw.rectangle([x0, y0, x1, y1], outline=color, width=6) draw.text((x0, y0), str(label), fill=color) font = ImageFont.load_default() if hasattr(font, "getbbox"): bbox = draw.textbbox((x0, y0), str(label), font) else: w, h = draw.textsize(str(label), font) bbox = (x0, y0, w + x0, y0 + h) # bbox = draw.textbbox((x0, y0), str(label)) draw.rectangle(bbox, fill=color) draw.text((x0, y0), str(label), fill="white") mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) return image_pil, mask # Use GPU if available if torch.cuda.is_available(): device = torch.device("cuda:4") else: device = torch.device("cpu") # load OWL-ViT model owlvit_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device) owlvit_model.eval() owlvit_processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") # run segment anything (SAM) sam_predictor = SamPredictor(build_sam(checkpoint="./sam_vit_h_4b8939.pth")) def query_image(img, text_prompt): # load image if not isinstance(img, PIL.Image.Image): pil_img = Image.fromarray(np.uint8(img)).convert('RGB') text_prompt = text_prompt texts = [text_prompt.split(",")] box_threshold = 0.0 # run object detection model with torch.no_grad(): inputs = owlvit_processor(text=texts, images=pil_img, return_tensors="pt").to(device) outputs = owlvit_model(**inputs) # Target image sizes (height, width) to rescale box predictions [batch_size, 2] target_sizes = torch.Tensor([pil_img.size[::-1]]) # Convert outputs (bounding boxes and class logits) to COCO API results = owlvit_processor.post_process_object_detection(outputs=outputs, threshold=box_threshold, target_sizes=target_sizes.to(device)) scores = torch.sigmoid(outputs.logits) topk_scores, topk_idxs = torch.topk(scores, k=1, dim=1) i = 0 # Retrieve predictions for the first image for the corresponding text queries text = texts[i] topk_idxs = topk_idxs.squeeze(1).tolist() topk_boxes = results[i]['boxes'][topk_idxs] topk_scores = topk_scores.view(len(text), -1) topk_labels = results[i]["labels"][topk_idxs] boxes, scores, labels = topk_boxes, topk_scores, topk_labels # boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] boxes = boxes.cpu().detach().numpy() normalized_boxes = copy.deepcopy(boxes) # # visualize pred size = pil_img.size pred_dict = { "boxes": normalized_boxes, "size": [size[1], size[0]], # H, W "labels": [text[idx] for idx in labels] } # release the OWL-ViT # owlvit_model.cpu() # del owlvit_model gc.collect() torch.cuda.empty_cache() # run segment anything (SAM) open_cv_image = np.array(pil_img) image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB) sam_predictor.set_image(image) H, W = size[1], size[0] for i in range(boxes.shape[0]): boxes[i] = torch.Tensor(boxes[i]) boxes = torch.tensor(boxes, device=sam_predictor.device) transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2]) masks, _, _ = sam_predictor.predict_torch( point_coords = None, point_labels = None, boxes = transformed_boxes, multimask_output = False, ) plt.figure(figsize=(10, 10)) plt.imshow(image) for mask in masks: show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) for box in boxes: show_box(box.numpy(), plt.gca()) plt.axis('off') import io buf = io.BytesIO() plt.savefig(buf) buf.seek(0) owlvit_segment_image = Image.open(buf).convert('RGB') # grounded results image_with_box = plot_boxes_to_image(pil_img, pred_dict)[0] return owlvit_segment_image, image_with_box # return owlvit_segment_image