ExStella commited on
Commit
4255339
·
1 Parent(s): 1de5900

app.py added both detection

Browse files
Files changed (1) hide show
  1. app.py +44 -11
app.py CHANGED
@@ -19,17 +19,32 @@ def load_model(repo_id):
19
  traffic_cones_model = load_model("ExStella/Traffic-cones")
20
  license_plate_model = load_model("ExStella/License-plate")
21
 
22
- # Function to process an image with the selected model
23
  def process_image(img, model_type):
24
- model = traffic_cones_model if model_type == "Traffic Cones" else license_plate_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  result = model.predict(img, conf=0.5, iou=0.6)
26
  img_bgr = result[0].plot()
27
- output_image = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB for PIL
28
  return output_image
29
 
30
- # Function to process a video with the selected model
31
  def process_video(video_path, model_type):
32
- model = traffic_cones_model if model_type == "Traffic Cones" else license_plate_model
33
  temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
34
  cap = cv2.VideoCapture(video_path)
35
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -45,9 +60,24 @@ def process_video(video_path, model_type):
45
  if not ret:
46
  break
47
 
48
- # Perform object detection
49
- results = model.predict(frame, conf=0.5, iou=0.6)
50
- frame = results[0].plot() # Annotated frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # Write the processed frame
53
  out.write(frame)
@@ -59,6 +89,9 @@ def process_video(video_path, model_type):
59
 
60
  # Gradio prediction function
61
  def predict(input_file, model_type):
 
 
 
62
  if input_file.name.endswith(('.jpg', '.jpeg', '.png')):
63
  # Image input
64
  img = Image.open(input_file.name)
@@ -76,12 +109,12 @@ gr.Interface(
76
  fn=predict,
77
  inputs=[
78
  gr.File(label="Upload an image or video (JPG, PNG, MP4, AVI, etc.)"),
79
- gr.Radio(["Traffic Cones", "License Plate"], label="Choose Detection Type"),
80
  ],
81
  outputs=[
82
  gr.Image(type="pil", label="Processed Image"), # Output for images
83
  gr.Video(label="Processed Video"), # Output for videos
84
  ],
85
  title="Object Detection for Traffic Cones and License Plates",
86
- description="Upload an image or video to perform object detection. Select between Traffic Cones or License Plates detection.",
87
- ).launch(share=True)
 
19
  traffic_cones_model = load_model("ExStella/Traffic-cones")
20
  license_plate_model = load_model("ExStella/License-plate")
21
 
22
+ # Function to process an image with the selected model(s)
23
  def process_image(img, model_type):
24
+ if model_type == "Traffic Cones":
25
+ model = traffic_cones_model
26
+ elif model_type == "License Plate":
27
+ model = license_plate_model
28
+ elif model_type == "Both":
29
+ # Process with both models
30
+ result1 = traffic_cones_model.predict(img, conf=0.5, iou=0.6)
31
+ result2 = license_plate_model.predict(img, conf=0.5, iou=0.6)
32
+ img_bgr1 = result1[0].plot()
33
+ img_bgr2 = result2[0].plot()
34
+ img_bgr_combined = cv2.addWeighted(img_bgr1, 0.5, img_bgr2, 0.5, 0) # Combine results
35
+ output_image = Image.fromarray(img_bgr_combined[..., ::-1]) # Convert BGR to RGB
36
+ return output_image
37
+ else:
38
+ raise ValueError("Invalid detection type.")
39
+
40
+ # Process for single model type
41
  result = model.predict(img, conf=0.5, iou=0.6)
42
  img_bgr = result[0].plot()
43
+ output_image = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB
44
  return output_image
45
 
46
+ # Function to process a video with the selected model(s)
47
  def process_video(video_path, model_type):
 
48
  temp_output = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
49
  cap = cv2.VideoCapture(video_path)
50
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
60
  if not ret:
61
  break
62
 
63
+ if model_type == "Traffic Cones":
64
+ model = traffic_cones_model
65
+ elif model_type == "License Plate":
66
+ model = license_plate_model
67
+ elif model_type == "Both":
68
+ # Process with both models
69
+ results1 = traffic_cones_model.predict(frame, conf=0.5, iou=0.6)
70
+ results2 = license_plate_model.predict(frame, conf=0.5, iou=0.6)
71
+ frame1 = results1[0].plot()
72
+ frame2 = results2[0].plot()
73
+ frame = cv2.addWeighted(frame1, 0.5, frame2, 0.5, 0) # Combine results
74
+ else:
75
+ raise ValueError("Invalid detection type.")
76
+
77
+ # Annotate frame for single model type
78
+ if model_type in ["Traffic Cones", "License Plate"]:
79
+ results = model.predict(frame, conf=0.5, iou=0.6)
80
+ frame = results[0].plot() # Annotated frame
81
 
82
  # Write the processed frame
83
  out.write(frame)
 
89
 
90
  # Gradio prediction function
91
  def predict(input_file, model_type):
92
+ if not model_type:
93
+ raise ValueError("Please select a detection type before submitting.")
94
+
95
  if input_file.name.endswith(('.jpg', '.jpeg', '.png')):
96
  # Image input
97
  img = Image.open(input_file.name)
 
109
  fn=predict,
110
  inputs=[
111
  gr.File(label="Upload an image or video (JPG, PNG, MP4, AVI, etc.)"),
112
+ gr.Radio(["Traffic Cones", "License Plate", "Both"], label="Choose Detection Type"),
113
  ],
114
  outputs=[
115
  gr.Image(type="pil", label="Processed Image"), # Output for images
116
  gr.Video(label="Processed Video"), # Output for videos
117
  ],
118
  title="Object Detection for Traffic Cones and License Plates",
119
+ description="Upload an image or video to perform object detection. Select between Traffic Cones, License Plates, or Both detection types.",
120
+ ).launch(share=True)