# coding=utf-8 # Copyright 2024 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Post processing.""" import torch import torch.nn.functional as F # pylint: disable=g-bad-import-order # pylint: disable=g-importing-member from modeling.post_process.object_discovery import get_instances from utils.metrics import IoM # This should be a abstract function to generate masks for the input image. # However, we first hack it due to the time limit. def generate_masks_from_sam( image_path, save_path, pipeline, img_sam=None, visualize=True ): """Generate masks from SAM.""" masks, _, mask_list = pipeline.segment_automask( image_path=image_path, visualize=visualize, save_path=save_path, image=img_sam, ) mask_tensor = torch.from_numpy(masks) mask_tensor = mask_tensor.float() return mask_tensor, mask_list def match_masks( mask_tensor, attn_map, mask_list, iom_thres=0.0, min_pred_threshold=0.2 ): """Match masks with the attention map according to the IoU. Args: mask_tensor: A torch.Tensor for the masks with shape [num_masks, height, width]. attn_map: A torch.Tensor for the attention map with shape [1, 1, height, width]. mask_list: A list of masks with shape [num_masks, height, width] iom_thres: A float for the threshold to apply to the attention map. min_pred_threshold: The prediction score threshold. Returns: A list of matched_masks with shape [num_masks, height, width], len(matched_masks) = number of captions """ predictions = attn_map.squeeze(1).detach() iom = IoM(predictions, mask_tensor, min_pred_threshold=min_pred_threshold) keep_mask = iom > iom_thres # mask_tensor = mask_tensor[keep_mask] new_list = [] for mid, m_dict in enumerate(mask_list): if keep_mask[mid]: new_list.append(m_dict) # if not len(new_list): if not new_list: max_id = torch.argmax(iom) new_list.append(mask_list[max_id]) return new_list def post_process_mask(attn_masks, pad=None, min_area_ratio=0.15): """Post process attention masks.""" if pad is not None: left, top, width, height = pad attn_masks = attn_masks[Ellipsis, top : top + height, left : left + width] else: height = None width = None mask_area = attn_masks.sum(dim=(1, 2)) total_area = mask_area.sum() keep_mask = mask_area / total_area > min_area_ratio if torch.sum(keep_mask) == 0: if keep_mask.shape[0] == 0: return torch.zeros( (1, height, width), device=attn_masks.device, dtype=attn_masks.dtype ) keep_mask[torch.argmax(mask_area)] = True attn_masks = attn_masks[keep_mask] return attn_masks def filter_masks( attn_masks, pad=None, mask_threshold=0.3, min_area_ratio=0.15, return_largest=False, device=None, return_instances=False, ): """Filter attention mask below the threshold.""" attn_masks[attn_masks < mask_threshold] = 0 # get_instances will be operated on cpu ins_masks = get_instances(attn_masks, return_largest=return_largest) ins_masks = [post_process_mask(m, pad, min_area_ratio) for m in ins_masks] ins_masks = list(filter(lambda x: x is not None, ins_masks)) ins_masks = [m.to(device) for m in ins_masks] if not return_instances: return [torch.any(m, dim=0, keepdim=True).to(m.dtype) for m in ins_masks] return ins_masks def post_process( input_array, attn_masks, pad=None, mask_threshold=0.3, return_largest=False, min_area_ratio=0.15, return_instances=False, ): """post process the input tensor with the attention masks. Args: input_array: A np.ndarray input array to be post processed with shape [width, height, 3, batch_size] attn_masks: A torch.Tensor for the attention masks with shape [1, num_texts, width, height] pad: A list of padding: [pad_left, pad_top, width, height], where pad_left, pad_top and width, height are int values. mask_threshold: The threshold to binarize the mask. return_largest: If true, return the largest connected component. min_area_ratio: Keep the mask if its area is larger than this threshold. return_instances: Whether to return instances or not. Returns: attn_masks: A list of tensors with shape [num_instances, height, width] x num_texts, where len(attn_masks) = num_texts. NOTE: the number_instances for each text (class) may vary. The output is a binary tensor. """ if len(attn_masks.shape) == 3: attn_masks = attn_masks[None] img_width, img_height = input_array.shape[:2] attn_masks = F.interpolate( attn_masks, size=(img_height, img_width), mode='bicubic' ).squeeze(0) device = attn_masks.device output_masks = filter_masks( attn_masks, pad=pad, mask_threshold=mask_threshold, min_area_ratio=min_area_ratio, return_largest=return_largest, device=device, return_instances=return_instances, ) if pad is not None: left, top, width, height = pad input_array = input_array[top : top + height, left : left + width] return input_array, output_masks