File size: 2,305 Bytes
11ce781
 
1afa0e4
 
 
b535dd1
1afa0e4
 
 
 
 
a59862a
 
11ce781
1afa0e4
11ce781
 
 
 
 
 
 
 
 
 
 
1afa0e4
 
11ce781
1afa0e4
 
11ce781
1afa0e4
 
11ce781
1afa0e4
 
11ce781
 
 
 
 
 
1afa0e4
11ce781
1afa0e4
 
 
 
 
 
 
 
11ce781
1afa0e4
11ce781
1afa0e4
 
11ce781
 
1afa0e4
11ce781
 
1afa0e4
11ce781
1afa0e4
 
 
 
 
 
 
 
11ce781
1afa0e4
 
11ce781
 
1afa0e4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85


from collections import defaultdict
import cv2
from ultralytics import YOLO
import torch

def detect_fire_in_video(
    input_video_path: str,
    output_video_path: str,
    model_path: str,
    model.to('cuda' if torch.cuda.is_available() else 'cpu')

) -> str:
    """
    Detects fire in the given video using a YOLO model.
    It draws annotations on each frame and saves the output video.

    Args:
        input_video_path (str): Path to the input video file.
        output_video_path (str): Path to save the annotated output video.
        model_path (str): Path to the YOLO .pt file.
        device (str): 'cpu', 'cuda', or 'mps' for processing.

    Returns:
        str: The path to the output annotated video.
    """

    # Tracking history - optional usage
    track_history = defaultdict(lambda: [])

    # Load the YOLO model
    model = YOLO(model_path, device=device)

    # Open the video
    cap = cv2.VideoCapture(input_video_path)

    # Retrieve video properties
    w, h, fps = (int(cap.get(prop)) for prop in [
        cv2.CAP_PROP_FRAME_WIDTH,
        cv2.CAP_PROP_FRAME_HEIGHT,
        cv2.CAP_PROP_FPS
    ])

    # Prepare output video writer
    out = cv2.VideoWriter(
        output_video_path,
        cv2.VideoWriter_fourcc(*"MJPG"),
        fps,
        (w, h)
    )

    while True:
        ret, frame = cap.read()
        if not ret:
            print("Reached end of video or no frame retrieved.")
            break

        # Create an annotator to draw on the frame
        annotator = Annotator(frame, line_width=2)

        # Perform object tracking
        results = model.track(frame, persist=True)

        # If there are boxes with IDs and masks, annotate them
        if results[0].boxes.id is not None and results[0].masks is not None:
            masks = results[0].masks.xy
            track_ids = results[0].boxes.id.int().cpu().tolist()

            for mask, track_id in zip(masks, track_ids):
                annotator.seg_bbox(
                    mask=mask,
                    mask_color=colors(int(track_id), True),
                    label=f"ID:{track_id}"
                )

        # Write the annotated frame to output
        out.write(frame)

    out.release()
    cap.release()
    cv2.destroyAllWindows()

    return output_video_path