Kevin Sun
init commit
6cd90b7
raw
history blame
11.6 kB
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of CaR."""
import os
import clip
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
# pylint: disable=g-importing-member
# pylint: disable=g-bad-import-order
from modeling.model.clip_wrapper import CLIPWrapper
from modeling.model.clip_wrapper import forward_clip
from modeling.model.clipcam import CLIPCAM
from modeling.model.crf import PostProcess
from modeling.model.utils import apply_visual_prompts
from utils.visualize import viz_attn
class CaR(nn.Module):
"""CaR module."""
def __init__(
self,
cfg,
device="cpu",
visualize=False,
confidence_threshold=0.45,
save_path="save_path",
seg_mode="refer",
semantic_clip_model_name=None,
semantic_pretrained_data=None,
semantic_templates=None,
text_template=None,
visual_prompt_type="circle",
clipes_threshold=0.4,
cam_text_template="a clean origami {}.",
bg_cls=None,
iom_thres=0.6,
min_pred_threshold=0.01,
bg_factor=1.0,
mask_threshold=0.5,
):
"""CaR model for image segmentation.
Args:
cfg: the config file.
device: the device to run the model.
visualize: whether to visualize the intermediate results
confidence_threshold: the confidence threshold for semantic
segmentation. If the confidence score is lower than the threshold, the
mask will be discarded.
save_path: the path to save the intermediate results
seg_mode: the segmentation mode, can be 'refer' or 'semantic'
semantic_clip_model_name: the name of the semantic segmentation model.
semantic_pretrained_data: the path to the pretrained semantic
segmentation model.
semantic_templates: the templates for semantic segmentation.
text_template: the template for visual prompting.
visual_prompt_type: the type of visual prompting.
clipes_threshold: the threshold for CLIPES.
cam_text_template: the template for CAM.
bg_cls: background classes.
iom_thres: IoM threshold.
min_pred_threshold: Prediction threshold.
bg_factor: Background factor.
mask_threshold: Mask threshold.
"""
super(CaR, self).__init__()
# CLIP parameters
self.confidence_threshold = confidence_threshold
self.device = device
self.visualize = visualize
self.save_path = save_path
self.seg_mode = seg_mode
self.semantic_clip_model_name = semantic_clip_model_name
self.semantic_pretrained_data = semantic_pretrained_data
self.semantic_templates = semantic_templates
self.text_template = text_template
self.visual_prompt_type = visual_prompt_type
self.clipes_threshold = clipes_threshold
self.cam_text_template = cam_text_template
self.iom_thres = iom_thres
self.min_pred_threshold = min_pred_threshold
self.bg_cls = bg_cls
self.bg_factor = bg_factor
self.mask_threshold = mask_threshold
if not hasattr(cfg, "clip"):
raise ValueError("The config file should contain the CLIP parameters.")
if not hasattr(cfg, "car"):
raise ValueError("The config file should contain the car parameters.")
if hasattr(cfg, "cam"):
raise ValueError("cfg.cam is deprecated, please use cfg.car ")
for k, v in vars(cfg.clip).items():
setattr(self, k, v)
for k, v in vars(cfg.car).items():
setattr(self, k, v)
if hasattr(cfg, "sam"):
for k, v in vars(cfg.sam).items():
setattr(self, k, v)
if not self.bg_cls:
self.bg_cls = None
print(f"The model is running on {self.device}")
self.clip_model, self.preprocess = clip.load(
self.clip_model_name, device=self.device
)
self.clip_model = CLIPWrapper(self.clip_model)
self.post_process = PostProcess(device=self.device)
self.mask_generator = CLIPCAM(
self.clip_model,
device=self.device,
text_template=self.text_template,
threshold=self.clipes_threshold,
bg_cls=self.bg_cls,
)
self.semantic_clip_model, self.semantic_preprocess = clip.load(
self.semantic_clip_model_name, device=self.device
)
self.semantic_clip_model = CLIPWrapper(self.semantic_clip_model)
def get_confidence(self, cam_map, binary_cam_map):
confidence_map = torch.sum(cam_map * binary_cam_map[None], dim=[2, 3])
confidence_map = confidence_map / torch.sum(binary_cam_map, dim=[1, 2])
confidence_score = confidence_map.squeeze()
return confidence_score
def set_visual_prompt_type(self, visual_prompt_type):
self.visual_prompt_type = visual_prompt_type
def set_bg_factor(self, bg_factor):
self.bg_factor = bg_factor
def set_confidence_threshold(self, confidence_threshold):
self.confidence_threshold = confidence_threshold
def set_mask_threshold(self, mask_threshold):
self.mask_threshold = mask_threshold
def apply_visual_prompts(self, image, mask):
if torch.sum(mask).item() <= 1:
return image
image_array = np.array(image)
img_h = image_array.shape[0]
img_w = image_array.shape[1]
mask = (
F.interpolate(mask[None][None], size=(img_h, img_w), mode="nearest")
.squeeze()
.detach()
.cpu()
.numpy()
)
mask = (mask > self.mask_threshold).astype(np.uint8)
prompted_image = apply_visual_prompts(
image_array, mask, self.visual_prompt_type, self.visualize
)
return prompted_image
def get_mask_confidence(self, prompted_images, prompt_text):
"""Get the confidene for each mask with visual prompting."""
# get the center, width and height of the mask
prompted_tensor = torch.stack(
[self.semantic_preprocess(img) for img in prompted_images], dim=0
)
prompted_tensor = prompted_tensor.to(self.device)
h, w = prompted_tensor.shape[-2:]
text_prediction = forward_clip(
self.semantic_clip_model, prompted_tensor, prompt_text, h, w
)
return text_prediction
def _filter_texts(self, ori_mask_id, sem_scores, prompt_text):
"""Remove false positive masks by score filtering and recall the backbone to get the CAM maps for the filtered texts."""
if not ori_mask_id:
max_id = np.argmax(sem_scores)
ori_mask_id.append(max_id)
filtered_text = [prompt_text[i] for i in ori_mask_id]
return filtered_text
def _forward_stage(self, ori_img, cam_text, clip_text, semantic_prompt_text):
mask_proposals = self.get_mask_proposals(ori_img, cam_text)
num_texts = len(cam_text)
ori_mask_id = []
sem_scores = torch.zeros((num_texts,), device=self.device).float()
prompted_imgs = [
self.apply_visual_prompts(ori_img, cam_map)
for cam_map in mask_proposals
]
text_scores = self.get_mask_confidence(prompted_imgs, semantic_prompt_text)
mask_scores = torch.diagonal(text_scores)
for mask_idx, mask_score in enumerate(mask_scores):
# record mask idx
if mask_score > self.confidence_threshold:
ori_mask_id.append(mask_idx)
sem_scores[mask_idx] = mask_score
sem_scores = sem_scores.cpu().detach().numpy()
filtered_texts = self._filter_texts(ori_mask_id, sem_scores, clip_text)
# if isinstance(ori_img, list):
# ori_img = [ori_img[i] for i in ori_mask_id]
all_scores = torch.zeros((num_texts,), device=self.device).float()
sem_scores = torch.from_numpy(sem_scores).to(self.device)
for new_id, ori_id in enumerate(ori_mask_id):
if new_id >= len(mask_proposals):
# the mask is filtered out.
continue
all_scores[ori_id] = sem_scores[ori_id]
return filtered_texts, all_scores, mask_proposals
def _get_save_path(self, text):
folder_name = "_".join([t.replace(" ", "_") for t in text])
if len(folder_name) > 20:
folder_name = folder_name[:20]
output_path = os.path.join(self.save_path, folder_name)
sub_output_path = [
os.path.join(output_path, t.replace(" ", "_")) for t in text
]
return output_path, sub_output_path
def get_mask_proposals(self, img, text):
if self.seg_mode == "refer":
if isinstance(img, list):
cam_map_list = [self.mask_generator(i, t)[0] for i, t in zip(img, text)]
else:
cam_map_list = [self.mask_generator(img, t)[0] for t in text]
return torch.cat(cam_map_list, dim=0)
elif self.seg_mode == "semantic":
return self.mask_generator(img, text)[0]
else:
raise ValueError(
"Unknown segmentation mode. Only refer and semantic segmentation are"
" supported."
)
def _forward_car(self, ori_img, text):
if isinstance(text, str):
text = [text]
_, sub_output_path = self._get_save_path(text)
image_array = np.array(ori_img)
clip_text = [self.cam_text_template.format(t) for t in text]
cam_text = text
init_clip_text = clip_text # the text prompts of CLIP is different.
semantic_prompt_text = clip_text
# Apply semantic prompting augmentation.
if self.semantic_templates is not None:
semantic_prompt_text = []
for template in self.semantic_templates:
templated_text = [template.format(t) for t in text]
semantic_prompt_text.append(templated_text)
num_positive_last = 0
run = 0
while True:
run += 1
cur_texts, all_scores, mask_proposals = self._forward_stage(
ori_img, cam_text, clip_text, semantic_prompt_text
)
if cur_texts: # if there is no text, skip the refinement
cam_text = cur_texts
clip_text = cur_texts
num_positive = (all_scores > 0).sum().item()
if num_positive == num_positive_last:
# stop the refinement if the number of positive masks
# does not change.
break
num_positive_last = num_positive
# Apply densecrf for refinement.
# SAM is optional and is applied outside the model.
refined_masks = self.post_process(
ori_img,
mask_proposals,
separate=self.seg_mode == "refer",
bg_factor=self.bg_factor,
)
predicted_class_idx = [init_clip_text.index(t) for t in cur_texts]
if self.visualize:
_ = [
viz_attn(
image_array,
attn,
prefix=sub_output_path[aid],
img_name="semantic_mask",
)
for aid, attn in enumerate(refined_masks)
]
final_predicted_masks = torch.zeros(len(text), *refined_masks[0].shape)
final_all_scores = torch.zeros(len(text))
for idx, mask, score in zip(predicted_class_idx, refined_masks, all_scores):
final_predicted_masks[idx] = mask
final_all_scores[idx] = score
return final_predicted_masks, final_all_scores
def forward(self, im_ori, text):
# raw_image_np is the padded image input with shape (512, 512, 3)
pseudo_masks, conf_scores = self._forward_car(im_ori, text)
return pseudo_masks, conf_scores