sashasax's picture
update description
3dd0478
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
import PIL
from PIL import Image
import os
from typing import Tuple
def setup_model(device: torch.device) -> Tuple[torch.nn.Module, int]:
image_size = 384
model = torch.hub.load('alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384')
model.to(device)
model.eval()
return model, image_size
def setup_transforms(image_size: int) -> transforms.Compose:
return transforms.Compose([
transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, image_size = setup_model(device)
trans_totensor = setup_transforms(image_size)
def estimate_surface_normal(input_image: PIL.Image.Image) -> PIL.Image.Image:
with torch.no_grad():
img_tensor = trans_totensor(input_image)[:3].unsqueeze(0).to(device)
if img_tensor.shape[1] == 1:
img_tensor = img_tensor.repeat_interleave(3, 1)
output = model(img_tensor).clamp(min=0, max=1)
output_image = transforms.ToPILImage()(output[0])
return output_image
iface = gr.Interface(
fn=estimate_surface_normal,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Monocular Surface Normal Estimation: Omnidata DPT-Hybrid",
description="Upload an image to estimate monocular surface normals. To use these models locally, you can use `torch.hub.load`. Code and examples in our [Github](https://github.com/alexsax/omnidata_models) repository. More information and the paper in the project page [Omnidata: A Scalable Pipeline for Making Multi-Task Mid-Level Vision Datasets from 3D Scans](https://omnidata.epfl.ch/).",
examples=[
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/test1_rgb.png?raw=true",
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test2.png?raw=true",
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test3.png?raw=true",
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test4.png?raw=true",
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test5.png?raw=true",
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test6.png?raw=true",
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test7.png?raw=true",
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test8.png?raw=true",
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test9.png?raw=true",
"https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test10.png?raw=true",
],
)
if __name__ == "__main__":
iface.launch()