"""Run a demo of the CaR model on a single image.""" import numpy as np import os import argparse from functools import reduce import PIL.Image as Image import torch from modeling.model import CaR from utils.utils import Config, load_yaml import matplotlib.pyplot as plt import colorsys from modeling.post_process.post_process import ( match_masks, generate_masks_from_sam, ) from sam.sam import SAMPipeline from sam.utils import build_sam_config import random import time def generate_distinct_colors(n): colors = [] # generate a random number from 0 to 1 random_color_bias = random.random() for i in range(n): hue = float(i) / n hue += random_color_bias hue = hue % 1.0 rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0) # Convert RGB values from [0, 1] range to [0, 255] colors.append(tuple(int(val * 255) for val in rgb)) return colors def overlap_masks(masks): """ Overlap masks to generate a single mask for visualization. Parameters: - masks: list of np.arrays of shape (H, W) representing binary masks for each class. Returns: - overlap_mask: list of np.array of shape (H, W) that have no overlaps """ overlap_mask = torch.zeros_like(masks[0]) for mask_idx, mask in enumerate(masks): overlap_mask[mask > 0] = mask_idx + 1 clean_masks = [ overlap_mask == mask_idx + 1 for mask_idx in range(len(masks)) ] clean_masks = torch.stack(clean_masks, dim=0) return clean_masks def visualize_segmentation( image, masks, class_names, alpha=0.45, y_list=None, x_list=None ): """ Visualize segmentation masks on an image. Parameters: - image: np.array of shape (H, W, 3) representing the RGB image - masks: list of np.arrays of shape (H, W) representing binary masks for each class. - class_names: list of strings representing names of each class - alpha: float, transparency level of masks on the image Returns: - visualization: plt.figure object """ # Create a figure and axis fig, ax = plt.subplots(1, figsize=(12, 9)) # Display the image # ax.imshow(image) # Generate distinct colors for each mask final_mask = np.zeros( (masks.shape[1], masks.shape[2], 3), dtype=np.float32 ) colors = generate_distinct_colors(len(class_names)) idx = 0 for mask, color, class_name in zip(masks, colors, class_names): # Overlay the mask final_mask += np.dstack([mask * c for c in color]) # Find a representative point (e.g., centroid) for placing the label if y_list is None or x_list is None: y, x = np.argwhere(mask).mean(axis=0) else: y, x = y_list[idx], x_list[idx] ax.text( x, y, class_name, color="white", fontsize=36, va="center", ha="center", bbox=dict(facecolor="black", alpha=0.7, edgecolor="none"), ) idx += 1 final_image = image * (1 - alpha) + final_mask * alpha final_image = final_image.astype(np.uint8) ax.imshow(final_image) # Remove axis ticks and labels ax.axis("off") return fig def get_sam_masks(config, image_path, masks, img_sam=None, pipeline=None): print("generating sam masks online") mask_tensor, mask_list = generate_masks_from_sam( image_path, save_path="./", pipeline=pipeline, img_sam=img_sam, visualize=False, ) mask_tensor = mask_tensor.to(masks.device) # only conduct sam on masks that is not all zero attn_map, mask_ids = [], [] for mask_id, mask in enumerate(masks): if torch.sum(mask) > 0: attn_map.append(mask.unsqueeze(0)) mask_ids.append(mask_id) matched_masks = [ match_masks( mask_tensor, attn, mask_list, iom_thres=config.car.iom_thres, min_pred_threshold=config.sam.min_pred_threshold, ) for attn in attn_map ] for matched_mask, mask_id in zip(matched_masks, mask_ids): sam_masks = np.array([item["segmentation"] for item in matched_mask]) sam_mask = np.any(sam_masks, axis=0) masks[mask_id] = torch.from_numpy(sam_mask).to(masks.device) return masks def load_sam(config, sam_device): sam_checkpoint, model_type = build_sam_config(config) pipelines = SAMPipeline( sam_checkpoint, model_type, device=sam_device, points_per_side=config.sam.points_per_side, pred_iou_thresh=config.sam.pred_iou_thresh, stability_score_thresh=config.sam.stability_score_thresh, box_nms_thresh=config.sam.box_nms_thresh, ) return pipelines if __name__ == "__main__": parser = argparse.ArgumentParser("CaR") # default arguments # additional arguments parser.add_argument( "--output_path", type=str, default="", help="path to save outputs" ) parser.add_argument( "--cfg-path", default="configs/voc_test.yaml", help="path to configuration file.", ) args = parser.parse_args() cfg = Config(**load_yaml(args.cfg_path)) device = "cuda" if torch.cuda.is_available() else "cpu" # device = 'cpu' folder_name = reduce( lambda x, y: x.replace(" ", "_") + "_" + y, cfg.image_caption ) if len(folder_name) > 20: folder_name = folder_name[:20] car_model = CaR( cfg, visualize=True, seg_mode=cfg.test.seg_mode, device=device ) sam_pipeline = load_sam(cfg, device) img = Image.open(cfg.image_path).convert("RGB") import pdb; pdb.set_trace() # resize image by dividing 2 if the size is larger than 1000 if img.size[0] > 1000: img = img.resize((img.size[0] // 3, img.size[1] // 3)) label_space = cfg.image_caption pseudo_masks, scores, _ = car_model(img, label_space) if not cfg.test.use_pseudo: t1 = time.time() pseudo_masks = get_sam_masks( cfg, cfg.image_path, pseudo_masks, img_sam=np.array(img), pipeline=sam_pipeline, ) pseudo_masks = overlap_masks(pseudo_masks) t2 = time.time() print(f"sam time: {t2 - t1}") # visualize segmentation masks demo_fig = visualize_segmentation( np.array(img), pseudo_masks.detach().cpu().numpy(), label_space, ) save_path = f"vis_results/{folder_name}" if not os.path.exists(save_path): os.makedirs(save_path) demo_fig.savefig(os.path.join(save_path, "demo.png"), bbox_inches="tight") print(f"results saved to {save_path}.")