import os import time import cv2 import matplotlib.pyplot as plt from PIL import Image import numpy as np import onnxruntime as ort import pandas as pd from typing import Tuple from huggingface_hub import hf_hub_download from constants import REPO_ID, FILENAME, MODEL_DIR, MODEL_PATH from metrics_storage import MetricsStorage def download_model(): """Download the model using Hugging Face Hub""" # Ensure model directory exists os.makedirs(MODEL_DIR, exist_ok=True) try: print(f"Downloading model from {REPO_ID}...") # Download the model file from Hugging Face Hub model_path = hf_hub_download( repo_id=REPO_ID, filename=FILENAME, local_dir=MODEL_DIR, force_download=True, cache_dir=None, ) # Move the file to the correct location if it's not there already if os.path.exists(model_path) and model_path != MODEL_PATH: os.rename(model_path, MODEL_PATH) # Remove empty directories if they exist empty_dir = os.path.join(MODEL_DIR, "tune") if os.path.exists(empty_dir): import shutil shutil.rmtree(empty_dir) print("Model downloaded successfully!") return MODEL_PATH except Exception as e: print(f"Error downloading model: {e}") raise e class SignatureDetector: def __init__(self, model_path: str = MODEL_PATH): self.model_path = model_path self.classes = ["signature"] self.input_width = 640 self.input_height = 640 # Initialize ONNX Runtime session options = ort.SessionOptions() options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL self.session = ort.InferenceSession(self.model_path, options) self.session.set_providers( ["OpenVINOExecutionProvider"], [{"device_type": "CPU"}] ) self.metrics_storage = MetricsStorage() def update_metrics(self, inference_time: float) -> None: """ Updates metrics in persistent storage. Args: inference_time (float): The time taken for inference in milliseconds. """ self.metrics_storage.add_metric(inference_time) def get_metrics(self) -> dict: """ Retrieves current metrics from storage. Returns: dict: A dictionary containing times, total inferences, average time, and start index. """ times = self.metrics_storage.get_recent_metrics() total = self.metrics_storage.get_total_inferences() avg = self.metrics_storage.get_average_time() start_index = max(0, total - len(times)) return { "times": times, "total_inferences": total, "avg_time": avg, "start_index": start_index, } def load_initial_metrics( self, ) -> Tuple[None, str, plt.Figure, plt.Figure, str, str]: """ Loads initial metrics for display. Returns: tuple: A tuple containing None, total inferences, histogram figure, line figure, average time, and last time. """ metrics = self.get_metrics() if not metrics["times"]: return None, None, None, None, None, None hist_data = pd.DataFrame({"Time (ms)": metrics["times"]}) indices = range( metrics["start_index"], metrics["start_index"] + len(metrics["times"]) ) line_data = pd.DataFrame( { "Inference": indices, "Time (ms)": metrics["times"], "Mean": [metrics["avg_time"]] * len(metrics["times"]), } ) hist_fig, line_fig = self.create_plots(hist_data, line_data) return ( None, f"{metrics['total_inferences']}", hist_fig, line_fig, f"{metrics['avg_time']:.2f}", f"{metrics['times'][-1]:.2f}", ) def create_plots( self, hist_data: pd.DataFrame, line_data: pd.DataFrame ) -> Tuple[plt.Figure, plt.Figure]: """ Helper method to create plots. Args: hist_data (pd.DataFrame): Data for histogram plot. line_data (pd.DataFrame): Data for line plot. Returns: tuple: A tuple containing histogram figure and line figure. """ plt.style.use("dark_background") # Histogram plot hist_fig, hist_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5") hist_ax.set_facecolor("#f0f0f5") hist_data.hist( bins=20, ax=hist_ax, color="#4F46E5", alpha=0.7, edgecolor="white" ) hist_ax.set_title( "Distribution of Inference Times", pad=15, fontsize=12, color="#1f2937", ) hist_ax.set_xlabel("Time (ms)", color="#374151") hist_ax.set_ylabel("Frequency", color="#374151") hist_ax.tick_params(colors="#4b5563") hist_ax.grid(True, linestyle="--", alpha=0.3) # Line plot line_fig, line_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5") line_ax.set_facecolor("#f0f0f5") line_data.plot( x="Inference", y="Time (ms)", ax=line_ax, color="#4F46E5", alpha=0.7, label="Time", ) line_data.plot( x="Inference", y="Mean", ax=line_ax, color="#DC2626", linestyle="--", label="Mean", ) line_ax.set_title( "Inference Time per Execution", pad=15, fontsize=12, color="#1f2937" ) line_ax.set_xlabel("Inference Number", color="#374151") line_ax.set_ylabel("Time (ms)", color="#374151") line_ax.tick_params(colors="#4b5563") line_ax.grid(True, linestyle="--", alpha=0.3) line_ax.legend( frameon=True, facecolor="#f0f0f5", edgecolor="white", labelcolor="black" ) hist_fig.tight_layout() line_fig.tight_layout() plt.close(hist_fig) plt.close(line_fig) return hist_fig, line_fig def preprocess(self, img: Image.Image) -> Tuple[np.ndarray, np.ndarray]: """ Preprocesses the image for inference. Args: img: The image to process. Returns: tuple: A tuple containing the processed image data and the original image. """ # Convert PIL Image to cv2 format img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) self.img_height, self.img_width = img_cv2.shape[:2] # Convert back to RGB for processing img_rgb = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB) # Resize img_resized = cv2.resize(img_rgb, (self.input_width, self.input_height)) # Normalize and transpose image_data = np.array(img_resized) / 255.0 image_data = np.transpose(image_data, (2, 0, 1)) image_data = np.expand_dims(image_data, axis=0).astype(np.float32) return image_data, img_cv2 def draw_detections( self, img: np.ndarray, box: list, score: float, class_id: int ) -> None: """ Draws the detections on the image. Args: img: The image to draw on. box (list): The bounding box coordinates. score (float): The confidence score. class_id (int): The class ID. """ x1, y1, w, h = box self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3)) color = self.color_palette[class_id] cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2) label = f"{self.classes[class_id]}: {score:.2f}" (label_width, label_height), _ = cv2.getTextSize( label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 ) label_x = x1 label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10 cv2.rectangle( img, (int(label_x), int(label_y - label_height)), (int(label_x + label_width), int(label_y + label_height)), color, cv2.FILLED, ) cv2.putText( img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA, ) def postprocess( self, input_image: np.ndarray, output: np.ndarray, conf_thres: float, iou_thres: float, ) -> np.ndarray: """ Postprocesses the output from inference. Args: input_image: The input image. output: The output from inference. conf_thres (float): Confidence threshold for detection. iou_thres (float): Intersection over Union threshold for detection. Returns: np.ndarray: The output image with detections drawn """ outputs = np.transpose(np.squeeze(output[0])) rows = outputs.shape[0] boxes = [] scores = [] class_ids = [] x_factor = self.img_width / self.input_width y_factor = self.img_height / self.input_height for i in range(rows): classes_scores = outputs[i][4:] max_score = np.amax(classes_scores) if max_score >= conf_thres: class_id = np.argmax(classes_scores) x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3] left = int((x - w / 2) * x_factor) top = int((y - h / 2) * y_factor) width = int(w * x_factor) height = int(h * y_factor) class_ids.append(class_id) scores.append(max_score) boxes.append([left, top, width, height]) indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres, iou_thres) for i in indices: box = boxes[i] score = scores[i] class_id = class_ids[i] self.draw_detections(input_image, box, score, class_id) return cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) def detect( self, image: Image.Image, conf_thres: float = 0.25, iou_thres: float = 0.5 ) -> Tuple[Image.Image, dict]: """ Detects signatures in the given image. Args: image: The image to process. conf_thres (float): Confidence threshold for detection. iou_thres (float): Intersection over Union threshold for detection. Returns: tuple: A tuple containing the output image and metrics. """ # Preprocess the image img_data, original_image = self.preprocess(image) # Run inference start_time = time.time() outputs = self.session.run(None, {self.session.get_inputs()[0].name: img_data}) inference_time = (time.time() - start_time) * 1000 # Convert to milliseconds # Postprocess the results output_image = self.postprocess(original_image, outputs, conf_thres, iou_thres) self.update_metrics(inference_time) return output_image, self.get_metrics() def detect_example( self, image: Image.Image, conf_thres: float = 0.25, iou_thres: float = 0.5 ) -> Image.Image: """ Wrapper method for examples that returns only the image. Args: image: The image to process. conf_thres (float): Confidence threshold for detection. iou_thres (float): Intersection over Union threshold for detection. Returns: The output image. """ output_image, _ = self.detect(image, conf_thres, iou_thres) return output_image