arsath-sm commited on
Commit
1c1fcf9
·
verified ·
1 Parent(s): 5c7d4a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -159
app.py CHANGED
@@ -4,133 +4,131 @@ import numpy as np
4
  import onnxruntime as ort
5
  from PIL import Image
6
  import tempfile
 
 
7
 
8
- # Define class labels
9
- CLASSES = {
10
- 0: "Vehicle",
11
- 1: "License_Plate"
12
- }
13
-
14
- # Load the ONNX model
15
  @st.cache_resource
16
- def load_model():
17
- return ort.InferenceSession("model.onnx")
18
-
19
- ort_session = load_model()
20
-
21
- def preprocess_image(image, target_size=(640, 640)):
22
- if isinstance(image, Image.Image):
23
- image = np.array(image)
24
-
25
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
26
- original_shape = image.shape[:2]
27
- image = cv2.resize(image, target_size)
28
- image = image.astype(np.float32) / 255.0
29
- image = np.transpose(image, (2, 0, 1))
30
- image = np.expand_dims(image, axis=0)
31
- return image, original_shape
32
-
33
- def postprocess_results(output, original_shape, confidence_threshold=0.25, iou_threshold=0.45):
34
- if isinstance(output, (list, tuple)):
35
- predictions = output[0]
36
- elif isinstance(output, np.ndarray):
37
- predictions = output
38
- else:
39
- raise ValueError(f"Unexpected output type: {type(output)}")
40
 
41
- if len(predictions.shape) == 4:
42
- predictions = predictions.squeeze((0, 1))
43
- elif len(predictions.shape) == 3:
44
- predictions = predictions.squeeze(0)
45
 
46
- # Extract boxes, scores, and class_ids
47
- boxes = predictions[:, :4]
48
- scores = predictions[:, 4]
49
- class_ids = predictions[:, 5]
50
 
51
- # Filter by confidence
52
- mask = scores > confidence_threshold
53
- boxes = boxes[mask]
54
- scores = scores[mask]
55
- class_ids = class_ids[mask]
56
 
57
- # Convert boxes from [x, y, w, h] to [x1, y1, x2, y2]
58
- boxes[:, 2:] += boxes[:, :2]
59
 
60
- # Scale boxes to original image size
61
- h, w = original_shape
62
- boxes[:, [0, 2]] *= w
63
- boxes[:, [1, 3]] *= h
64
 
65
- # Apply NMS for each class separately
66
- results = []
67
- for class_id in np.unique(class_ids):
68
- class_mask = class_ids == class_id
69
- class_boxes = boxes[class_mask]
70
- class_scores = scores[class_mask]
71
-
72
- indices = cv2.dnn.NMSBoxes(
73
- class_boxes.tolist(),
74
- class_scores.tolist(),
75
- confidence_threshold,
76
- iou_threshold
77
- )
78
-
79
- for i in indices:
80
- box = class_boxes[i]
81
- score = class_scores[i]
82
- x1, y1, x2, y2 = map(int, box)
83
- results.append((x1, y1, x2, y2, float(score), int(class_id)))
84
-
85
- return results
86
 
87
- def process_image(image):
88
- orig_image = image.copy()
89
- processed_image, original_shape = preprocess_image(image)
90
-
91
- # Run inference
92
- inputs = {ort_session.get_inputs()[0].name: processed_image}
93
- outputs = ort_session.run(None, inputs)
94
-
95
- results = postprocess_results(outputs, original_shape)
96
 
97
- # Draw bounding boxes on the image
98
- for x1, y1, x2, y2, score, class_id in results:
99
- # Draw rectangle with white color
100
- cv2.rectangle(orig_image, (x1, y1), (x2, y2), (255, 255, 255), 2)
101
-
102
- # Get class name
103
- class_name = CLASSES.get(class_id, f"Class_{class_id}")
104
- label = f"{class_name}: {score:.2f}"
105
-
106
- # Add label background and text
107
- (text_width, text_height), _ = cv2.getTextSize(
108
- label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1
109
- )
110
-
111
- # Draw black background for text
112
- cv2.rectangle(
113
- orig_image,
114
- (x1, y1 - text_height - 4),
115
- (x1 + text_width, y1),
116
- (0, 0, 0),
117
- -1
118
- )
119
-
120
- # Draw white text
121
- cv2.putText(
122
- orig_image,
123
- label,
124
- (x1, y1 - 5),
125
- cv2.FONT_HERSHEY_SIMPLEX,
126
- 0.6,
127
- (255, 255, 255),
128
- 1
129
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- return cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
132
 
133
- def process_video(video_path):
134
  cap = cv2.VideoCapture(video_path)
135
 
136
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -138,12 +136,10 @@ def process_video(video_path):
138
  fps = int(cap.get(cv2.CAP_PROP_FPS))
139
 
140
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
141
- out = cv2.VideoWriter(
142
- temp_file.name,
143
- cv2.VideoWriter_fourcc(*'mp4v'),
144
- fps,
145
- (width, height)
146
- )
147
 
148
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
149
  progress_bar = st.progress(0)
@@ -154,8 +150,8 @@ def process_video(video_path):
154
  if not ret:
155
  break
156
 
157
- processed_frame = process_image(frame)
158
- out.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
159
 
160
  frame_count += 1
161
  progress_bar.progress(frame_count / total_frames)
@@ -167,43 +163,37 @@ def process_video(video_path):
167
  return temp_file.name
168
 
169
  # Streamlit UI
170
- st.title("Vehicle and License Plate Detection")
171
-
172
- # Add confidence threshold slider
173
- confidence_threshold = st.slider(
174
- "Confidence Threshold",
175
- min_value=0.0,
176
- max_value=1.0,
177
- value=0.25,
178
- step=0.05
179
- )
180
 
181
- uploaded_file = st.file_uploader(
182
- "Choose an image or video file",
183
- type=["jpg", "jpeg", "png", "mp4"]
184
- )
185
-
186
- if uploaded_file is not None:
187
- file_type = uploaded_file.type.split('/')[0]
188
-
189
- if file_type == "image":
190
- image = Image.open(uploaded_file)
191
- st.image(image, caption="Uploaded Image", use_column_width=True)
192
-
193
- if st.button("Detect Objects"):
194
- with st.spinner("Processing image..."):
195
- processed_image = process_image(np.array(image))
196
- st.image(processed_image, caption="Processed Image", use_column_width=True)
197
 
198
- elif file_type == "video":
199
- tfile = tempfile.NamedTemporaryFile(delete=False)
200
- tfile.write(uploaded_file.read())
 
201
 
202
- st.video(tfile.name)
 
 
 
 
 
 
 
 
 
 
203
 
204
- if st.button("Detect Objects"):
205
- with st.spinner("Processing video..."):
206
- processed_video = process_video(tfile.name)
207
- st.video(processed_video)
208
-
209
- st.write("Upload an image or video to detect vehicles and license plates.")
 
 
 
 
 
 
 
4
  import onnxruntime as ort
5
  from PIL import Image
6
  import tempfile
7
+ import torch
8
+ from ultralytics import YOLO
9
 
10
+ # Load models
 
 
 
 
 
 
11
  @st.cache_resource
12
+ def load_models():
13
+ license_plate_detector = YOLO('license_plate_detector.pt')
14
+ vehicle_detector = YOLO('yolov8n.pt')
15
+ ort_session = ort.InferenceSession("model.onnx")
16
+ return license_plate_detector, vehicle_detector, ort_session
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ def draw_border(img, top_left, bottom_right, color=(0, 255, 0), thickness=10, line_length_x=200, line_length_y=200):
19
+ x1, y1 = top_left
20
+ x2, y2 = bottom_right
 
21
 
22
+ # Draw corner lines
23
+ cv2.line(img, (x1, y1), (x1, y1 + line_length_y), color, thickness) # top-left
24
+ cv2.line(img, (x1, y1), (x1 + line_length_x, y1), color, thickness)
 
25
 
26
+ cv2.line(img, (x1, y2), (x1, y2 - line_length_y), color, thickness) # bottom-left
27
+ cv2.line(img, (x1, y2), (x1 + line_length_x, y2), color, thickness)
 
 
 
28
 
29
+ cv2.line(img, (x2, y1), (x2 - line_length_x, y1), color, thickness) # top-right
30
+ cv2.line(img, (x2, y1), (x2, y1 + line_length_y), color, thickness)
31
 
32
+ cv2.line(img, (x2, y2), (x2, y2 - line_length_y), color, thickness) # bottom-right
33
+ cv2.line(img, (x2, y2), (x2 - line_length_x, y2), color, thickness)
 
 
34
 
35
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def process_frame(frame, license_plate_detector, vehicle_detector, ort_session):
38
+ # Detect vehicles
39
+ vehicle_results = vehicle_detector(frame, classes=[2, 3, 5, 7]) # cars, motorcycles, bus, trucks
 
 
 
 
 
 
40
 
41
+ # Process each vehicle
42
+ for vehicle in vehicle_results[0].boxes.data:
43
+ x1, y1, x2, y2, score, class_id = vehicle
44
+ if score > 0.5: # Confidence threshold
45
+ # Draw vehicle border
46
+ draw_border(frame,
47
+ (int(x1), int(y1)),
48
+ (int(x2), int(y2)),
49
+ color=(0, 255, 0),
50
+ thickness=25,
51
+ line_length_x=200,
52
+ line_length_y=200)
53
+
54
+ # Detect license plate in vehicle region
55
+ vehicle_crop = frame[int(y1):int(y2), int(x1):int(x2)]
56
+ license_results = license_plate_detector(vehicle_crop)
57
+
58
+ for license_plate in license_results[0].boxes.data:
59
+ lp_x1, lp_y1, lp_x2, lp_y2, lp_score, _ = license_plate
60
+ if lp_score > 0.5:
61
+ # Adjust coordinates to full frame
62
+ abs_lp_x1 = int(x1 + lp_x1)
63
+ abs_lp_y1 = int(y1 + lp_y1)
64
+ abs_lp_x2 = int(x1 + lp_x2)
65
+ abs_lp_y2 = int(y1 + lp_y2)
66
+
67
+ # Draw license plate box
68
+ cv2.rectangle(frame,
69
+ (abs_lp_x1, abs_lp_y1),
70
+ (abs_lp_x2, abs_lp_y2),
71
+ (0, 0, 255), 12)
72
+
73
+ # Extract and process license plate for OCR
74
+ license_crop = frame[abs_lp_y1:abs_lp_y2, abs_lp_x1:abs_lp_x2]
75
+ if license_crop.size > 0:
76
+ # Prepare license crop for ONNX model
77
+ license_crop_resized = cv2.resize(license_crop, (640, 640))
78
+ license_crop_processed = np.transpose(license_crop_resized, (2, 0, 1)).astype(np.float32) / 255.0
79
+ license_crop_processed = np.expand_dims(license_crop_processed, axis=0)
80
+
81
+ # Run OCR inference
82
+ try:
83
+ inputs = {ort_session.get_inputs()[0].name: license_crop_processed}
84
+ outputs = ort_session.run(None, inputs)
85
+
86
+ # Process OCR output (adjust based on your model's output format)
87
+ # This is a placeholder - adjust based on your ONNX model's output
88
+ license_number = "ABC123" # Replace with actual OCR processing
89
+
90
+ # Display license plate number
91
+ H, W, _ = license_crop.shape
92
+ license_crop_display = cv2.resize(license_crop, (int(W * 400 / H), 400))
93
+
94
+ try:
95
+ # Display license crop and number above vehicle
96
+ h_crop, w_crop, _ = license_crop_display.shape
97
+ center_x = int((x1 + x2) / 2)
98
+
99
+ # Display license plate crop
100
+ frame[int(y1) - h_crop - 100:int(y1) - 100,
101
+ int(center_x - w_crop/2):int(center_x + w_crop/2)] = license_crop_display
102
+
103
+ # White background for text
104
+ cv2.rectangle(frame,
105
+ (int(center_x - w_crop/2), int(y1) - h_crop - 400),
106
+ (int(center_x + w_crop/2), int(y1) - h_crop - 100),
107
+ (255, 255, 255),
108
+ -1)
109
+
110
+ # Draw license number
111
+ (text_width, text_height), _ = cv2.getTextSize(
112
+ license_number,
113
+ cv2.FONT_HERSHEY_SIMPLEX,
114
+ 4.3,
115
+ 17)
116
+
117
+ cv2.putText(frame,
118
+ license_number,
119
+ (int(center_x - text_width/2), int(y1 - h_crop - 250 + text_height/2)),
120
+ cv2.FONT_HERSHEY_SIMPLEX,
121
+ 4.3,
122
+ (0, 0, 0),
123
+ 17)
124
+ except Exception as e:
125
+ st.error(f"Error displaying results: {str(e)}")
126
+ except Exception as e:
127
+ st.error(f"Error in OCR processing: {str(e)}")
128
 
129
+ return frame
130
 
131
+ def process_video(video_path, license_plate_detector, vehicle_detector, ort_session):
132
  cap = cv2.VideoCapture(video_path)
133
 
134
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
136
  fps = int(cap.get(cv2.CAP_PROP_FPS))
137
 
138
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
139
+ out = cv2.VideoWriter(temp_file.name,
140
+ cv2.VideoWriter_fourcc(*'mp4v'),
141
+ fps,
142
+ (width, height))
 
 
143
 
144
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
145
  progress_bar = st.progress(0)
 
150
  if not ret:
151
  break
152
 
153
+ processed_frame = process_frame(frame, license_plate_detector, vehicle_detector, ort_session)
154
+ out.write(processed_frame)
155
 
156
  frame_count += 1
157
  progress_bar.progress(frame_count / total_frames)
 
163
  return temp_file.name
164
 
165
  # Streamlit UI
166
+ st.title("Advanced Vehicle and License Plate Detection")
 
 
 
 
 
 
 
 
 
167
 
168
+ try:
169
+ license_plate_detector, vehicle_detector, ort_session = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"])
172
+
173
+ if uploaded_file is not None:
174
+ file_type = uploaded_file.type.split('/')[0]
175
 
176
+ if file_type == "image":
177
+ image = Image.open(uploaded_file)
178
+ st.image(image, caption="Uploaded Image", use_column_width=True)
179
+
180
+ if st.button("Detect"):
181
+ with st.spinner("Processing image..."):
182
+ # Convert PIL Image to CV2 format
183
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
184
+ processed_image = process_frame(image_cv, license_plate_detector, vehicle_detector, ort_session)
185
+ processed_image = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
186
+ st.image(processed_image, caption="Processed Image", use_column_width=True)
187
 
188
+ elif file_type == "video":
189
+ tfile = tempfile.NamedTemporaryFile(delete=False)
190
+ tfile.write(uploaded_file.read())
191
+ st.video(tfile.name)
192
+
193
+ if st.button("Detect"):
194
+ with st.spinner("Processing video..."):
195
+ processed_video = process_video(tfile.name, license_plate_detector, vehicle_detector, ort_session)
196
+ st.video(processed_video)
197
+
198
+ except Exception as e:
199
+ st.error(f"Error loading models: {str(e)}")