merve HF staff commited on
Commit
d20f4f3
·
verified ·
1 Parent(s): 67513d2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection
2
+
3
+ processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
4
+ model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to("cuda")
5
+ from typing import List
6
+ import os
7
+ import numpy as np
8
+ import supervision as sv
9
+ import uuid
10
+ import torch
11
+ from tqdm import tqdm
12
+ import gradio as gr
13
+ import torch
14
+ import numpy as np
15
+ from PIL import Image
16
+
17
+
18
+ BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
19
+ MASK_ANNOTATOR = sv.MaskAnnotator()
20
+ LABEL_ANNOTATOR = sv.LabelAnnotator()
21
+
22
+
23
+ def calculate_end_frame_index(source_video_path):
24
+ video_info = sv.VideoInfo.from_video_path(source_video_path)
25
+ return min(
26
+ video_info.total_frames,
27
+ video_info.fps * 2
28
+ )
29
+
30
+
31
+ def annotate_image(
32
+ input_image,
33
+ detections,
34
+ labels
35
+ ) -> np.ndarray:
36
+ output_image = MASK_ANNOTATOR.annotate(input_image, detections)
37
+ output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
38
+ output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
39
+ return output_image
40
+
41
+
42
+
43
+ def process_video(
44
+ input_video,
45
+ labels,
46
+ progress=gr.Progress(track_tqdm=True)
47
+ ):
48
+ labels = labels.split(",")
49
+ video_info = sv.VideoInfo.from_video_path(input_video)
50
+ total = calculate_end_frame_index(input_video)
51
+ frame_generator = sv.get_video_frames_generator(
52
+ source_path=input_video,
53
+ end=total
54
+ )
55
+
56
+ result_file_name = f"{uuid.uuid4()}.mp4"
57
+ result_file_path = os.path.join("./outputs", result_file_name)
58
+ with sv.VideoSink(result_file_path, video_info=video_info) as sink:
59
+ for _ in tqdm(range(total), desc="Processing video.."):
60
+ frame = next(frame_generator)
61
+ # list of dict of {"box": box, "mask":mask, "score":score, "label":label}
62
+ results = query(frame, labels)
63
+
64
+ #detections = sv.Detections.empty()
65
+ detections = sv.Detections.from_transformers(results[0])
66
+ final_labels = []
67
+ for id in results[0]["labels"]:
68
+ final_labels.append(labels[id])
69
+ frame = annotate_image(
70
+ input_image=frame,
71
+ detections=detections,
72
+ labels=final_labels,
73
+ )
74
+ sink.write_frame(frame)
75
+ return result_file_path
76
+
77
+
78
+ def query(image, texts):
79
+ inputs = processor(text=texts, images=image, return_tensors="pt").to("cuda")
80
+ with torch.no_grad():
81
+ outputs = model(**inputs)
82
+ target_sizes = torch.Tensor([image.shape[:-1]])
83
+
84
+ results = processor.post_process_object_detection(outputs=outputs, threshold=0.3, target_sizes=target_sizes)
85
+ return results
86
+
87
+
88
+
89
+ with gr.Blocks() as demo:
90
+ with gr.Markdown(" ## Zero-shot Object Tracking with OWLv2 🦉")
91
+ with gr.Markdown("This is a demo for zero-shot object tracking using [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) model by Google.")
92
+ with gr.Markdown("Simply upload a video and enter the candidate labels, or try the example below. 👇")
93
+ with gr.Tab(label="Video"):
94
+ with gr.Row():
95
+ input_video = gr.Video(
96
+ label='Input Video'
97
+ )
98
+ output_video = gr.Video(
99
+ label='Output Video'
100
+ )
101
+ with gr.Row():
102
+ candidate_labels = gr.Textbox(
103
+ label='Labels',
104
+ placeholder='Labels separated by a comma',
105
+ )
106
+ submit = gr.Button()
107
+ gr.Examples(
108
+ fn=process_video,
109
+ examples=[["./cats.mp4", "dog,cat"]],
110
+ inputs=[
111
+ input_video,
112
+ candidate_labels,
113
+
114
+ ],
115
+ outputs=output_video
116
+ )
117
+
118
+ submit.click(
119
+ fn=process_video,
120
+ inputs=[input_video, candidate_labels],
121
+ outputs=output_video
122
+ )
123
+
124
+ demo.launch(debug=False, show_error=True)