anahita-b commited on
Commit
b305022
1 Parent(s): aa88645

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor
5
+
6
+ model_id = "BridgeTower/bridgetower-large-itm-mlm-gaudi"
7
+ processor = BridgeTowerProcessor.from_pretrained(model_id)
8
+ model = BridgeTowerForImageAndTextRetrieval.from_pretrained(model_id)
9
+
10
+ # Process a frame
11
+ def process_frame(image, texts):
12
+ scores = {}
13
+ texts = texts.split(",")
14
+ for t in texts:
15
+ encoding = processor(image, t, return_tensors="pt")
16
+ outputs = model(**encoding)
17
+ scores[t] = "{:.2f}".format(outputs.logits[0, 1].item())
18
+ # sort scores in descending order
19
+ scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
20
+ return scores
21
+
22
+
23
+ # Process a video
24
+ def process(video, text, sample_rate, min_score):
25
+ video = cv2.VideoCapture(video)
26
+ fps = round(video.get(cv2.CAP_PROP_FPS))
27
+ frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
28
+ length = frames // fps
29
+ print(f"{fps} fps, {frames} frames, {length} seconds")
30
+
31
+ frame_count = 0
32
+ clips = []
33
+ clip_images = []
34
+ clip_started = False
35
+ while True:
36
+ ret, frame = video.read()
37
+ if not ret:
38
+ break
39
+
40
+ if frame_count % (fps * sample_rate) == 0:
41
+ frame = Image.fromarray(frame)
42
+ score = process_frame(frame, text)
43
+ # print(f"{frame_count} {scores}")
44
+
45
+ if float(score[text]) > min_score:
46
+ if clip_started:
47
+ end_time = frame_count / fps
48
+ else:
49
+ clip_started = True
50
+ start_time = frame_count / fps
51
+ end_time = start_time
52
+ start_score = score[text]
53
+ clip_images.append(frame)
54
+ elif clip_started:
55
+ clip_started = False
56
+ end_time = frame_count / fps
57
+ clips.append((start_score, start_time, end_time))
58
+ frame_count += 1
59
+ return clip_images, clips
60
+
61
+
62
+ # Inputs
63
+ video = gr.Video(label="Video")
64
+ text = gr.Text(label="Text query")
65
+ sample_rate = gr.Number(value=5, label="Sample rate (1 frame every 'n' seconds)")
66
+ min_score = gr.Number(value=3, label="Minimum score")
67
+
68
+ # Output
69
+ gallery = gr.Gallery(label="Images")
70
+ clips = gr.Text(label="Clips (score, start time, end time)")
71
+
72
+ description = "This Space lets you run semantic search on a video."
73
+
74
+ iface = gr.Interface(
75
+ description=description,
76
+ fn=process,
77
+ inputs=[video, text, sample_rate, min_score],
78
+ outputs=[gallery, clips],
79
+ examples=[
80
+ [
81
+ "video.mp4",
82
+ "wild bears",
83
+ 5,
84
+ 3,
85
+ ]
86
+ ],
87
+ allow_flagging="never",
88
+ )
89
+
90
+ iface.launch()