#!/usr/bin/env python # -*- coding: utf-8 -*- ''' @File : metauas.py @Time : 2025/03/26 23:46:12 @Author : Bin-Bin Gao @Email : csgaobb@gmail.com @Homepage: https://csgaobb.github.io/ @Version : 1.0 @Desc : some classes and functions for MetaUAS ''' import os import random import kornia as K import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import torch import torch.nn as nn import tqdm import time import cv2 from PIL import Image from einops import rearrange from torch.nn import functional as F from torchvision import transforms from torchvision.transforms.functional import pil_to_tensor from segmentation_models_pytorch.unet.model import UnetDecoder from segmentation_models_pytorch.fpn.decoder import FPNDecoder from segmentation_models_pytorch.encoders import get_encoder, get_preprocessing_params def set_random_seed(seed=233, reproduce=False): np.random.seed(seed) torch.manual_seed(seed ** 2) torch.cuda.manual_seed(seed ** 3) random.seed(seed ** 4) if reproduce: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True else: torch.backends.cudnn.benchmark = True def normalize(pred, max_value=None, min_value=None): if max_value is None or min_value is None: return (pred - pred.min()) / (pred.max() - pred.min()) else: return (pred - min_value) / (max_value - min_value) def apply_ad_scoremap(image, scoremap, alpha=0.5): np_image = np.asarray(image, dtype=np.float32) scoremap = (scoremap * 255).astype(np.uint8) scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET) scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB) return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8) def read_image_as_tensor(path_to_image): pil_image = Image.open(path_to_image).convert("RGB") image_as_tensor = pil_to_tensor(pil_image).float() / 255.0 return image_as_tensor def safely_load_state_dict(model, checkpoint): model.load_state_dict(torch.load(checkpoint), strict=True) return model class AlignmentModule(nn.Module): def __init__(self, input_channels=2048, hidden_channels=256, alignment_type="sa", fusion_policy='cat'): super().__init__() self.fusion_policy = fusion_policy self.alignment_layer = AlignmentLayer(input_channels, hidden_channels, alignment_type=alignment_type) def forward(self, query_features, prompt_features): if isinstance(prompt_features, list): aligned_prompt = [] for i in range(len(prompt_features)): weighted_prompt.append(self.alignment_layer(query_features, prompt_features[i])) aligned_prompt = torch.mean(torch.stack(aligned_prompt),0) else: aligned_prompt = self.alignment_layer(query_features, prompt_features) if self.fusion_policy == 'cat': query_features = rearrange( [query_features, aligned_prompt], "two b c h w -> b (two c) h w" ) elif self.fusion_policy == 'add': query_features = query_features + aligned_prompt elif self.fusion_policy == 'absdiff': query_features = (query_features - aligned_prompt).abs() return query_features class AlignmentLayer(nn.Module): def __init__(self, input_channels=2048, hidden_channels=256, alignment_type="sa"): super().__init__() self.alignment_type = alignment_type if alignment_type != "na": self.dimensionality_reduction = nn.Conv2d( input_channels, hidden_channels, kernel_size=1, stride=1, padding=0, bias=True ) def forward(self, query_features, prompt_features): # no-alignment if self.alignment_type == 'na': return prompt_features else: Q = self.dimensionality_reduction(query_features) K = self.dimensionality_reduction(prompt_features) V = rearrange(prompt_features, "b c h w -> b c (h w)") soft_attention_map = torch.einsum("bcij,bckl->bijkl", Q, K) soft_attention_map = rearrange(soft_attention_map, "b h1 w1 h2 w2 -> b h1 w1 (h2 w2)") soft_attention_map = nn.Softmax(dim=3)(soft_attention_map) # soft-alignment if self.alignment_type == 'sa': aligned_features = torch.einsum("bijp,bcp->bcij", soft_attention_map, V) # hard-alignment if self.alignment_type == 'ha': max_v, max_index = attention_map.max(dim=-1, keepdim=True) hard_attention_map = (attention_map == max_v).float() aligned_features = torch.einsum("bijp,bcp->bcij", hard_attention_map, V) return aligned_features class MetaUAS(pl.LightningModule): def __init__(self, encoder_name, decoder_name, encoder_depth, decoder_depth, num_alignment_layers, alignment_type, fusion_policy): super().__init__() self.encoder_name = encoder_name self.decoder_name = decoder_name self.encoder_depth = encoder_depth self.decoder_depth = decoder_depth self.num_alignment_layers = num_alignment_layers self.alignment_type = alignment_type self.fusion_policy = fusion_policy align_input_channels = [448, 160, 56] align_hidden_channels = [224, 80, 28] encoder_channels = [3, 48, 32, 56, 160, 448] decoder_channels = [256, 128, 64, 64, 48] self.encoder = get_encoder( self.encoder_name, in_channels=3, depth=self.encoder_depth, weights="imagenet",) preparams = get_preprocessing_params( self.encoder_name, pretrained="imagenet" ) self.preprocess = transforms.Normalize(preparams['mean'], preparams['std']) self.encoder.eval() for param in self.encoder.parameters(): param.requires_grad = False if self.decoder_name == "unet": encoder_out_channels = encoder_channels[self.encoder_depth-self.decoder_depth:] if self.fusion_policy == 'cat': num_alignment_layers = self.num_alignment_layers elif self.fusion_policy == 'add' or self.fusion_policy == 'absdiff': num_alignment_layers = 0 self.decoder = UnetDecoder( encoder_channels=encoder_out_channels, decoder_channels=decoder_channels, n_blocks= self.decoder_depth, attention_type="scse", num_coam_layers= num_alignment_layers, ) elif self.decoder_name == "fpn": encoder_out_channels = encoder_channels if self.fusion_policy == 'cat': for i in range(self.num_alignment_layers): encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)] self.decoder = FPNDecoder( encoder_channels= encoder_out_channels, encoder_depth=self.encoder_depth, pyramid_channels=256, segmentation_channels=decoder_channels[-1], dropout=0.2, merge_policy="add", ) elif self.decoder_name == "fpnadd": segmentation_channels = 256 #128 encoder_out_channels = encoder_channels if self.fusion_policy == 'cat': for i in range(self.num_alignment_layers): encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)] self.decoder = FPNDecoder( encoder_channels= encoder_out_channels, encoder_depth=self.encoder_depth, pyramid_channels=256, segmentation_channels=segmentation_channels, dropout=0.2, merge_policy="add", ) elif self.decoder_name == "fpncat": encoder_out_channels = encoder_channels segmentation_channels = 256 #128 if self.fusion_policy == 'cat': for i in range(self.num_alignment_layers): encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)] self.decoder = FPNDecoder( encoder_channels= encoder_out_channels, encoder_depth=self.encoder_depth, pyramid_channels=256, segmentation_channels=segmentation_channels, dropout=0.2, merge_policy="cat", ) if self.alignment_type == "sa" or self.alignment_type == "na" or self.alignment_type == "ha" : self.alignment = nn.ModuleList( [ AlignmentModule( input_channels=align_input_channels[i], hidden_channels=align_hidden_channels[i], alignment_type=self.alignment_type, fusion_policy=self.fusion_policy, ) for i in range(self.num_alignment_layers) ] ) if self.decoder_name == "fpncat": self.mask_head = nn.Conv2d( segmentation_channels*4, 1, kernel_size=1, stride=1, padding=0, ) elif self.decoder_name == "fpnadd": self.mask_head = nn.Conv2d( segmentation_channels, 1, kernel_size=1, stride=1, padding=0, ) else: self.mask_head = nn.Conv2d( decoder_channels[-1], 1, kernel_size=1, stride=1, padding=0, ) def forward(self, batch): query_input = self.preprocess(batch["query_image"]) prompt_input = self.preprocess(batch["prompt_image"]) with torch.no_grad(): query_encoded_features = self.encoder(query_input) prompt_encoded_features = self.encoder(prompt_input) for i in range(len(self.alignment)): query_encoded_features[-(i + 1)] = self.alignment[i](query_encoded_features[-(i + 1)], prompt_encoded_features[-(i + 1)]) query_decoded_features = self.decoder(*query_encoded_features[self.encoder_depth-self.decoder_depth:]) if self.decoder_name == "fpn" or self.decoder_name == "fpncat" or self.decoder_name == "fpnadd": output = F.interpolate(self.mask_head(query_decoded_features), scale_factor=4, mode="bilinear", align_corners=False) elif self.decoder_name == "unet": if self.decoder_depth == 4: output = F.interpolate(self.mask_head(query_decoded_features), scale_factor=2, mode="bilinear", align_corners=False) if self.decoder_depth == 5: if not self.training: output = self.mask_head(query_decoded_features) return output.sigmoid()