import supervision
import tqdm
import os
from ultralytics import YOLO
from dataclasses import dataclass
from onemetric.cv.utils.iou import box_iou_batch
from supervision import Point
from supervision import Detections, BoxAnnotator
from supervision import draw_text
from supervision import Color
from supervision import VideoInfo
from supervision import get_video_frames_generator
from supervision import VideoSink
os.system("pip install git+https://github.com/ifzhang/ByteTrack")
from typing import List
import numpy as np
import gradio as gr
from tqdm import tqdm
import yolox
os.system("pip3 install cython_bbox gdown 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'")
os.system("pip3 install -v -e .")
from yolox.tracker.byte_tracker import BYTETracker, STrack

MODEL = "./best.pt"

TARGET_VIDEO_PATH = "test.mp4"

CLASS_ID = [0,1,2,3,4,5,6]

video_examples = [['example.mp4']]

model = YOLO(MODEL)
model.fuse()

classes = CLASS_ID

@dataclass(frozen=True)
class BYTETrackerArgs:
    track_thresh: float = 0.25
    track_buffer: int = 30
    match_thresh: float = 0.8
    aspect_ratio_thresh: float = 3.0
    min_box_area: float = 1.0
    mot20: bool = False


# converts Detections into format that can be consumed by match_detections_with_tracks function
def detections2boxes(detections : Detections) -> np.ndarray:
    return np.hstack((
        detections.xyxy,
        detections.confidence[:, np.newaxis]
    ))


# converts List[STrack] into format that can be consumed by match_detections_with_tracks function
def tracks2boxes(tracks: List[STrack]) -> np.ndarray:
    return np.array([
        track.tlbr
        for track
        in tracks
    ], dtype=float)


# matches our bounding boxes with predictions
def match_detections_with_tracks(
    detections: Detections,
    tracks: List[STrack],
) -> Detections:
    if not np.any(detections.xyxy) or len(tracks) == 0:
        return np.empty((0,))

    tracks_boxes = tracks2boxes(tracks=tracks)
    iou = box_iou_batch(tracks_boxes, detections.xyxy)
    track2detection = np.argmax(iou, axis=1)

    tracker_ids = [None] * len(detections)

    for tracker_index, detection_index in enumerate(track2detection):
        if iou[tracker_index, detection_index] != 0:
            tracker_ids[detection_index] = tracks[tracker_index].track_id

    return tracker_ids

def ObjectDetection(video_path):
    byte_tracker = BYTETracker(BYTETrackerArgs())
    video_info = VideoInfo.from_video_path(video_path)
    generator = get_video_frames_generator(video_path)
    box_annotator = BoxAnnotator(thickness=5, text_thickness=5, text_scale=1)
    #polygon
    polygon = np.array([[200,300], [200,1420], [880, 1420], [880, 300]])
    #zone
    zone = supervision.PolygonZone(polygon=polygon, frame_resolution_wh=video_info.resolution_wh)
    #zone annotator
    zone_annotator = supervision.PolygonZoneAnnotator(zone=zone, color=Color.white(), thickness=4)
    # open target video file
    with VideoSink(TARGET_VIDEO_PATH, video_info) as sink:
        # loop over video frames
        for frame in tqdm(generator, total=video_info.total_frames):
            results = model(frame)
            detections = Detections(
                xyxy=results[0].boxes.xyxy.cpu().numpy(),
                confidence=results[0].boxes.conf.cpu().numpy(),
                class_id=results[0].boxes.cls.cpu().numpy().astype(int)
            )
            # filtering out detections with unwanted classes
            detections = detections[np.isin(detections.class_id, CLASS_ID)]
            # tracking detections
            tracks = byte_tracker.update(
                output_results=detections2boxes(detections = detections),
                img_info=frame.shape,
                img_size=frame.shape
            )
            tracker_id = match_detections_with_tracks(detections=detections, tracks=tracks)
            detections.tracker_id = np.array(tracker_id)
            # filtering out detections without trackers
            detections = detections[np.not_equal(detections.tracker_id, None)]
            # format custom labels
            labels = [
                f"#{tracker_id} {classes[class_id]} {confidence:0.2f}"
                for _, _, confidence, class_id, tracker_id
                in detections
            ]
            t = np.unique(detections.class_id, return_counts =True)
            # annotate and display frame
            mask = zone.trigger(detections=detections)
            detections_filtered = detections[mask]
            t = np.unique(detections_filtered.class_id, return_counts =True)
            for x in zip(t[0], t[1]):
                frame = draw_text(background_color=Color.white(), scene=frame, text=' '.join((str(classes[x[0]]), ':', str(x[1]))), text_anchor=Point(x=500, y=1550 + (50 * x[0])), text_scale = 2, text_thickness = 4)
            frame = box_annotator.annotate(scene=frame, detections=detections_filtered, labels=labels)
            frame = zone_annotator.annotate(scene=frame)
            sink.write_frame(frame)

    return TARGET_VIDEO_PATH

demo = gr.Interface(fn=ObjectDetection, inputs=gr.Video(), outputs=gr.Video(), examples=video_examples, cache_examples=False)
demo.queue().launch()