# coding=utf-8 # Copyright 2024 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Find objects.""" # pylint: disable=g-importing-member import numpy as np import scipy from scipy import ndimage from scipy.linalg import eigh from scipy.ndimage import label import torch import torch.nn.functional as F def ncut( feats, dims, scales, init_image_size, tau=0, eps=1e-5, no_binary_graph=False, ): """Implementation of NCut Method. Args: feats: the pixel/patche features of an image dims: dimension of the map from which the features are used scales: from image to map scale init_image_size: size of the image tau: thresold for graph construction eps: graph edge weight no_binary_graph: ablation study for using similarity score as graph edge weight Returns: TODO """ feats = feats[0, 1:, :] feats = F.normalize(feats, p=2) a = feats @ feats.transpose(1, 0) a = a.cpu().numpy() if no_binary_graph: a[a < tau] = eps else: a = a > tau a = np.where(a.astype(float) == 0, eps, a) d_i = np.sum(a, axis=1) d = np.diag(d_i) # Print second and third smallest eigenvector _, eigenvectors = eigh(d - a, d, subset_by_index=[1, 2]) eigenvec = np.copy(eigenvectors[:, 0]) # Using average point to compute bipartition second_smallest_vec = eigenvectors[:, 0] avg = np.sum(second_smallest_vec) / len(second_smallest_vec) bipartition = second_smallest_vec > avg seed = np.argmax(np.abs(second_smallest_vec)) if bipartition[seed] != 1: eigenvec = eigenvec * -1 bipartition = np.logical_not(bipartition) bipartition = bipartition.reshape(dims).astype(float) # predict BBox # We only extract the principal object BBox pred, _, objects, cc = detect_box( bipartition, seed, dims, scales=scales, initial_im_size=init_image_size[1:], ) mask = np.zeros(dims) mask[cc[0], cc[1]] = 1 return np.asarray(pred), objects, mask, seed, None, eigenvec.reshape(dims) def grad_obj_discover_on_attn(attn, gradcam, dims, topk=1, threshold=0.6): """Get the gradcam and attn map, then find the seed, then use LOST algorithm to find the potential points. Args: attn: attention map from ViT averaged across all heads, shape: [1, (1+num_patches), (1+num_patches)]. gradcam: gradcam map from ViT, shape: [1, 1, H, W]. dims: topk: threshold: Returns: th_attn: """ w_featmap, h_featmap = dims # nh = attn.shape[1] attn = attn.squeeze() seeds = torch.argsort(gradcam.flatten(), descending=True)[:topk] # We keep only the output patch attention # Get the attentions corresponding to [CLS] token patch_attn = attn[1:, 1:] topk_attn = patch_attn[seeds] nh = topk_attn.shape[0] # attentions = attn[0, :, 0, 1:].reshape(nh, -1) # we keep only a certain percentage of the mass val, idx = torch.sort(topk_attn) val /= torch.sum(val, dim=1, keepdim=True) cumval = torch.cumsum(val, dim=1) th_attn = cumval > (1 - threshold) idx2 = torch.argsort(idx) for h in range(nh): th_attn[h] = th_attn[h][idx2[h]] th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() th_attn = th_attn.sum(0) th_attn[th_attn > 1] = 1 return th_attn[None, None] def grad_obj_discover(feats, gradcam, dims): """Using gradient heatmap to find the seed, then use LOST algorithm to find the potential points. Args: feats: the pixel/patche features of an image. Shape: [1, HW, C] gradcam: the grad cam map dims: dimension of the map from which the features are used Returns: pred: box predictions A: binary affinity matrix scores: lowest degree scores for all patches seed: selected patch corresponding to an object """ # Compute the similarity a = (feats @ feats.transpose(1, 2)).squeeze() # Compute the inverse degree centrality measure per patch # sorted_patches, scores = patch_scoring(a) # Select the initial seed # seed = sorted_patches[0] seed = gradcam.argmax() mask = a[seed] mask = mask.view(1, 1, *dims) return mask def lost(feats, dims, scales, init_image_size, k_patches=100): """Implementation of LOST method. Args: feats: the pixel/patche features of an image. Shape: [1, C, H, W] dims: dimension of the map from which the features are used scales: from image to map scale init_image_size: size of the image k_patches: number of k patches retrieved that are compared to the seed at seed expansion. Returns: pred: box predictions A: binary affinity matrix scores: lowest degree scores for all patches seed: selected patch corresponding to an object """ # Compute the similarity feats = feats.flatten(2).transpose(1, 2) a = (feats @ feats.transpose(1, 2)).squeeze() # Compute the inverse degree centrality measure per patch sorted_patches, _ = patch_scoring(a) # Select the initial seed seed = sorted_patches[0] # Seed expansion potentials = sorted_patches[:k_patches] similars = potentials[a[seed, potentials] > 0.0] m = torch.sum(a[similars, :], dim=0) # Box extraction _, _, _, mask = detect_box( m, seed, dims, scales=scales, initial_im_size=init_image_size[1:] ) return mask # return np.asarray(bbox), A, scores, seed def patch_scoring(m, threshold=0.0): """Patch scoring based on the inverse degree.""" # Cloning important a = m.clone() # Zero diagonal a.fill_diagonal_(0) # Make sure symmetric and non nul a[a < 0] = 0 # C = A + A.t() # Sort pixels by inverse degree cent = -torch.sum(a > threshold, dim=1).type(torch.float32) sel = torch.argsort(cent, descending=True) return sel, cent def detect_box( bipartition, seed, dims, initial_im_size=None, scales=None, principle_object=True, ): """Extract a box corresponding to the seed patch.""" # Among connected components extract from the affinity matrix, select the one # corresponding to the seed patch. # w_featmap, h_featmap = dims objects, _ = ndimage.label(bipartition) cc = objects[np.unravel_index(seed, dims)] if principle_object: mask = np.where(objects == cc) # Add +1 because excluded max ymin, ymax = min(mask[0]), max(mask[0]) + 1 xmin, xmax = min(mask[1]), max(mask[1]) + 1 # Rescale to image size r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax pred = [r_xmin, r_ymin, r_xmax, r_ymax] # Check not out of image size (used when padding) if initial_im_size: pred[2] = min(pred[2], initial_im_size[1]) pred[3] = min(pred[3], initial_im_size[0]) # Coordinate predictions for the feature space # Axis different then in image space pred_feats = [ymin, xmin, ymax, xmax] return pred, pred_feats, objects, mask else: raise NotImplementedError # This function is modified from # https://github.com/facebookresearch/dino/blob/main/visualize_attention.py # Ref: https://github.com/facebookresearch/dino. def dino_seg(attn, dims, patch_size, head=0): """Extraction of boxes based on the DINO segmentation method proposed in DINO.""" w_featmap, h_featmap = dims nh = attn.shape[1] official_th = 0.6 # We keep only the output patch attention # Get the attentions corresponding to [CLS] token attentions = attn[0, :, 0, 1:].reshape(nh, -1) # we keep only a certain percentage of the mass val, idx = torch.sort(attentions) val /= torch.sum(val, dim=1, keepdim=True) cumval = torch.cumsum(val, dim=1) th_attn = cumval > (1 - official_th) idx2 = torch.argsort(idx) for h in range(nh): th_attn[h] = th_attn[h][idx2[h]] th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() # Connected components labeled_array, _ = scipy.ndimage.label(th_attn[head].cpu().numpy()) # Find the biggest component size_components = [ np.sum(labeled_array == c) for c in range(np.max(labeled_array)) ] if len(size_components) > 1: # Select the biggest component avoiding component 0 corresponding # to background biggest_component = np.argmax(size_components[1:]) + 1 else: # Cases of a single component biggest_component = 0 # Mask corresponding to connected component mask = np.where(labeled_array == biggest_component) # Add +1 because excluded max ymin, ymax = min(mask[0]), max(mask[0]) + 1 xmin, xmax = min(mask[1]), max(mask[1]) + 1 # Rescale to image r_xmin, r_xmax = xmin * patch_size, xmax * patch_size r_ymin, r_ymax = ymin * patch_size, ymax * patch_size pred = [r_xmin, r_ymin, r_xmax, r_ymax] return pred def get_feats(feat_out, shape): # Batch size, Number of heads, Number of tokens nb_im, nh, nb_tokens = shape[0:3] qkv = ( feat_out["qkv"] .reshape(nb_im, nb_tokens, 3, nh, -1 // nh) .permute(2, 0, 3, 1, 4) ) k = qkv[1] k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1) return k def get_instances(masks, return_largest=False): return [ get_instances_single(m[None], return_largest=return_largest) for m in masks ] def get_instances_single(mask, return_largest=False): """Get the mask of a single instance.""" labeled_array, _ = label(mask.cpu().numpy()) instances = np.concatenate( [labeled_array == c for c in range(np.max(labeled_array) + 1)], axis=0 ) if return_largest: size_components = np.sum(instances, axis=(1, 2)) if len(size_components) > 1: # Select the biggest component avoiding component 0 corresponding # to background biggest_component = np.argmax(size_components[1:]) + 1 else: # Cases of a single component biggest_component = 0 # Mask corresponding to connected component return torch.from_numpy(labeled_array == biggest_component).float() return torch.from_numpy(instances[1:]).float()