import os, sys import random import warnings import copy os.system("python -m pip install -e asam") os.system("python -m pip install -e GroundingDINO") # os.system("python -m pip uninstall gradio") os.system("python -m pip install gradio==3.38.0") os.system("pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel") sys.path.append(os.path.join(os.getcwd(), "GroundingDINO")) sys.path.append(os.path.join(os.getcwd(), "asam")) warnings.filterwarnings("ignore") import gradio as gr import argparse import numpy as np import torch import torchvision from PIL import Image, ImageDraw, ImageFont from scipy import ndimage # Grounding DINO import GroundingDINO.groundingdino.datasets.transforms as T from GroundingDINO.groundingdino.models import build_model from GroundingDINO.groundingdino.util.slconfig import SLConfig from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # segment anything from segment_anything import build_sam_vit_b, SamPredictor import numpy as np # BLIP from transformers import BlipProcessor, BlipForConditionalGeneration def generate_caption(processor, blip_model, raw_image): # unconditional image captioning inputs = processor(raw_image, return_tensors="pt").to( device) #fp 16 out = blip_model.generate(**inputs) caption = processor.decode(out[0], skip_special_tokens=True) return caption def transform_image(image_pil): transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image, _ = transform(image_pil, None) # 3, h, w return image def load_model(model_config_path, model_checkpoint_path, device): args = SLConfig.fromfile(model_config_path) args.device = device model = build_model(args) checkpoint = torch.load(model_checkpoint_path, map_location="cpu") load_res = model.load_state_dict( clean_state_dict(checkpoint["model"]), strict=False) print(load_res) _ = model.eval() return model def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True): caption = caption.lower() caption = caption.strip() if not caption.endswith("."): caption = caption + "." with torch.no_grad(): outputs = model(image[None], captions=[caption]) logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) logits.shape[0] # filter output logits_filt = logits.clone() boxes_filt = boxes.clone() filt_mask = logits_filt.max(dim=1)[0] > box_threshold logits_filt = logits_filt[filt_mask] # num_filt, 256 boxes_filt = boxes_filt[filt_mask] # num_filt, 4 logits_filt.shape[0] # get phrase tokenlizer = model.tokenizer tokenized = tokenlizer(caption) # build pred pred_phrases = [] scores = [] for logit, box in zip(logits_filt, boxes_filt): pred_phrase = get_phrases_from_posmap( logit > text_threshold, tokenized, tokenlizer) if with_logits: pred_phrases.append( pred_phrase + f"({str(logit.max().item())[:4]})") else: pred_phrases.append(pred_phrase) scores.append(logit.max().item()) return boxes_filt, torch.Tensor(scores), pred_phrases def draw_mask(mask, draw, random_color=False): if random_color: color = (random.randint(0, 255), random.randint( 0, 255), random.randint(0, 255), 153) else: color = (30, 144, 255, 153) nonzero_coords = np.transpose(np.nonzero(mask)) for coord in nonzero_coords: draw.point(coord[::-1], fill=color) def draw_box(box, draw, label): # random color color = tuple(np.random.randint(0, 255, size=3).tolist()) draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=2) if label: font = ImageFont.load_default() if hasattr(font, "getbbox"): bbox = draw.textbbox((box[0], box[1]), str(label), font) else: w, h = draw.textsize(str(label), font) bbox = (box[0], box[1], w + box[0], box[1] + h) draw.rectangle(bbox, fill=color) draw.text((box[0], box[1]), str(label), fill="white") draw.text((box[0], box[1]), label) def draw_point(point, draw, r=10): show_point = [] for p in point: x,y = p draw.ellipse((x-r, y-r, x+r, y+r), fill='green') config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py' ckpt_filenmae = "groundingdino_swint_ogc.pth" sam_checkpoint = 'sam_vit_b_01ec64.pth' asam_checkpoint = 'asam_vit_b.pth' output_dir = "outputs" device = 'cuda' if torch.cuda.is_available() else 'cpu' blip_processor = None blip_model = None groundingdino_model = None sam_predictor = None def run_grounded_sam(input_image, text_prompt, task_type, box_threshold, text_threshold, iou_threshold): print(text_prompt, type(text_prompt)) global blip_processor, blip_model, groundingdino_model, sam_predictor # make dir os.makedirs(output_dir, exist_ok=True) # load image scribble = np.array(input_image["mask"]) image_pil = input_image["image"].convert("RGB") transformed_image = transform_image(image_pil) print('img sum:' ,torch.sum(transformed_image).to(torch.int).item()) if groundingdino_model is None: groundingdino_model = load_model( config_file, ckpt_filenmae, device=device) if task_type == 'automatic': # generate caption and tags # use Tag2Text can generate better captions # https://huggingface.co/spaces/xinyu1205/Tag2Text # but there are some bugs... blip_processor = blip_processor or BlipProcessor.from_pretrained( "Salesforce/blip-image-captioning-large") blip_model = blip_model or BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-large").to(device) #torch_dtype=torch.float16 text_prompt = generate_caption(blip_processor, blip_model, image_pil) print(f"Caption: {text_prompt}") # run grounding dino model boxes_filt, scores, pred_phrases = get_grounding_output( groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold ) size = image_pil.size # process boxes H, W = size[1], size[0] for i in range(boxes_filt.size(0)): boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 boxes_filt[i][2:] += boxes_filt[i][:2] boxes_filt = boxes_filt.cpu() # nms print(f"Before NMS: {boxes_filt.shape[0]} boxes") nms_idx = torchvision.ops.nms( boxes_filt, scores, iou_threshold).numpy().tolist() boxes_filt = boxes_filt[nms_idx] pred_phrases = [pred_phrases[idx] for idx in nms_idx] print(f"After NMS: {boxes_filt.shape[0]} boxes") if sam_predictor is None: # initialize SAM assert sam_checkpoint, 'sam_checkpoint is not found!' sam = build_sam_vit_b(checkpoint=sam_checkpoint) sam.to(device=device) sam_predictor = SamPredictor(sam) image = np.array(image_pil) sam_predictor.set_image(image) if task_type == 'automatic': # use NMS to handle overlapped boxes print(f"Revise caption with number: {text_prompt}") if task_type == 'default_box' or task_type == 'automatic' or task_type == 'scribble_box': if task_type == 'default_box': id = torch.sum(transformed_image).to(torch.int).item() if id == -1683627: #example 1 * x_min, y_min, x_max, y_max = 204, 213, 813, 1023 elif id == 1137390: #example 2 * x_min, y_min, x_max, y_max = 125, 168, 842, 904 elif id == 1145309: #example 3 * x_min, y_min, x_max, y_max = 0, 486, 992, 899 elif id == 1091779: #example 4 * x_min, y_min, x_max, y_max = 2, 73, 981, 968 elif id == -1335352: #example 5 * x_min, y_min, x_max, y_max = 201, 195, 811, 1023 elif id == -1479645: #example 6 x_min, y_min, x_max, y_max = 428, 0, 992, 799 elif id == -544197: #example 7 x_min, y_min, x_max, y_max = 106, 419, 312, 783 elif id == -23873: #example 8 x_min, y_min, x_max, y_max = 250, 25, 774, 803 elif id == -1572157: #example 9 * x_min, y_min, x_max, y_max = 15, 88, 1006, 977 else: print("not defined") raise NotImplementedError bbox = np.array([x_min, y_min, x_max, y_max]) bbox = torch.tensor(bbox).unsqueeze(0) transformed_boxes = sam_predictor.transform.apply_boxes_torch(bbox, image.shape[:2]).to(device) elif task_type == 'scribble_box': scribble = scribble.transpose(2, 1, 0)[0] labeled_array, num_features = ndimage.label(scribble >= 255) centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1)) centers = np.array(centers) ### (x1, y1, x2, y2) x_min = centers[:, 0].min() x_max = centers[:, 0].max() y_min = centers[:, 1].min() y_max = centers[:, 1].max() bbox = np.array([x_min, y_min, x_max, y_max]) bbox = torch.tensor(bbox).unsqueeze(0) transformed_boxes = sam_predictor.transform.apply_boxes_torch(bbox, image.shape[:2]).to(device) else: transformed_boxes = sam_predictor.transform.apply_boxes_torch( boxes_filt, image.shape[:2]).to(device) a_image_pil = copy.deepcopy(image_pil) # sam`s output sam_predictor.model.load_state_dict(torch.load(sam_checkpoint,map_location='cpu')) masks, _, _ = sam_predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False, ) print(torch.sum(masks), masks.device) # masks: [1, 1, 512, 512] mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0)) mask_draw = ImageDraw.Draw(mask_image) for mask in masks: draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True) image_draw = ImageDraw.Draw(image_pil) if task_type == 'scribble_box' or task_type == 'default_box': for box in bbox: draw_box(box, image_draw, None) else: for box, label in zip(boxes_filt, pred_phrases): draw_box(box, image_draw, label) if task_type == 'automatic': image_draw.text((10, 10), text_prompt, fill='black') image_pil = image_pil.convert('RGBA') image_pil.alpha_composite(mask_image) # asam`s output total_weights = 0 for param in sam_predictor.model.parameters(): total_weights += param.data.sum() print("Total sum of model weights:", total_weights.item()) sam_predictor.model.load_state_dict(torch.load(asam_checkpoint,map_location='cpu')) total_weights = 0 for param in sam_predictor.model.parameters(): total_weights += param.data.sum() print("Total sum of model weights:", total_weights.item()) a_masks, _, _ = sam_predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False, ) print(torch.sum(a_masks)) # masks: [1, 1, 512, 512] a_mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0)) a_mask_draw = ImageDraw.Draw(a_mask_image) for a_mask in a_masks: draw_mask(a_mask[0].cpu().numpy(), a_mask_draw, random_color=True) a_image_draw = ImageDraw.Draw(a_image_pil) if task_type == 'scribble_box' or task_type == 'default_box': for box in bbox: draw_box(box, a_image_draw, None) else: for box, label in zip(boxes_filt, pred_phrases): draw_box(box, a_image_draw, label) if task_type == 'automatic': a_image_draw.text((10, 10), text_prompt, fill='black') a_image_pil = a_image_pil.convert('RGBA') a_image_pil.alpha_composite(a_mask_image) return [[image_pil, mask_image],[a_image_pil, a_mask_image]] elif task_type == 'scribble_point': scribble = scribble.transpose(2, 1, 0)[0] labeled_array, num_features = ndimage.label(scribble >= 255) centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1)) centers = np.array(centers) point_coords = centers point_labels = np.ones(point_coords.shape[0]) a_image_pil = copy.deepcopy(image_pil) # sam`s output sam_predictor.model.load_state_dict(torch.load(sam_checkpoint,map_location='cpu')) masks, _, _ = sam_predictor.predict( point_coords=point_coords, point_labels=point_labels, box=None, multimask_output=False, ) mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0)) mask_draw = ImageDraw.Draw(mask_image) for mask in masks: draw_mask(mask, mask_draw, random_color=True) image_draw = ImageDraw.Draw(image_pil) draw_point(point_coords,image_draw) image_pil = image_pil.convert('RGBA') image_pil.alpha_composite(mask_image) # asam`s output sam_predictor.model.load_state_dict(torch.load(asam_checkpoint,map_location='cpu')) a_masks, _, _ = sam_predictor.predict( point_coords=point_coords, point_labels=point_labels, box=None, multimask_output=False, ) a_mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0)) a_mask_draw = ImageDraw.Draw(a_mask_image) for a_mask in a_masks: draw_mask(a_mask, a_mask_draw, random_color=True) a_image_draw = ImageDraw.Draw(a_image_pil) draw_point(point_coords,a_image_draw) a_image_pil = a_image_pil.convert('RGBA') a_image_pil.alpha_composite(a_mask_image) return [[image_pil, mask_image],[a_image_pil, a_mask_image]] else: print("task_type:{} error!".format(task_type)) if __name__ == "__main__": parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True) parser.add_argument("--debug", action="store_true", help="using debug mode") parser.add_argument("--share", action="store_true", help="share the app") parser.add_argument('--no-gradio-queue', action="store_true", help='path to the SAM checkpoint') args = parser.parse_args() print(args) block = gr.Blocks() if not args.no_gradio_queue: block = block.queue() with block: gr.Markdown( """ # ASAM Welcome to the ASAM demo
You may select different prompt types to get the output mask of target instance. ## Usage You may check the instruction below, or check our github page about more details. ## Mode You may select an example image or upload your image to start, we support 4 prompt types: **default_box**: According to the mask label, automaticly generate the default box prompt, only used for examples. **automatic**: Automaticly generate text prompt and the corresponding box input with BLIP and Grounding-DINO. **scribble_point**: Click an point on the target instance. **scribble_box**: Click on two points, the top-left point and the bottom-right point to represent a bounding box of the target instance. """) with gr.Row(): with gr.Column(): input_image = gr.Image( source='upload', type="pil", value="example9.jpg", tool="sketch",brush_radius=20) task_type = gr.Dropdown( ["default_box","automatic", "scribble_point", "scribble_box"], value="default_box", label="task_type") text_prompt = gr.Textbox(label="Text Prompt", placeholder="bench .", visible=False) run_button = gr.Button(label="Run") with gr.Accordion("Advanced options", open=False): box_threshold = gr.Slider( label="Box Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.001 ) text_threshold = gr.Slider( label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 ) iou_threshold = gr.Slider( label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001 ) with gr.Column(): gr.Markdown( """ # SAM`s output """) gallery1 = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(preview=True, grid=2, object_fit="scale-down") gr.Markdown( """ # ASAM`s output """) gallery2 = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(preview=True, grid=2, object_fit="scale-down") with gr.Row(): with gr.Column(): gr.Examples(["example1.jpg"], inputs=input_image) with gr.Column(): gr.Examples(["example2.jpg"], inputs=input_image) with gr.Column(): gr.Examples(["example3.jpg"], inputs=input_image) with gr.Column(): gr.Examples(["example4.jpg"], inputs=input_image) with gr.Column(): gr.Examples(["example5.jpg"], inputs=input_image) with gr.Column(): gr.Examples(["example6.jpg"], inputs=input_image) with gr.Column(): gr.Examples(["example7.jpg"], inputs=input_image) with gr.Column(): gr.Examples(["example8.jpg"], inputs=input_image) with gr.Column(): gr.Examples(["example9.jpg"], inputs=input_image) run_button.click(fn=run_grounded_sam, inputs=[ input_image, text_prompt, task_type, box_threshold, text_threshold, iou_threshold], outputs=[gallery1,gallery2]) block.launch(debug=args.debug, share=args.share, show_error=True)