|
|
|
import numpy as np
|
|
import gradio as gr
|
|
import torch
|
|
from PIL import Image
|
|
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
|
|
|
|
|
|
|
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-beit-large-512")
|
|
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-beit-large-512")
|
|
|
|
|
|
def process_image(image):
|
|
"""
|
|
Preprocesses an image, passes it through a model, and returns the formatted depth map as an image.
|
|
|
|
Args:
|
|
image (PIL.Image.Image): The input image.
|
|
|
|
Returns:
|
|
PIL.Image.Image: The formatted depth map as an image.
|
|
"""
|
|
|
|
|
|
encoding = feature_extractor(image, return_tensors="pt")
|
|
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**encoding)
|
|
predicted_depth = outputs.predicted_depth
|
|
|
|
|
|
prediction = torch.nn.functional.interpolate(
|
|
predicted_depth.unsqueeze(1),
|
|
size=image.size[::-1],
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
).squeeze()
|
|
output = prediction.cpu().numpy()
|
|
formatted = (output * 255 / np.max(output)).astype("uint8")
|
|
|
|
|
|
return Image.fromarray(formatted)
|
|
|
|
|
|
|
|
image = gr.Image(type="pil", label="Image")
|
|
|
|
|
|
answer = gr.Image(type="pil", label="Depth Map")
|
|
|
|
|
|
examples = [
|
|
["cat.jpg"],
|
|
["dog.jpg"],
|
|
["bird.jpg"],
|
|
]
|
|
|
|
|
|
title = "Zero Shot Depth Estimation"
|
|
description = "Gradio Demo for the Intel/DPT Beit-Large-512 Depth Estimation model. This model can estimate the depth of objects in images. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
|
|
article = "<p style='text-align: center'><a href='https://arxiv.org/pdf/2307.14460' target='_blank'>MiDaS v3.1 – A Model Zoo for Robust Monocular Relative Depth Estimation</a> | <a href='https://huggingface.co/Intel/dpt-beit-large-512' target='_blank'>Model Page</a></p>"
|
|
|
|
|
|
|
|
interface = gr.Interface(
|
|
fn=process_image,
|
|
inputs=[image],
|
|
outputs=answer,
|
|
examples=examples,
|
|
title=title,
|
|
description=description,
|
|
article=article,
|
|
theme="Soft",
|
|
allow_flagging="never",
|
|
)
|
|
interface.launch(debug=False)
|
|
|