CLIP_as_RNN / modeling /post_process /object_discovery.py
Kevin Sun
init commit
6cd90b7
raw
history blame
10.4 kB
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Find objects."""
# pylint: disable=g-importing-member
import numpy as np
import scipy
from scipy import ndimage
from scipy.linalg import eigh
from scipy.ndimage import label
import torch
import torch.nn.functional as F
def ncut(
feats,
dims,
scales,
init_image_size,
tau=0,
eps=1e-5,
no_binary_graph=False,
):
"""Implementation of NCut Method.
Args:
feats: the pixel/patche features of an image
dims: dimension of the map from which the features are used
scales: from image to map scale
init_image_size: size of the image
tau: thresold for graph construction
eps: graph edge weight
no_binary_graph: ablation study for using similarity score as graph
edge weight
Returns:
TODO
"""
feats = feats[0, 1:, :]
feats = F.normalize(feats, p=2)
a = feats @ feats.transpose(1, 0)
a = a.cpu().numpy()
if no_binary_graph:
a[a < tau] = eps
else:
a = a > tau
a = np.where(a.astype(float) == 0, eps, a)
d_i = np.sum(a, axis=1)
d = np.diag(d_i)
# Print second and third smallest eigenvector
_, eigenvectors = eigh(d - a, d, subset_by_index=[1, 2])
eigenvec = np.copy(eigenvectors[:, 0])
# Using average point to compute bipartition
second_smallest_vec = eigenvectors[:, 0]
avg = np.sum(second_smallest_vec) / len(second_smallest_vec)
bipartition = second_smallest_vec > avg
seed = np.argmax(np.abs(second_smallest_vec))
if bipartition[seed] != 1:
eigenvec = eigenvec * -1
bipartition = np.logical_not(bipartition)
bipartition = bipartition.reshape(dims).astype(float)
# predict BBox
# We only extract the principal object BBox
pred, _, objects, cc = detect_box(
bipartition,
seed,
dims,
scales=scales,
initial_im_size=init_image_size[1:],
)
mask = np.zeros(dims)
mask[cc[0], cc[1]] = 1
return np.asarray(pred), objects, mask, seed, None, eigenvec.reshape(dims)
def grad_obj_discover_on_attn(attn, gradcam, dims, topk=1, threshold=0.6):
"""Get the gradcam and attn map, then find the seed, then use LOST algorithm to find the potential points.
Args:
attn: attention map from ViT averaged across all heads, shape: [1,
(1+num_patches), (1+num_patches)].
gradcam: gradcam map from ViT, shape: [1, 1, H, W].
dims:
topk:
threshold:
Returns:
th_attn:
"""
w_featmap, h_featmap = dims
# nh = attn.shape[1]
attn = attn.squeeze()
seeds = torch.argsort(gradcam.flatten(), descending=True)[:topk]
# We keep only the output patch attention
# Get the attentions corresponding to [CLS] token
patch_attn = attn[1:, 1:]
topk_attn = patch_attn[seeds]
nh = topk_attn.shape[0]
# attentions = attn[0, :, 0, 1:].reshape(nh, -1)
# we keep only a certain percentage of the mass
val, idx = torch.sort(topk_attn)
val /= torch.sum(val, dim=1, keepdim=True)
cumval = torch.cumsum(val, dim=1)
th_attn = cumval > (1 - threshold)
idx2 = torch.argsort(idx)
for h in range(nh):
th_attn[h] = th_attn[h][idx2[h]]
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
th_attn = th_attn.sum(0)
th_attn[th_attn > 1] = 1
return th_attn[None, None]
def grad_obj_discover(feats, gradcam, dims):
"""Using gradient heatmap to find the seed, then use LOST algorithm to find the potential points.
Args:
feats: the pixel/patche features of an image. Shape: [1, HW, C]
gradcam: the grad cam map
dims: dimension of the map from which the features are used
Returns:
pred: box predictions
A: binary affinity matrix
scores: lowest degree scores for all patches
seed: selected patch corresponding to an object
"""
# Compute the similarity
a = (feats @ feats.transpose(1, 2)).squeeze()
# Compute the inverse degree centrality measure per patch
# sorted_patches, scores = patch_scoring(a)
# Select the initial seed
# seed = sorted_patches[0]
seed = gradcam.argmax()
mask = a[seed]
mask = mask.view(1, 1, *dims)
return mask
def lost(feats, dims, scales, init_image_size, k_patches=100):
"""Implementation of LOST method.
Args:
feats: the pixel/patche features of an image. Shape: [1, C, H, W]
dims: dimension of the map from which the features are used
scales: from image to map scale
init_image_size: size of the image
k_patches: number of k patches retrieved that are compared to the seed
at seed expansion.
Returns:
pred: box predictions
A: binary affinity matrix
scores: lowest degree scores for all patches
seed: selected patch corresponding to an object
"""
# Compute the similarity
feats = feats.flatten(2).transpose(1, 2)
a = (feats @ feats.transpose(1, 2)).squeeze()
# Compute the inverse degree centrality measure per patch
sorted_patches, _ = patch_scoring(a)
# Select the initial seed
seed = sorted_patches[0]
# Seed expansion
potentials = sorted_patches[:k_patches]
similars = potentials[a[seed, potentials] > 0.0]
m = torch.sum(a[similars, :], dim=0)
# Box extraction
_, _, _, mask = detect_box(
m, seed, dims, scales=scales, initial_im_size=init_image_size[1:]
)
return mask
# return np.asarray(bbox), A, scores, seed
def patch_scoring(m, threshold=0.0):
"""Patch scoring based on the inverse degree."""
# Cloning important
a = m.clone()
# Zero diagonal
a.fill_diagonal_(0)
# Make sure symmetric and non nul
a[a < 0] = 0
# C = A + A.t()
# Sort pixels by inverse degree
cent = -torch.sum(a > threshold, dim=1).type(torch.float32)
sel = torch.argsort(cent, descending=True)
return sel, cent
def detect_box(
bipartition,
seed,
dims,
initial_im_size=None,
scales=None,
principle_object=True,
):
"""Extract a box corresponding to the seed patch."""
# Among connected components extract from the affinity matrix, select the one
# corresponding to the seed patch.
# w_featmap, h_featmap = dims
objects, _ = ndimage.label(bipartition)
cc = objects[np.unravel_index(seed, dims)]
if principle_object:
mask = np.where(objects == cc)
# Add +1 because excluded max
ymin, ymax = min(mask[0]), max(mask[0]) + 1
xmin, xmax = min(mask[1]), max(mask[1]) + 1
# Rescale to image size
r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax
r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax
pred = [r_xmin, r_ymin, r_xmax, r_ymax]
# Check not out of image size (used when padding)
if initial_im_size:
pred[2] = min(pred[2], initial_im_size[1])
pred[3] = min(pred[3], initial_im_size[0])
# Coordinate predictions for the feature space
# Axis different then in image space
pred_feats = [ymin, xmin, ymax, xmax]
return pred, pred_feats, objects, mask
else:
raise NotImplementedError
# This function is modified from
# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
# Ref: https://github.com/facebookresearch/dino.
def dino_seg(attn, dims, patch_size, head=0):
"""Extraction of boxes based on the DINO segmentation method proposed in DINO."""
w_featmap, h_featmap = dims
nh = attn.shape[1]
official_th = 0.6
# We keep only the output patch attention
# Get the attentions corresponding to [CLS] token
attentions = attn[0, :, 0, 1:].reshape(nh, -1)
# we keep only a certain percentage of the mass
val, idx = torch.sort(attentions)
val /= torch.sum(val, dim=1, keepdim=True)
cumval = torch.cumsum(val, dim=1)
th_attn = cumval > (1 - official_th)
idx2 = torch.argsort(idx)
for h in range(nh):
th_attn[h] = th_attn[h][idx2[h]]
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
# Connected components
labeled_array, _ = scipy.ndimage.label(th_attn[head].cpu().numpy())
# Find the biggest component
size_components = [
np.sum(labeled_array == c) for c in range(np.max(labeled_array))
]
if len(size_components) > 1:
# Select the biggest component avoiding component 0 corresponding
# to background
biggest_component = np.argmax(size_components[1:]) + 1
else:
# Cases of a single component
biggest_component = 0
# Mask corresponding to connected component
mask = np.where(labeled_array == biggest_component)
# Add +1 because excluded max
ymin, ymax = min(mask[0]), max(mask[0]) + 1
xmin, xmax = min(mask[1]), max(mask[1]) + 1
# Rescale to image
r_xmin, r_xmax = xmin * patch_size, xmax * patch_size
r_ymin, r_ymax = ymin * patch_size, ymax * patch_size
pred = [r_xmin, r_ymin, r_xmax, r_ymax]
return pred
def get_feats(feat_out, shape):
# Batch size, Number of heads, Number of tokens
nb_im, nh, nb_tokens = shape[0:3]
qkv = (
feat_out["qkv"]
.reshape(nb_im, nb_tokens, 3, nh, -1 // nh)
.permute(2, 0, 3, 1, 4)
)
k = qkv[1]
k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
return k
def get_instances(masks, return_largest=False):
return [
get_instances_single(m[None], return_largest=return_largest)
for m in masks
]
def get_instances_single(mask, return_largest=False):
"""Get the mask of a single instance."""
labeled_array, _ = label(mask.cpu().numpy())
instances = np.concatenate(
[labeled_array == c for c in range(np.max(labeled_array) + 1)], axis=0
)
if return_largest:
size_components = np.sum(instances, axis=(1, 2))
if len(size_components) > 1:
# Select the biggest component avoiding component 0 corresponding
# to background
biggest_component = np.argmax(size_components[1:]) + 1
else:
# Cases of a single component
biggest_component = 0
# Mask corresponding to connected component
return torch.from_numpy(labeled_array == biggest_component).float()
return torch.from_numpy(instances[1:]).float()