File size: 12,101 Bytes
6eb1d7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.

# pyre-unsafe

import csv
import logging
import numpy as np
from typing import Any, Callable, Dict, List, Optional, Union
import av
import torch
from torch.utils.data.dataset import Dataset

from detectron2.utils.file_io import PathManager

from ..utils import maybe_prepend_base_path
from .frame_selector import FrameSelector, FrameTsList

FrameList = List[av.frame.Frame]  # pyre-ignore[16]
FrameTransform = Callable[[torch.Tensor], torch.Tensor]


def list_keyframes(video_fpath: str, video_stream_idx: int = 0) -> FrameTsList:
    """
    Traverses all keyframes of a video file. Returns a list of keyframe
    timestamps. Timestamps are counts in timebase units.

    Args:
       video_fpath (str): Video file path
       video_stream_idx (int): Video stream index (default: 0)
    Returns:
       List[int]: list of keyframe timestaps (timestamp is a count in timebase
           units)
    """
    try:
        with PathManager.open(video_fpath, "rb") as io:
            # pyre-fixme[16]: Module `av` has no attribute `open`.
            container = av.open(io, mode="r")
            stream = container.streams.video[video_stream_idx]
            keyframes = []
            pts = -1
            # Note: even though we request forward seeks for keyframes, sometimes
            # a keyframe in backwards direction is returned. We introduce tolerance
            # as a max count of ignored backward seeks
            tolerance_backward_seeks = 2
            while True:
                try:
                    container.seek(pts + 1, backward=False, any_frame=False, stream=stream)
                except av.AVError as e:
                    # the exception occurs when the video length is exceeded,
                    # we then return whatever data we've already collected
                    logger = logging.getLogger(__name__)
                    logger.debug(
                        f"List keyframes: Error seeking video file {video_fpath}, "
                        f"video stream {video_stream_idx}, pts {pts + 1}, AV error: {e}"
                    )
                    return keyframes
                except OSError as e:
                    logger = logging.getLogger(__name__)
                    logger.warning(
                        f"List keyframes: Error seeking video file {video_fpath}, "
                        f"video stream {video_stream_idx}, pts {pts + 1}, OS error: {e}"
                    )
                    return []
                packet = next(container.demux(video=video_stream_idx))
                if packet.pts is not None and packet.pts <= pts:
                    logger = logging.getLogger(__name__)
                    logger.warning(
                        f"Video file {video_fpath}, stream {video_stream_idx}: "
                        f"bad seek for packet {pts + 1} (got packet {packet.pts}), "
                        f"tolerance {tolerance_backward_seeks}."
                    )
                    tolerance_backward_seeks -= 1
                    if tolerance_backward_seeks == 0:
                        return []
                    pts += 1
                    continue
                tolerance_backward_seeks = 2
                pts = packet.pts
                if pts is None:
                    return keyframes
                if packet.is_keyframe:
                    keyframes.append(pts)
            return keyframes
    except OSError as e:
        logger = logging.getLogger(__name__)
        logger.warning(
            f"List keyframes: Error opening video file container {video_fpath}, " f"OS error: {e}"
        )
    except RuntimeError as e:
        logger = logging.getLogger(__name__)
        logger.warning(
            f"List keyframes: Error opening video file container {video_fpath}, "
            f"Runtime error: {e}"
        )
    return []


def read_keyframes(
    video_fpath: str, keyframes: FrameTsList, video_stream_idx: int = 0
) -> FrameList:  # pyre-ignore[11]
    """
    Reads keyframe data from a video file.

    Args:
        video_fpath (str): Video file path
        keyframes (List[int]): List of keyframe timestamps (as counts in
            timebase units to be used in container seek operations)
        video_stream_idx (int): Video stream index (default: 0)
    Returns:
        List[Frame]: list of frames that correspond to the specified timestamps
    """
    try:
        with PathManager.open(video_fpath, "rb") as io:
            # pyre-fixme[16]: Module `av` has no attribute `open`.
            container = av.open(io)
            stream = container.streams.video[video_stream_idx]
            frames = []
            for pts in keyframes:
                try:
                    container.seek(pts, any_frame=False, stream=stream)
                    frame = next(container.decode(video=0))
                    frames.append(frame)
                except av.AVError as e:
                    logger = logging.getLogger(__name__)
                    logger.warning(
                        f"Read keyframes: Error seeking video file {video_fpath}, "
                        f"video stream {video_stream_idx}, pts {pts}, AV error: {e}"
                    )
                    container.close()
                    return frames
                except OSError as e:
                    logger = logging.getLogger(__name__)
                    logger.warning(
                        f"Read keyframes: Error seeking video file {video_fpath}, "
                        f"video stream {video_stream_idx}, pts {pts}, OS error: {e}"
                    )
                    container.close()
                    return frames
                except StopIteration:
                    logger = logging.getLogger(__name__)
                    logger.warning(
                        f"Read keyframes: Error decoding frame from {video_fpath}, "
                        f"video stream {video_stream_idx}, pts {pts}"
                    )
                    container.close()
                    return frames

            container.close()
            return frames
    except OSError as e:
        logger = logging.getLogger(__name__)
        logger.warning(
            f"Read keyframes: Error opening video file container {video_fpath}, OS error: {e}"
        )
    except RuntimeError as e:
        logger = logging.getLogger(__name__)
        logger.warning(
            f"Read keyframes: Error opening video file container {video_fpath}, Runtime error: {e}"
        )
    return []


def video_list_from_file(video_list_fpath: str, base_path: Optional[str] = None):
    """
    Create a list of paths to video files from a text file.

    Args:
        video_list_fpath (str): path to a plain text file with the list of videos
        base_path (str): base path for entries from the video list (default: None)
    """
    video_list = []
    with PathManager.open(video_list_fpath, "r") as io:
        for line in io:
            video_list.append(maybe_prepend_base_path(base_path, str(line.strip())))
    return video_list


def read_keyframe_helper_data(fpath: str):
    """
    Read keyframe data from a file in CSV format: the header should contain
    "video_id" and "keyframes" fields. Value specifications are:
      video_id: int
      keyframes: list(int)
    Example of contents:
      video_id,keyframes
      2,"[1,11,21,31,41,51,61,71,81]"

    Args:
        fpath (str): File containing keyframe data

    Return:
        video_id_to_keyframes (dict: int -> list(int)): for a given video ID it
          contains a list of keyframes for that video
    """
    video_id_to_keyframes = {}
    try:
        with PathManager.open(fpath, "r") as io:
            csv_reader = csv.reader(io)
            header = next(csv_reader)
            video_id_idx = header.index("video_id")
            keyframes_idx = header.index("keyframes")
            for row in csv_reader:
                video_id = int(row[video_id_idx])
                assert (
                    video_id not in video_id_to_keyframes
                ), f"Duplicate keyframes entry for video {fpath}"
                video_id_to_keyframes[video_id] = (
                    [int(v) for v in row[keyframes_idx][1:-1].split(",")]
                    if len(row[keyframes_idx]) > 2
                    else []
                )
    except Exception as e:
        logger = logging.getLogger(__name__)
        logger.warning(f"Error reading keyframe helper data from {fpath}: {e}")
    return video_id_to_keyframes


class VideoKeyframeDataset(Dataset):
    """
    Dataset that provides keyframes for a set of videos.
    """

    _EMPTY_FRAMES = torch.empty((0, 3, 1, 1))

    def __init__(
        self,
        video_list: List[str],
        category_list: Union[str, List[str], None] = None,
        frame_selector: Optional[FrameSelector] = None,
        transform: Optional[FrameTransform] = None,
        keyframe_helper_fpath: Optional[str] = None,
    ):
        """
        Dataset constructor

        Args:
            video_list (List[str]): list of paths to video files
            category_list (Union[str, List[str], None]): list of animal categories for each
                video file. If it is a string, or None, this applies to all videos
            frame_selector (Callable: KeyFrameList -> KeyFrameList):
                selects keyframes to process, keyframes are given by
                packet timestamps in timebase counts. If None, all keyframes
                are selected (default: None)
            transform (Callable: torch.Tensor -> torch.Tensor):
                transforms a batch of RGB images (tensors of size [B, 3, H, W]),
                returns a tensor of the same size. If None, no transform is
                applied (default: None)

        """
        if type(category_list) is list:
            self.category_list = category_list
        else:
            self.category_list = [category_list] * len(video_list)
        assert len(video_list) == len(
            self.category_list
        ), "length of video and category lists must be equal"
        self.video_list = video_list
        self.frame_selector = frame_selector
        self.transform = transform
        self.keyframe_helper_data = (
            read_keyframe_helper_data(keyframe_helper_fpath)
            if keyframe_helper_fpath is not None
            else None
        )

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """
        Gets selected keyframes from a given video

        Args:
            idx (int): video index in the video list file
        Returns:
            A dictionary containing two keys:
                images (torch.Tensor): tensor of size [N, H, W, 3] or of size
                    defined by the transform that contains keyframes data
                categories (List[str]): categories of the frames
        """
        categories = [self.category_list[idx]]
        fpath = self.video_list[idx]
        keyframes = (
            list_keyframes(fpath)
            if self.keyframe_helper_data is None or idx not in self.keyframe_helper_data
            else self.keyframe_helper_data[idx]
        )
        transform = self.transform
        frame_selector = self.frame_selector
        if not keyframes:
            return {"images": self._EMPTY_FRAMES, "categories": []}
        if frame_selector is not None:
            keyframes = frame_selector(keyframes)
        frames = read_keyframes(fpath, keyframes)
        if not frames:
            return {"images": self._EMPTY_FRAMES, "categories": []}
        frames = np.stack([frame.to_rgb().to_ndarray() for frame in frames])
        frames = torch.as_tensor(frames, device=torch.device("cpu"))
        frames = frames[..., [2, 1, 0]]  # RGB -> BGR
        frames = frames.permute(0, 3, 1, 2).float()  # NHWC -> NCHW
        if transform is not None:
            frames = transform(frames)
        return {"images": frames, "categories": categories}

    def __len__(self):
        return len(self.video_list)