Spaces:
Runtime error
Runtime error
Update app.py
Browse filesremove - DINO sources
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import spaces
|
2 |
-
from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection
|
|
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
|
@@ -8,8 +9,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
8 |
owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to("cuda")
|
9 |
owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
|
10 |
|
11 |
-
dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
|
12 |
-
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")
|
13 |
|
14 |
@spaces.GPU
|
15 |
def infer(img, text_queries, score_threshold, model):
|
@@ -57,26 +58,30 @@ def infer(img, text_queries, score_threshold, model):
|
|
57 |
result_labels.append((box, label))
|
58 |
return result_labels
|
59 |
|
60 |
-
def query_image(img, text_queries, owl_threshold, dino_threshold):
|
|
|
61 |
text_queries = text_queries
|
62 |
text_queries = text_queries.split(",")
|
63 |
owl_output = infer(img, text_queries, owl_threshold, "owl")
|
64 |
-
dino_output = infer(img, text_queries, dino_threshold, "dino")
|
65 |
|
66 |
|
67 |
-
return (img, owl_output), (img, dino_output)
|
|
|
68 |
|
69 |
|
70 |
owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold")
|
71 |
-
dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
|
72 |
owl_output = gr.AnnotatedImage(label="OWL Output")
|
73 |
-
dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
|
74 |
demo = gr.Interface(
|
75 |
query_image,
|
76 |
-
inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold, dino_threshold],
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
80 |
examples=[["./bee.jpg", "bee, flower", 0.16, 0.12], ["./cats.png", "cat, fishnet", 0.16, 0.12]]
|
81 |
)
|
82 |
demo.launch(debug=True)
|
|
|
1 |
import spaces
|
2 |
+
# from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection
|
3 |
+
from transformers import Owlv2Processor, Owlv2ForObjectDetection
|
4 |
import torch
|
5 |
import gradio as gr
|
6 |
|
|
|
9 |
owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to("cuda")
|
10 |
owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
|
11 |
|
12 |
+
# dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
|
13 |
+
# dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")
|
14 |
|
15 |
@spaces.GPU
|
16 |
def infer(img, text_queries, score_threshold, model):
|
|
|
58 |
result_labels.append((box, label))
|
59 |
return result_labels
|
60 |
|
61 |
+
# def query_image(img, text_queries, owl_threshold, dino_threshold):
|
62 |
+
def query_image(img, text_queries, owl_threshold):
|
63 |
text_queries = text_queries
|
64 |
text_queries = text_queries.split(",")
|
65 |
owl_output = infer(img, text_queries, owl_threshold, "owl")
|
66 |
+
# dino_output = infer(img, text_queries, dino_threshold, "dino")
|
67 |
|
68 |
|
69 |
+
# return (img, owl_output), (img, dino_output)
|
70 |
+
return (img, owl_output)
|
71 |
|
72 |
|
73 |
owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold")
|
74 |
+
# dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
|
75 |
owl_output = gr.AnnotatedImage(label="OWL Output")
|
76 |
+
# dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
|
77 |
demo = gr.Interface(
|
78 |
query_image,
|
79 |
+
# inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold, dino_threshold],
|
80 |
+
inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold],
|
81 |
+
# outputs=[owl_output, dino_output],
|
82 |
+
outputs=[owl_output],
|
83 |
+
title="OWLv2 Demo",
|
84 |
+
description="Compare two state-of-the-art zero-shot object detection models [OWLv2](https://huggingface.co/google/owlv2-base-patch16) . Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in each model.",
|
85 |
examples=[["./bee.jpg", "bee, flower", 0.16, 0.12], ["./cats.png", "cat, fishnet", 0.16, 0.12]]
|
86 |
)
|
87 |
demo.launch(debug=True)
|