pg56714 commited on
Commit
7415b0a
·
verified ·
1 Parent(s): 7b4703a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -16
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  import numpy as np
5
  import supervision as sv
6
  import torch
 
7
  from inference.models import YOLOWorld
8
 
9
  from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
@@ -42,15 +43,39 @@ MASK_ANNOTATOR = sv.MaskAnnotator()
42
  LABEL_ANNOTATOR = sv.LabelAnnotator()
43
 
44
 
45
- def detect(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  image: np.ndarray,
47
- query: str,
48
  confidence_threshold: float,
49
  nms_threshold: float,
50
  with_confidence: bool = True,
51
  ) -> np.ndarray:
52
  # Preparation.
53
- categories = [category.strip() for category in query.split(",")]
54
  yolo_world.set_classes(categories)
55
  # print("categories:", categories)
56
 
@@ -72,17 +97,13 @@ def detect(
72
 
73
  # Annotation
74
  output_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
75
- labels = [
76
- (
77
- f"{categories[class_id]}: {confidence:.3f}"
78
- if with_confidence
79
- else f"{categories[class_id]}"
80
- )
81
- for class_id, confidence in zip(detections.class_id, detections.confidence)
82
- ]
83
- output_image = MASK_ANNOTATOR.annotate(output_image, detections)
84
- output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
85
- output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
86
  return cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
87
 
88
 
@@ -140,7 +161,7 @@ with gr.Blocks() as demo:
140
  yolo_world_output_image_component = gr.Image(type="numpy", label="Output image")
141
  submit_button_component = gr.Button(value="Submit", scale=1, variant="primary")
142
  gr.Examples(
143
- fn=detect,
144
  examples=[
145
  [
146
  os.path.join(os.path.dirname(__file__), "examples/livingroom.jpg"),
@@ -165,7 +186,7 @@ with gr.Blocks() as demo:
165
  )
166
 
167
  submit_button_component.click(
168
- fn=detect,
169
  inputs=[
170
  input_image_component,
171
  image_categories_text_component,
 
4
  import numpy as np
5
  import supervision as sv
6
  import torch
7
+ from typing import List
8
  from inference.models import YOLOWorld
9
 
10
  from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
 
43
  LABEL_ANNOTATOR = sv.LabelAnnotator()
44
 
45
 
46
+ def process_categories(categories: str) -> List[str]:
47
+ return [category.strip() for category in categories.split(",")]
48
+
49
+
50
+ def annotate_image(
51
+ input_image: np.ndarray,
52
+ detections: sv.Detections,
53
+ categories: List[str],
54
+ with_confidence: bool = False,
55
+ ) -> np.ndarray:
56
+ labels = [
57
+ (
58
+ f"{categories[class_id]}: {confidence:.3f}"
59
+ if with_confidence
60
+ else f"{categories[class_id]}"
61
+ )
62
+ for class_id, confidence in zip(detections.class_id, detections.confidence)
63
+ ]
64
+ output_image = MASK_ANNOTATOR.annotate(input_image, detections)
65
+ output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
66
+ output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
67
+ return output_image
68
+
69
+
70
+ def process_image(
71
  image: np.ndarray,
72
+ categories: str,
73
  confidence_threshold: float,
74
  nms_threshold: float,
75
  with_confidence: bool = True,
76
  ) -> np.ndarray:
77
  # Preparation.
78
+ categories = process_categories(categories)
79
  yolo_world.set_classes(categories)
80
  # print("categories:", categories)
81
 
 
97
 
98
  # Annotation
99
  output_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
100
+ output_image = annotate_image(
101
+ input_image=output_image,
102
+ detections=detections,
103
+ categories=categories,
104
+ with_confidence=with_confidence,
105
+ )
106
+
 
 
 
 
107
  return cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
108
 
109
 
 
161
  yolo_world_output_image_component = gr.Image(type="numpy", label="Output image")
162
  submit_button_component = gr.Button(value="Submit", scale=1, variant="primary")
163
  gr.Examples(
164
+ # fn=process_image,
165
  examples=[
166
  [
167
  os.path.join(os.path.dirname(__file__), "examples/livingroom.jpg"),
 
186
  )
187
 
188
  submit_button_component.click(
189
+ fn=process_image,
190
  inputs=[
191
  input_image_component,
192
  image_categories_text_component,