File size: 1,516 Bytes
d16b52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import random
from typing import Optional

import torch


class SimilarImageFilter:
    def __init__(self, threshold: float = 0.98, max_skip_frame: float = 10) -> None:
        self.threshold = threshold
        self.prev_tensor = None
        self.cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
        self.max_skip_frame = max_skip_frame
        self.skip_count = 0

    def __call__(self, x: torch.Tensor) -> Optional[torch.Tensor]:
        if self.prev_tensor is None:
            self.prev_tensor = x.detach().clone()
            return x
        else:
            cos_sim = self.cos(self.prev_tensor.reshape(-1), x.reshape(-1)).item()
            sample = random.uniform(0, 1)
            if self.threshold >= 1:
                skip_prob = 0
            else:
                skip_prob = max(0, 1 - (1 - cos_sim) / (1 - self.threshold))

            # not skip frame
            if skip_prob < sample:
                self.prev_tensor = x.detach().clone()
                return x
            # skip frame
            else:
                if self.skip_count > self.max_skip_frame:
                    self.skip_count = 0
                    self.prev_tensor = x.detach().clone()
                    return x
                else:
                    self.skip_count += 1
                    return None

    def set_threshold(self, threshold: float) -> None:
        self.threshold = threshold

    def set_max_skip_frame(self, max_skip_frame: float) -> None:
        self.max_skip_frame = max_skip_frame