Spaces:
Sleeping
Sleeping
ExStella
commited on
Commit
·
4255339
1
Parent(s):
1de5900
app.py added both detection
Browse files
app.py
CHANGED
@@ -19,17 +19,32 @@ def load_model(repo_id):
|
|
19 |
traffic_cones_model = load_model("ExStella/Traffic-cones")
|
20 |
license_plate_model = load_model("ExStella/License-plate")
|
21 |
|
22 |
-
# Function to process an image with the selected model
|
23 |
def process_image(img, model_type):
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
result = model.predict(img, conf=0.5, iou=0.6)
|
26 |
img_bgr = result[0].plot()
|
27 |
-
output_image = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB
|
28 |
return output_image
|
29 |
|
30 |
-
# Function to process a video with the selected model
|
31 |
def process_video(video_path, model_type):
|
32 |
-
model = traffic_cones_model if model_type == "Traffic Cones" else license_plate_model
|
33 |
temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
34 |
cap = cv2.VideoCapture(video_path)
|
35 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
@@ -45,9 +60,24 @@ def process_video(video_path, model_type):
|
|
45 |
if not ret:
|
46 |
break
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# Write the processed frame
|
53 |
out.write(frame)
|
@@ -59,6 +89,9 @@ def process_video(video_path, model_type):
|
|
59 |
|
60 |
# Gradio prediction function
|
61 |
def predict(input_file, model_type):
|
|
|
|
|
|
|
62 |
if input_file.name.endswith(('.jpg', '.jpeg', '.png')):
|
63 |
# Image input
|
64 |
img = Image.open(input_file.name)
|
@@ -76,12 +109,12 @@ gr.Interface(
|
|
76 |
fn=predict,
|
77 |
inputs=[
|
78 |
gr.File(label="Upload an image or video (JPG, PNG, MP4, AVI, etc.)"),
|
79 |
-
gr.Radio(["Traffic Cones", "License Plate"], label="Choose Detection Type"),
|
80 |
],
|
81 |
outputs=[
|
82 |
gr.Image(type="pil", label="Processed Image"), # Output for images
|
83 |
gr.Video(label="Processed Video"), # Output for videos
|
84 |
],
|
85 |
title="Object Detection for Traffic Cones and License Plates",
|
86 |
-
description="Upload an image or video to perform object detection. Select between Traffic Cones
|
87 |
-
).launch(share=True)
|
|
|
19 |
traffic_cones_model = load_model("ExStella/Traffic-cones")
|
20 |
license_plate_model = load_model("ExStella/License-plate")
|
21 |
|
22 |
+
# Function to process an image with the selected model(s)
|
23 |
def process_image(img, model_type):
|
24 |
+
if model_type == "Traffic Cones":
|
25 |
+
model = traffic_cones_model
|
26 |
+
elif model_type == "License Plate":
|
27 |
+
model = license_plate_model
|
28 |
+
elif model_type == "Both":
|
29 |
+
# Process with both models
|
30 |
+
result1 = traffic_cones_model.predict(img, conf=0.5, iou=0.6)
|
31 |
+
result2 = license_plate_model.predict(img, conf=0.5, iou=0.6)
|
32 |
+
img_bgr1 = result1[0].plot()
|
33 |
+
img_bgr2 = result2[0].plot()
|
34 |
+
img_bgr_combined = cv2.addWeighted(img_bgr1, 0.5, img_bgr2, 0.5, 0) # Combine results
|
35 |
+
output_image = Image.fromarray(img_bgr_combined[..., ::-1]) # Convert BGR to RGB
|
36 |
+
return output_image
|
37 |
+
else:
|
38 |
+
raise ValueError("Invalid detection type.")
|
39 |
+
|
40 |
+
# Process for single model type
|
41 |
result = model.predict(img, conf=0.5, iou=0.6)
|
42 |
img_bgr = result[0].plot()
|
43 |
+
output_image = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB
|
44 |
return output_image
|
45 |
|
46 |
+
# Function to process a video with the selected model(s)
|
47 |
def process_video(video_path, model_type):
|
|
|
48 |
temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
49 |
cap = cv2.VideoCapture(video_path)
|
50 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
60 |
if not ret:
|
61 |
break
|
62 |
|
63 |
+
if model_type == "Traffic Cones":
|
64 |
+
model = traffic_cones_model
|
65 |
+
elif model_type == "License Plate":
|
66 |
+
model = license_plate_model
|
67 |
+
elif model_type == "Both":
|
68 |
+
# Process with both models
|
69 |
+
results1 = traffic_cones_model.predict(frame, conf=0.5, iou=0.6)
|
70 |
+
results2 = license_plate_model.predict(frame, conf=0.5, iou=0.6)
|
71 |
+
frame1 = results1[0].plot()
|
72 |
+
frame2 = results2[0].plot()
|
73 |
+
frame = cv2.addWeighted(frame1, 0.5, frame2, 0.5, 0) # Combine results
|
74 |
+
else:
|
75 |
+
raise ValueError("Invalid detection type.")
|
76 |
+
|
77 |
+
# Annotate frame for single model type
|
78 |
+
if model_type in ["Traffic Cones", "License Plate"]:
|
79 |
+
results = model.predict(frame, conf=0.5, iou=0.6)
|
80 |
+
frame = results[0].plot() # Annotated frame
|
81 |
|
82 |
# Write the processed frame
|
83 |
out.write(frame)
|
|
|
89 |
|
90 |
# Gradio prediction function
|
91 |
def predict(input_file, model_type):
|
92 |
+
if not model_type:
|
93 |
+
raise ValueError("Please select a detection type before submitting.")
|
94 |
+
|
95 |
if input_file.name.endswith(('.jpg', '.jpeg', '.png')):
|
96 |
# Image input
|
97 |
img = Image.open(input_file.name)
|
|
|
109 |
fn=predict,
|
110 |
inputs=[
|
111 |
gr.File(label="Upload an image or video (JPG, PNG, MP4, AVI, etc.)"),
|
112 |
+
gr.Radio(["Traffic Cones", "License Plate", "Both"], label="Choose Detection Type"),
|
113 |
],
|
114 |
outputs=[
|
115 |
gr.Image(type="pil", label="Processed Image"), # Output for images
|
116 |
gr.Video(label="Processed Video"), # Output for videos
|
117 |
],
|
118 |
title="Object Detection for Traffic Cones and License Plates",
|
119 |
+
description="Upload an image or video to perform object detection. Select between Traffic Cones, License Plates, or Both detection types.",
|
120 |
+
).launch(share=True)
|