import gradio as gr import torch import torch.nn.functional as F from torchvision import transforms import PIL from PIL import Image import os from typing import Tuple def setup_model(device: torch.device) -> Tuple[torch.nn.Module, int]: image_size = 384 model = torch.hub.load('alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384') model.to(device) model.eval() return model, image_size def setup_transforms(image_size: int) -> transforms.Compose: return transforms.Compose([ transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR), transforms.CenterCrop(image_size), transforms.ToTensor(), ]) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model, image_size = setup_model(device) trans_totensor = setup_transforms(image_size) def estimate_surface_normal(input_image: PIL.Image.Image) -> PIL.Image.Image: with torch.no_grad(): img_tensor = trans_totensor(input_image)[:3].unsqueeze(0).to(device) if img_tensor.shape[1] == 1: img_tensor = img_tensor.repeat_interleave(3, 1) output = model(img_tensor).clamp(min=0, max=1) output_image = transforms.ToPILImage()(output[0]) return output_image iface = gr.Interface( fn=estimate_surface_normal, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Monocular Surface Normal Estimation: Omnidata DPT-Hybrid", description="Upload an image to estimate monocular surface normals. To use these models locally, you can use `torch.hub.load`. Code and examples in our [Github](https://github.com/alexsax/omnidata_models) repository. More information and the paper in the project page [Omnidata: A Scalable Pipeline for Making Multi-Task Mid-Level Vision Datasets from 3D Scans](https://omnidata.epfl.ch/).", examples=[ "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/test1_rgb.png?raw=true", "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test2.png?raw=true", "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test3.png?raw=true", "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test4.png?raw=true", "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test5.png?raw=true", "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test6.png?raw=true", "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test7.png?raw=true", "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test8.png?raw=true", "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test9.png?raw=true", "https://github.com/EPFL-VILAB/omnidata/blob/main/omnidata_tools/torch/assets/demo/test10.png?raw=true", ], ) if __name__ == "__main__": iface.launch()