Spaces:
Runtime error
Runtime error
File size: 3,726 Bytes
4723159 fec6b29 4723159 fee2c8a 4723159 6ec7588 4723159 eb9c0c4 4723159 0170535 4723159 1a89cf0 173e298 1a89cf0 4723159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import spaces
from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection
import torch
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to("cuda")
owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")
@spaces.GPU
def infer(img, text_queries, score_threshold, model):
if model == "dino":
queries=""
for query in text_queries:
queries += f"{query}. "
width, height = img.shape[:2]
target_sizes=[(width, height)]
inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = dino_model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
box_threshold=score_threshold,
target_sizes=target_sizes)
elif model == "owl":
size = max(img.shape[:2])
target_sizes = torch.Tensor([[size, size]])
inputs = owl_processor(text=text_queries, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = owl_model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = owl_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
result_labels = []
for box, score, label in zip(boxes, scores, labels):
box = [int(i) for i in box.tolist()]
if score < score_threshold:
continue
if model == "owl":
label = text_queries[label.cpu().item()]
result_labels.append((box, label))
elif model == "dino":
if label != "":
result_labels.append((box, label))
return result_labels
def query_image(img, text_queries, owl_threshold, dino_threshold):
text_queries = text_queries
text_queries = text_queries.split(",")
owl_output = infer(img, text_queries, owl_threshold, "owl")
dino_output = infer(img, text_queries, dino_threshold, "dino")
return (img, owl_output), (img, dino_output)
owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold")
dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
owl_output = gr.AnnotatedImage(label="OWL Output")
dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
demo = gr.Interface(
query_image,
inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold, dino_threshold],
outputs=[owl_output, dino_output],
title="OWLv2 ⚔ Grounding DINO",
description="Compare two state-of-the-art zero-shot object detection models [OWLv2](https://huggingface.co/google/owlv2-base-patch16) and [Grounding DINO](https://huggingface.co/IDEA-Research/grounding-dino-base) in this Space. Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in each model.",
examples=[["./bee.jpg", "bee, flower", 0.16, 0.12], ["./cats.png", "cat, fishnet", 0.16, 0.12]]
)
demo.launch(debug=True) |