Spaces:
Running
Running
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
''' | |
@File : app.py | |
@Time : 2025/03/26 23:48:24 | |
@Author : Bin-Bin Gao | |
@Email : csgaobb@gmail.com | |
@Homepage: https://csgaobb.github.io/ | |
@Version : 1.0 | |
@Desc : MetaUAS Demo with Gradio | |
''' | |
import os | |
import cv2 | |
import torch | |
import json | |
import shutil | |
import kornia as K | |
import numpy as np | |
import gradio as gr | |
from easydict import EasyDict | |
from argparse import ArgumentParser | |
from torchvision.transforms.functional import pil_to_tensor | |
from metauas import MetaUAS, set_random_seed, normalize, apply_ad_scoremap, safely_load_state_dict | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# configurations | |
random_seed = 1 | |
encoder_name = 'efficientnet-b4' | |
decoder_name = 'unet' | |
encoder_depth = 5 | |
decoder_depth = 5 | |
num_alignment_layers = 3 | |
alignment_type = 'sa' | |
fusion_policy = 'cat' | |
# build model | |
set_random_seed(random_seed) | |
metauas_model = MetaUAS(encoder_name, | |
decoder_name, | |
encoder_depth, | |
decoder_depth, | |
num_alignment_layers, | |
alignment_type, | |
fusion_policy | |
) | |
def process_image(prompt_img, query_img, options): | |
# Load the model based on selected options | |
if 'model-512' in options: | |
ckt_path = "weights/metauas-512.ckpt" | |
model = safely_load_state_dict(metauas_model, ckt_path) | |
img_size = 512 | |
else: | |
ckt_path = 'weights/metauas-256.ckpt' | |
model = safely_load_state_dict(metauas_model, ckt_path) | |
img_size = 256 | |
model.to(device) | |
model.eval() | |
# Ensure image is in RGB mode | |
prompt_img = prompt_img.convert('RGB') | |
query_img = query_img.convert('RGB') | |
query_img = pil_to_tensor(query_img).float() / 255.0 | |
prompt_img = pil_to_tensor(prompt_img).float() / 255.0 | |
# Resize image | |
resize_trans = K.augmentation.Resize([img_size, img_size], return_transform=True) | |
query_img = resize_trans(query_img)[0] | |
prompt_img = resize_trans(prompt_img)[0] | |
test_data = { | |
"query_image": query_img.to(device), | |
"prompt_image": prompt_img.to(device), | |
} | |
# Forward | |
with torch.no_grad(): | |
predicted_masks = model(test_data) | |
anomaly_score = predicted_masks[:].max() | |
# Process anomaly map | |
query_img = test_data["query_image"][0] * 255 | |
query_img = query_img.permute(1,2,0) | |
if device == 'cuda': | |
anomaly_map = predicted_masks.squeeze().detach()[:, :, None].cpu().numpy().repeat(3, 2) | |
anomaly_map_vis = apply_ad_scoremap(query_img.cpu(), normalize(anomaly_map)) | |
else: | |
anomaly_map = predicted_masks.squeeze().detach()[:, :, None].numpy().repeat(3, 2) | |
anomaly_map_vis = apply_ad_scoremap(query_img, normalize(anomaly_map)) | |
anomaly_map = (anomaly_map * 255).astype(np.uint8) | |
anomaly_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET) | |
anomaly_map = cv2.cvtColor(anomaly_map, cv2.COLOR_BGR2RGB) | |
return anomaly_map_vis, anomaly_map, f'{anomaly_score:.3f}' | |
# Define examples | |
examples = [ | |
["images/134.png", "images/000.png", "model-256"], | |
["images/036.png", "images/024.png", "model-256"], | |
["images/178.png", "images/003.png", "model-256"], | |
] | |
# Gradio interface layout | |
with gr.Blocks() as demo: | |
gr.HTML("""<h1 align="center" style='margin-top: 30px;'>MetaUAS: Universal Anomaly Segmentation</h1>""") | |
gr.HTML("""<h1 align="center" style="font-size: 15px; "style='margin-top: 40px;'>just given ONE normal image prompt</h1>""") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
prompt_image = gr.Image(type="pil", label="Prompt Image") | |
query_image = gr.Image(type="pil", label="Query Image") | |
model_selector = gr.Radio(["model-256", "model-512"], label="Pre-models") | |
with gr.Column(): | |
with gr.Row(): | |
anomaly_map_vis = gr.Image(type="pil", label="Anomaly Results") | |
anomaly_map = gr.Image(type="pil", label="Anomaly Maps") | |
anomaly_score = gr.Textbox(label="Anomaly Score") | |
with gr.Row(): | |
submit_button = gr.Button("Submit", elem_id="submit-button") | |
clear_button = gr.Button("Clear") | |
# Set up the event handlers | |
submit_button.click(process_image, inputs=[prompt_image, query_image, model_selector], outputs=[anomaly_map_vis, anomaly_map, anomaly_score]) | |
clear_button.click(lambda: (None, None, None), outputs=[anomaly_map_vis, anomaly_map, anomaly_score]) | |
# Add examples directly to the Blocks interface | |
gr.Examples(examples, inputs=[prompt_image, query_image, model_selector]) | |
# Add custom CSS to control the output image size and button styles | |
demo.css = """ | |
#submit-button { | |
color: red !important; /* Font color */ | |
background-color: LightSalmon !important; /* Background color */ | |
border: none !important; /* Remove border */ | |
padding: 10px 20px !important; /* Add padding */ | |
border-radius: 5px !important; /* Rounded corners */ | |
font-size: 16px !important; /* Font size */ | |
cursor: pointer !important; /* Pointer cursor on hover */ | |
} | |
#submit-button:hover { | |
background-color: Coral !important; /* Darker orange on hover */ | |
} | |
""" | |
# Launch the demo | |
demo.launch() | |