File size: 5,670 Bytes
6cd90b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Post processing."""

import torch
import torch.nn.functional as F

# pylint: disable=g-bad-import-order
# pylint: disable=g-importing-member
from modeling.post_process.object_discovery import get_instances
from utils.metrics import IoM


# This should be a abstract function to generate masks for the input image.
# However, we first hack it due to the time limit.
def generate_masks_from_sam(
    image_path, save_path, pipeline, img_sam=None, visualize=True
):
  """Generate masks from SAM."""
  masks, _, mask_list = pipeline.segment_automask(
      image_path=image_path,
      visualize=visualize,
      save_path=save_path,
      image=img_sam,
  )
  mask_tensor = torch.from_numpy(masks)
  mask_tensor = mask_tensor.float()
  return mask_tensor, mask_list


def match_masks(
    mask_tensor, attn_map, mask_list, iom_thres=0.0, min_pred_threshold=0.2
):
  """Match masks with the attention map according to the IoU.

  Args:
      mask_tensor: A torch.Tensor for the masks with shape [num_masks, height,
        width].
      attn_map: A torch.Tensor for the attention map with shape [1, 1, height,
        width].
      mask_list: A list of masks with shape [num_masks, height, width]
      iom_thres: A float for the threshold to apply to the attention map.
      min_pred_threshold: The prediction score threshold.

  Returns:
      A list of matched_masks with shape [num_masks, height, width],
      len(matched_masks) = number of captions
  """
  predictions = attn_map.squeeze(1).detach()
  iom = IoM(predictions, mask_tensor, min_pred_threshold=min_pred_threshold)
  keep_mask = iom > iom_thres
  # mask_tensor = mask_tensor[keep_mask]
  new_list = []
  for mid, m_dict in enumerate(mask_list):
    if keep_mask[mid]:
      new_list.append(m_dict)
  #  if not len(new_list):
  if not new_list:
    max_id = torch.argmax(iom)
    new_list.append(mask_list[max_id])
  return new_list


def post_process_mask(attn_masks, pad=None, min_area_ratio=0.15):
  """Post process attention masks."""
  if pad is not None:
    left, top, width, height = pad
    attn_masks = attn_masks[Ellipsis, top : top + height, left : left + width]
  else:
    height = None
    width = None
  mask_area = attn_masks.sum(dim=(1, 2))
  total_area = mask_area.sum()
  keep_mask = mask_area / total_area > min_area_ratio
  if torch.sum(keep_mask) == 0:
    if keep_mask.shape[0] == 0:
      return torch.zeros(
          (1, height, width), device=attn_masks.device, dtype=attn_masks.dtype
      )
    keep_mask[torch.argmax(mask_area)] = True
  attn_masks = attn_masks[keep_mask]
  return attn_masks


def filter_masks(
    attn_masks,
    pad=None,
    mask_threshold=0.3,
    min_area_ratio=0.15,
    return_largest=False,
    device=None,
    return_instances=False,
):
  """Filter attention mask below the threshold."""
  attn_masks[attn_masks < mask_threshold] = 0
  # get_instances will be operated on cpu
  ins_masks = get_instances(attn_masks, return_largest=return_largest)
  ins_masks = [post_process_mask(m, pad, min_area_ratio) for m in ins_masks]
  ins_masks = list(filter(lambda x: x is not None, ins_masks))
  ins_masks = [m.to(device) for m in ins_masks]
  if not return_instances:
    return [torch.any(m, dim=0, keepdim=True).to(m.dtype) for m in ins_masks]
  return ins_masks


def post_process(
    input_array,
    attn_masks,
    pad=None,
    mask_threshold=0.3,
    return_largest=False,
    min_area_ratio=0.15,
    return_instances=False,
):
  """post process the input tensor with the attention masks.

  Args:
      input_array: A np.ndarray input array to be post processed with shape
        [width, height, 3, batch_size]
      attn_masks: A torch.Tensor for the attention masks with shape [1,
        num_texts, width, height]
      pad: A list of padding: [pad_left, pad_top, width, height], where
        pad_left, pad_top and width, height are int values.
      mask_threshold: The threshold to binarize the mask.
      return_largest: If true, return the largest connected component.
      min_area_ratio: Keep the mask if its area is larger than this threshold.
      return_instances: Whether to return instances or not.

  Returns:
      attn_masks: A list of tensors with shape [num_instances, height, width]
          x num_texts, where len(attn_masks) = num_texts.
      NOTE: the number_instances for each text (class) may vary.
      The output is a binary tensor.
  """
  if len(attn_masks.shape) == 3:
    attn_masks = attn_masks[None]
  img_width, img_height = input_array.shape[:2]
  attn_masks = F.interpolate(
      attn_masks, size=(img_height, img_width), mode='bicubic'
  ).squeeze(0)
  device = attn_masks.device
  output_masks = filter_masks(
      attn_masks,
      pad=pad,
      mask_threshold=mask_threshold,
      min_area_ratio=min_area_ratio,
      return_largest=return_largest,
      device=device,
      return_instances=return_instances,
  )
  if pad is not None:
    left, top, width, height = pad
    input_array = input_array[top : top + height, left : left + width]
  return input_array, output_masks