Spaces:
Running
Running
import torch | |
from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import time | |
from flask import Flask, jsonify, request | |
# Initialize Flask app | |
app = Flask(__name__) | |
# Device setup (GPU or CPU) | |
device = 'cpu' | |
if torch.cuda.is_available(): | |
device = torch.device('cuda') | |
elif torch.backends.mps.is_available(): | |
device = torch.device('mps') | |
# Load pre-trained model and image processor from Hugging Face | |
ckpt = 'yainage90/fashion-object-detection' | |
image_processor = AutoImageProcessor.from_pretrained(ckpt) | |
model = AutoModelForObjectDetection.from_pretrained(ckpt).to(device) | |
def detect_objects(frame): | |
"""Detect objects in the video frame.""" | |
# Convert the frame to PIL image | |
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
# Prepare inputs for the model | |
with torch.no_grad(): | |
inputs = image_processor(images=[image], return_tensors="pt") | |
outputs = model(**inputs.to(device)) | |
target_sizes = torch.tensor([[image.size[1], image.size[0]]]) | |
results = image_processor.post_process_object_detection(outputs, threshold=0.4, target_sizes=target_sizes)[0] | |
# Extract the detected items | |
items = [] | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
score = score.item() | |
label = label.item() | |
box = [i.item() for i in box] | |
print(f"{model.config.id2label[label]}: {round(score, 3)} at {box}") | |
items.append((score, label, box)) | |
return items | |
def save_data(frame, items): | |
"""Save image and extract plate number.""" | |
filename = f"helmet_violation_{int(time.time())}.jpg" | |
cv2.imwrite(filename, frame) | |
# Here, you'd extract plate numbers or process further | |
plate_number = extract_plate_number(frame) | |
save_to_database(filename, plate_number, items) | |
def extract_plate_number(frame): | |
"""Extract license plate number (simplified).""" | |
plate_number = "XYZ 1234" # Replace with an actual license plate recognition method | |
return plate_number | |
def save_to_database(image_filename, plate_number, items): | |
"""Save the data (for simplicity, we just print it here).""" | |
print(f"Plate Number: {plate_number}, Image saved as {image_filename}") | |
print("Detected items:", items) | |
def process_frame(): | |
"""Process incoming video frame via API.""" | |
frame = request.files["frame"].read() | |
np_array = np.frombuffer(frame, np.uint8) | |
img = cv2.imdecode(np_array, cv2.IMREAD_COLOR) | |
# Detect objects (e.g., helmets) in the frame | |
items = detect_objects(img) | |
if items: # If objects are detected, save the data | |
save_data(img, items) | |
return jsonify({"status": "processed"}) | |
if __name__ == "__main__": | |
app.run(debug=True, host="0.0.0.0", port=5000) | |