alkzar90's picture
add upscale function
d1c29b6
raw
history blame
1.36 kB
import gradio as gr
import torch
from torch import nn
from transformers import (SegformerFeatureExtractor,
SegformerForSemanticSegmentation)
MODEL_PATH="./best_model_test/"
device = torch.device("cpu")
preprocessor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_PATH)
model.eval()
def upscale_logits(logit_outputs, size):
"""Escala los logits a (4W)x(4H) para recobrar dimensiones originales del input"""
return nn.functional.interpolate(
logit_outputs,
size=size,
mode="bilinear",
align_corners=False
)
def query_image(img):
"""Función para generar predicciones a la escala origina"""
inputs = preprocessor(images=img, return_tensors="pt")
with torch.no_grad():
#preds = model(inputs.unsqueeze(0).to(device))["logits"]
preds = model(**inputs)["logits"]
preds_upscale = upscale_logits(preds, image.shape[2])
predict_label = torch.argmax(preds_upscale, dim=1).to(device)
return predict_label[0,:,:].detach().cpu().numpy()
def visualize_instance_seg_mask(mask):
return mask
demo = gr.Interface(
query_image,
inputs=[gr.Image()],
outputs="image",
title="SegFormer Model for rock glacier image segmentation"
)
demo.launch()