Detection / app.py
SamiKhokhar's picture
Create app.py
94886ef verified
import torch
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image
import cv2
import numpy as np
import time
from flask import Flask, jsonify, request
# Initialize Flask app
app = Flask(__name__)
# Device setup (GPU or CPU)
device = 'cpu'
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
# Load pre-trained model and image processor from Hugging Face
ckpt = 'yainage90/fashion-object-detection'
image_processor = AutoImageProcessor.from_pretrained(ckpt)
model = AutoModelForObjectDetection.from_pretrained(ckpt).to(device)
def detect_objects(frame):
"""Detect objects in the video frame."""
# Convert the frame to PIL image
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Prepare inputs for the model
with torch.no_grad():
inputs = image_processor(images=[image], return_tensors="pt")
outputs = model(**inputs.to(device))
target_sizes = torch.tensor([[image.size[1], image.size[0]]])
results = image_processor.post_process_object_detection(outputs, threshold=0.4, target_sizes=target_sizes)[0]
# Extract the detected items
items = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
score = score.item()
label = label.item()
box = [i.item() for i in box]
print(f"{model.config.id2label[label]}: {round(score, 3)} at {box}")
items.append((score, label, box))
return items
def save_data(frame, items):
"""Save image and extract plate number."""
filename = f"helmet_violation_{int(time.time())}.jpg"
cv2.imwrite(filename, frame)
# Here, you'd extract plate numbers or process further
plate_number = extract_plate_number(frame)
save_to_database(filename, plate_number, items)
def extract_plate_number(frame):
"""Extract license plate number (simplified)."""
plate_number = "XYZ 1234" # Replace with an actual license plate recognition method
return plate_number
def save_to_database(image_filename, plate_number, items):
"""Save the data (for simplicity, we just print it here)."""
print(f"Plate Number: {plate_number}, Image saved as {image_filename}")
print("Detected items:", items)
@app.route("/process_frame", methods=["POST"])
def process_frame():
"""Process incoming video frame via API."""
frame = request.files["frame"].read()
np_array = np.frombuffer(frame, np.uint8)
img = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
# Detect objects (e.g., helmets) in the frame
items = detect_objects(img)
if items: # If objects are detected, save the data
save_data(img, items)
return jsonify({"status": "processed"})
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=5000)