MetaUAS / metauas.py
csgaobb's picture
fix bug
618fc8b
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : metauas.py
@Time : 2025/03/26 23:46:12
@Author : Bin-Bin Gao
@Email : [email protected]
@Homepage: https://csgaobb.github.io/
@Version : 1.0
@Desc : some classes and functions for MetaUAS
'''
import os
import random
import kornia as K
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import tqdm
import time
import cv2
from PIL import Image
from einops import rearrange
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms.functional import pil_to_tensor
from segmentation_models_pytorch.unet.model import UnetDecoder
from segmentation_models_pytorch.fpn.decoder import FPNDecoder
from segmentation_models_pytorch.encoders import get_encoder, get_preprocessing_params
def set_random_seed(seed=233, reproduce=False):
np.random.seed(seed)
torch.manual_seed(seed ** 2)
torch.cuda.manual_seed(seed ** 3)
random.seed(seed ** 4)
if reproduce:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
def normalize(pred, max_value=None, min_value=None):
if max_value is None or min_value is None:
return (pred - pred.min()) / (pred.max() - pred.min())
else:
return (pred - min_value) / (max_value - min_value)
def apply_ad_scoremap(image, scoremap, alpha=0.5):
np_image = np.asarray(image, dtype=np.float32)
scoremap = (scoremap * 255).astype(np.uint8)
scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)
def read_image_as_tensor(path_to_image):
pil_image = Image.open(path_to_image).convert("RGB")
image_as_tensor = pil_to_tensor(pil_image).float() / 255.0
return image_as_tensor
def safely_load_state_dict(model, checkpoint):
model.load_state_dict(torch.load(checkpoint), strict=True)
return model
class AlignmentModule(nn.Module):
def __init__(self, input_channels=2048, hidden_channels=256, alignment_type="sa", fusion_policy='cat'):
super().__init__()
self.fusion_policy = fusion_policy
self.alignment_layer = AlignmentLayer(input_channels, hidden_channels, alignment_type=alignment_type)
def forward(self, query_features, prompt_features):
if isinstance(prompt_features, list):
aligned_prompt = []
for i in range(len(prompt_features)):
weighted_prompt.append(self.alignment_layer(query_features, prompt_features[i]))
aligned_prompt = torch.mean(torch.stack(aligned_prompt),0)
else:
aligned_prompt = self.alignment_layer(query_features, prompt_features)
if self.fusion_policy == 'cat':
query_features = rearrange(
[query_features, aligned_prompt], "two b c h w -> b (two c) h w"
)
elif self.fusion_policy == 'add':
query_features = query_features + aligned_prompt
elif self.fusion_policy == 'absdiff':
query_features = (query_features - aligned_prompt).abs()
return query_features
class AlignmentLayer(nn.Module):
def __init__(self, input_channels=2048, hidden_channels=256, alignment_type="sa"):
super().__init__()
self.alignment_type = alignment_type
if alignment_type != "na":
self.dimensionality_reduction = nn.Conv2d(
input_channels, hidden_channels, kernel_size=1, stride=1, padding=0, bias=True
)
def forward(self, query_features, prompt_features):
# no-alignment
if self.alignment_type == 'na':
return prompt_features
else:
Q = self.dimensionality_reduction(query_features)
K = self.dimensionality_reduction(prompt_features)
V = rearrange(prompt_features, "b c h w -> b c (h w)")
soft_attention_map = torch.einsum("bcij,bckl->bijkl", Q, K)
soft_attention_map = rearrange(soft_attention_map, "b h1 w1 h2 w2 -> b h1 w1 (h2 w2)")
soft_attention_map = nn.Softmax(dim=3)(soft_attention_map)
# soft-alignment
if self.alignment_type == 'sa':
aligned_features = torch.einsum("bijp,bcp->bcij", soft_attention_map, V)
# hard-alignment
if self.alignment_type == 'ha':
max_v, max_index = attention_map.max(dim=-1, keepdim=True)
hard_attention_map = (attention_map == max_v).float()
aligned_features = torch.einsum("bijp,bcp->bcij", hard_attention_map, V)
return aligned_features
class MetaUAS(pl.LightningModule):
def __init__(self, encoder_name, decoder_name, encoder_depth, decoder_depth, num_alignment_layers, alignment_type, fusion_policy):
super().__init__()
self.encoder_name = encoder_name
self.decoder_name = decoder_name
self.encoder_depth = encoder_depth
self.decoder_depth = decoder_depth
self.num_alignment_layers = num_alignment_layers
self.alignment_type = alignment_type
self.fusion_policy = fusion_policy
align_input_channels = [448, 160, 56]
align_hidden_channels = [224, 80, 28]
encoder_channels = [3, 48, 32, 56, 160, 448]
decoder_channels = [256, 128, 64, 64, 48]
self.encoder = get_encoder(
self.encoder_name,
in_channels=3,
depth=self.encoder_depth,
weights="imagenet",)
preparams = get_preprocessing_params(
self.encoder_name,
pretrained="imagenet"
)
self.preprocess = transforms.Normalize(preparams['mean'], preparams['std'])
self.encoder.eval()
for param in self.encoder.parameters():
param.requires_grad = False
if self.decoder_name == "unet":
encoder_out_channels = encoder_channels[self.encoder_depth-self.decoder_depth:]
if self.fusion_policy == 'cat':
num_alignment_layers = self.num_alignment_layers
elif self.fusion_policy == 'add' or self.fusion_policy == 'absdiff':
num_alignment_layers = 0
self.decoder = UnetDecoder(
encoder_channels=encoder_out_channels,
decoder_channels=decoder_channels,
n_blocks= self.decoder_depth,
attention_type="scse",
num_coam_layers= num_alignment_layers,
)
elif self.decoder_name == "fpn":
encoder_out_channels = encoder_channels
if self.fusion_policy == 'cat':
for i in range(self.num_alignment_layers):
encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)]
self.decoder = FPNDecoder(
encoder_channels= encoder_out_channels,
encoder_depth=self.encoder_depth,
pyramid_channels=256,
segmentation_channels=decoder_channels[-1],
dropout=0.2,
merge_policy="add",
)
elif self.decoder_name == "fpnadd":
segmentation_channels = 256 #128
encoder_out_channels = encoder_channels
if self.fusion_policy == 'cat':
for i in range(self.num_alignment_layers):
encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)]
self.decoder = FPNDecoder(
encoder_channels= encoder_out_channels,
encoder_depth=self.encoder_depth,
pyramid_channels=256,
segmentation_channels=segmentation_channels,
dropout=0.2,
merge_policy="add",
)
elif self.decoder_name == "fpncat":
encoder_out_channels = encoder_channels
segmentation_channels = 256 #128
if self.fusion_policy == 'cat':
for i in range(self.num_alignment_layers):
encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)]
self.decoder = FPNDecoder(
encoder_channels= encoder_out_channels,
encoder_depth=self.encoder_depth,
pyramid_channels=256,
segmentation_channels=segmentation_channels,
dropout=0.2,
merge_policy="cat",
)
if self.alignment_type == "sa" or self.alignment_type == "na" or self.alignment_type == "ha" :
self.alignment = nn.ModuleList(
[
AlignmentModule(
input_channels=align_input_channels[i],
hidden_channels=align_hidden_channels[i],
alignment_type=self.alignment_type,
fusion_policy=self.fusion_policy,
)
for i in range(self.num_alignment_layers)
]
)
if self.decoder_name == "fpncat":
self.mask_head = nn.Conv2d(
segmentation_channels*4,
1,
kernel_size=1,
stride=1,
padding=0,
)
elif self.decoder_name == "fpnadd":
self.mask_head = nn.Conv2d(
segmentation_channels,
1,
kernel_size=1,
stride=1,
padding=0,
)
else:
self.mask_head = nn.Conv2d(
decoder_channels[-1],
1,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, batch):
query_input = self.preprocess(batch["query_image"])
prompt_input = self.preprocess(batch["prompt_image"])
with torch.no_grad():
query_encoded_features = self.encoder(query_input)
prompt_encoded_features = self.encoder(prompt_input)
for i in range(len(self.alignment)):
query_encoded_features[-(i + 1)] = self.alignment[i](query_encoded_features[-(i + 1)], prompt_encoded_features[-(i + 1)])
query_decoded_features = self.decoder(*query_encoded_features[self.encoder_depth-self.decoder_depth:])
if self.decoder_name == "fpn" or self.decoder_name == "fpncat" or self.decoder_name == "fpnadd":
output = F.interpolate(self.mask_head(query_decoded_features), scale_factor=4, mode="bilinear", align_corners=False)
elif self.decoder_name == "unet":
if self.decoder_depth == 4:
output = F.interpolate(self.mask_head(query_decoded_features), scale_factor=2, mode="bilinear", align_corners=False)
if self.decoder_depth == 5:
if not self.training:
output = self.mask_head(query_decoded_features)
return output.sigmoid()