Gaejoon commited on
Commit
78d8c4b
·
verified ·
1 Parent(s): 3ca4fe4

Update app.py

Browse files

remove - DINO sources

Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import spaces
2
- from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection
 
3
  import torch
4
  import gradio as gr
5
 
@@ -8,8 +9,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
  owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to("cuda")
9
  owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
10
 
11
- dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
12
- dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")
13
 
14
  @spaces.GPU
15
  def infer(img, text_queries, score_threshold, model):
@@ -57,26 +58,30 @@ def infer(img, text_queries, score_threshold, model):
57
  result_labels.append((box, label))
58
  return result_labels
59
 
60
- def query_image(img, text_queries, owl_threshold, dino_threshold):
 
61
  text_queries = text_queries
62
  text_queries = text_queries.split(",")
63
  owl_output = infer(img, text_queries, owl_threshold, "owl")
64
- dino_output = infer(img, text_queries, dino_threshold, "dino")
65
 
66
 
67
- return (img, owl_output), (img, dino_output)
 
68
 
69
 
70
  owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold")
71
- dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
72
  owl_output = gr.AnnotatedImage(label="OWL Output")
73
- dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
74
  demo = gr.Interface(
75
  query_image,
76
- inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold, dino_threshold],
77
- outputs=[owl_output, dino_output],
78
- title="OWLv2 ⚔ Grounding DINO",
79
- description="Compare two state-of-the-art zero-shot object detection models [OWLv2](https://huggingface.co/google/owlv2-base-patch16) and [Grounding DINO](https://huggingface.co/IDEA-Research/grounding-dino-base) in this Space. Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in each model.",
 
 
80
  examples=[["./bee.jpg", "bee, flower", 0.16, 0.12], ["./cats.png", "cat, fishnet", 0.16, 0.12]]
81
  )
82
  demo.launch(debug=True)
 
1
  import spaces
2
+ # from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection
3
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection
4
  import torch
5
  import gradio as gr
6
 
 
9
  owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to("cuda")
10
  owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
11
 
12
+ # dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
13
+ # dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")
14
 
15
  @spaces.GPU
16
  def infer(img, text_queries, score_threshold, model):
 
58
  result_labels.append((box, label))
59
  return result_labels
60
 
61
+ # def query_image(img, text_queries, owl_threshold, dino_threshold):
62
+ def query_image(img, text_queries, owl_threshold):
63
  text_queries = text_queries
64
  text_queries = text_queries.split(",")
65
  owl_output = infer(img, text_queries, owl_threshold, "owl")
66
+ # dino_output = infer(img, text_queries, dino_threshold, "dino")
67
 
68
 
69
+ # return (img, owl_output), (img, dino_output)
70
+ return (img, owl_output)
71
 
72
 
73
  owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold")
74
+ # dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
75
  owl_output = gr.AnnotatedImage(label="OWL Output")
76
+ # dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
77
  demo = gr.Interface(
78
  query_image,
79
+ # inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold, dino_threshold],
80
+ inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold],
81
+ # outputs=[owl_output, dino_output],
82
+ outputs=[owl_output],
83
+ title="OWLv2 Demo",
84
+ description="Compare two state-of-the-art zero-shot object detection models [OWLv2](https://huggingface.co/google/owlv2-base-patch16) . Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in each model.",
85
  examples=[["./bee.jpg", "bee, flower", 0.16, 0.12], ["./cats.png", "cat, fishnet", 0.16, 0.12]]
86
  )
87
  demo.launch(debug=True)