xiaoyao9184 commited on
Commit
216efed
·
verified ·
1 Parent(s): e4a91cc

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (4) hide show
  1. app.py +38 -0
  2. gradio_app.py +385 -0
  3. gradio_run.py +7 -0
  4. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import git
4
+ import subprocess
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ REPO_URL = "https://github.com/facebookresearch/videoseal.git"
8
+ REPO_BRANCH = '5897ac50b5b0f5c806f42d2f7d1ef208a0780a28'
9
+ LOCAL_PATH = "./videoseal"
10
+
11
+ def install_src():
12
+ if not os.path.exists(LOCAL_PATH):
13
+ print(f"Cloning repository from {REPO_URL}...")
14
+ repo = git.Repo.clone_from(REPO_URL, LOCAL_PATH)
15
+ repo.git.checkout(REPO_BRANCH)
16
+ else:
17
+ print(f"Repository already exists at {LOCAL_PATH}")
18
+
19
+ requirements_path = os.path.join(LOCAL_PATH, "requirements.txt")
20
+ if os.path.exists(requirements_path):
21
+ print("Installing requirements...")
22
+ subprocess.check_call(["pip", "install", "-r", requirements_path])
23
+ else:
24
+ print("No requirements.txt found.")
25
+
26
+ # clone repo
27
+ install_src()
28
+
29
+ # change directory
30
+ print(f"Current Directory: {os.getcwd()}")
31
+ os.chdir(LOCAL_PATH)
32
+ print(f"New Directory: {os.getcwd()}")
33
+
34
+ # fix sys.path for import
35
+ sys.path.append(os.getcwd())
36
+
37
+ # run gradio
38
+ import gradio_app
gradio_app.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ if "APP_PATH" in os.environ:
5
+ app_path = os.path.abspath(os.environ["APP_PATH"])
6
+ if os.getcwd() != app_path:
7
+ # fix sys.path for import
8
+ os.chdir(app_path)
9
+ if app_path not in sys.path:
10
+ sys.path.append(app_path)
11
+
12
+ import gradio as gr
13
+
14
+ import torch
15
+ import torchaudio
16
+ import torchvision
17
+ import matplotlib.pyplot as plt
18
+ import re
19
+ import random
20
+ import string
21
+ from audioseal import AudioSeal
22
+ import videoseal
23
+ from videoseal.utils.display import save_video_audio_to_mp4
24
+
25
+ # Load video_model if not already loaded in reload mode
26
+ if 'video_model' not in globals():
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ # Load the VideoSeal model
30
+ video_model = videoseal.load("videoseal")
31
+ video_model.eval()
32
+ video_model.to(device)
33
+ video_model_nbytes = int(video_model.embedder.msg_processor.nbits / 8)
34
+
35
+ # Load the AudioSeal model
36
+ # Load audio_generator if not already loaded in reload mode
37
+ if 'audio_generator' not in globals():
38
+ audio_generator = AudioSeal.load_generator("audioseal_wm_16bits")
39
+ audio_generator = audio_generator.to(device)
40
+ audio_generator_nbytes = int(audio_generator.msg_processor.nbits / 8)
41
+
42
+ # Load audio_detector if not already loaded in reload mode
43
+ if 'audio_detector' not in globals():
44
+ audio_detector = AudioSeal.load_detector("audioseal_detector_16bits")
45
+ audio_detector = audio_detector.to(device)
46
+
47
+
48
+ def load_video(file):
49
+ # Read the video and convert to tensor format
50
+ video, audio, info = torchvision.io.read_video(file, output_format="TCHW", pts_unit="sec")
51
+ assert "audio_fps" in info, "The input video must contain an audio track. Simply refer to the main videoseal inference code if not."
52
+
53
+ # Normalize the video frames to the range [0, 1]
54
+ # audio = audio.float()
55
+ # video = video.float() / 255.0
56
+
57
+ # Normalize the video frames to the range [0, 1] and trim to 3 second
58
+ fps = 24
59
+ video = video[:fps * 3].float() / 255.0
60
+
61
+ sample_rate = info["audio_fps"]
62
+ audio = audio[:, :int(sample_rate * 3)].float()
63
+
64
+ return video, info["video_fps"], audio, info["audio_fps"]
65
+
66
+ def generate_msg_pt_by_format_string(format_string, bytes_count):
67
+ msg_hex = format_string.replace("-", "")
68
+ hex_length = bytes_count * 2
69
+ binary_list = []
70
+ for i in range(0, len(msg_hex), hex_length):
71
+ chunk = msg_hex[i:i+hex_length]
72
+ binary = bin(int(chunk, 16))[2:].zfill(bytes_count * 8)
73
+ binary_list.append([int(b) for b in binary])
74
+ # torch.randint(0, 2, (1, 16), dtype=torch.int32)
75
+ msg_pt = torch.tensor(binary_list, dtype=torch.int32)
76
+ return msg_pt.to(device)
77
+
78
+ def embed_watermark(output_file, msg_v, msg_a, video_only, video, fps, audio, sample_rate):
79
+ # Perform watermark embedding on video
80
+ with torch.no_grad():
81
+ outputs = video_model.embed(video, is_video=True, msgs=msg_v)
82
+
83
+ # Extract the results
84
+ video_w = outputs["imgs_w"] # Watermarked video frames
85
+ video_msgs = outputs["msgs"] # Watermark messages
86
+
87
+ if not video_only:
88
+ # Resample the audio to 16kHz for watermarking
89
+ audio_16k = torchaudio.transforms.Resample(sample_rate, 16000)(audio)
90
+
91
+ # If the audio has more than one channel, average all channels to 1 channel
92
+ if audio_16k.shape[0] > 1:
93
+ audio_16k_mono = torch.mean(audio_16k, dim=0, keepdim=True)
94
+ else:
95
+ audio_16k_mono = audio_16k
96
+
97
+ # Add batch dimension to the audio tensor
98
+ audio_16k_mono_batched = audio_16k_mono.unsqueeze(0).to(device)
99
+
100
+ # Get the watermark for the audio
101
+ with torch.no_grad():
102
+ watermark = audio_generator.get_watermark(
103
+ audio_16k_mono_batched, 16000, message=msg_a
104
+ )
105
+
106
+ # Embed the watermark in the audio
107
+ audio_16k_w = audio_16k_mono_batched + watermark
108
+
109
+ # Remove batch dimension from the watermarked audio tensor
110
+ audio_16k_w = audio_16k_w.squeeze(0)
111
+
112
+ # If the original audio had more than one channel, duplicate the watermarked audio to all channels
113
+ if audio_16k.shape[0] > 1:
114
+ audio_16k_w = audio_16k_w.repeat(audio_16k.shape[0], 1)
115
+
116
+ # Resample the watermarked audio back to the original sample rate
117
+ audio_w = torchaudio.transforms.Resample(16000, sample_rate).to(device)(audio_16k_w)
118
+ else:
119
+ audio_w = audio
120
+
121
+ # for Incompatible pixel format 'rgb24' for codec 'libx264', auto-selecting format 'yuv444p'
122
+ video_w = video_w.flip(1)
123
+
124
+ # Save the watermarked video and audio
125
+ save_video_audio_to_mp4(
126
+ video_tensor=video_w,
127
+ audio_tensor=audio_w,
128
+ fps=int(fps),
129
+ audio_sample_rate=int(sample_rate),
130
+ output_filename=output_file,
131
+ )
132
+
133
+ print(f"encoded message: \n Audio: {msg_a} \n Video {video_msgs[0]}")
134
+
135
+ return video_w, audio_w
136
+
137
+ def generate_format_string_by_msg_pt(msg_pt, bytes_count):
138
+ hex_length = bytes_count * 2
139
+ binary_int = 0
140
+ for bit in msg_pt:
141
+ binary_int = (binary_int << 1) | int(bit.item())
142
+ hex_string = format(binary_int, f'0{hex_length}x')
143
+
144
+ split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
145
+ format_hex = "-".join(split_hex)
146
+
147
+ return hex_string, format_hex
148
+
149
+ def detect_watermark(video_only, video, audio, sample_rate):
150
+ # Detect watermarks in the video
151
+ with torch.no_grad():
152
+ msg_extracted = video_model.extract_message(video)
153
+
154
+ print(f"Extracted message from video: {msg_extracted}")
155
+
156
+ if not video_only:
157
+ if len(audio.shape) == 2:
158
+ audio = audio.unsqueeze(0).to(device) # batchify
159
+
160
+ # if stereo convert to mono
161
+ if audio.shape[1] > 1:
162
+ audio = torch.mean(audio, dim=1, keepdim=True)
163
+
164
+ # Resample the audio to 16kHz for detectting
165
+ audio_16k = torchaudio.transforms.Resample(sample_rate, 16000).to(device)(audio)
166
+
167
+ # Detect watermarks in the audio
168
+ with torch.no_grad():
169
+ result, message = audio_detector.detect_watermark(audio_16k, 16000)
170
+
171
+ # pred_prob is a tensor of size batch x 2 x frames, indicating the probability (positive and negative) of watermarking for each frame
172
+ # A watermarked audio should have pred_prob[:, 1, :] > 0.5
173
+ # message_prob is a tensor of size batch x 16, indicating of the probability of each bit to be 1.
174
+ # message will be a random tensor if the detector detects no watermarking from the audio
175
+ pred_prob, message_prob = audio_detector(audio_16k, sample_rate)
176
+
177
+ print(f"Detection result for audio: {result}")
178
+ print(f"Extracted message from audio: {message}")
179
+
180
+ return msg_extracted, (result, message, pred_prob, message_prob)
181
+ else:
182
+ return msg_extracted, None
183
+
184
+ def get_waveform_and_specgram(waveform, sample_rate):
185
+ # If the audio has more than one channel, average all channels to 1 channel
186
+ if waveform.shape[0] > 1:
187
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
188
+
189
+ waveform = waveform.squeeze().detach().cpu().numpy()
190
+
191
+ num_frames = waveform.shape[-1]
192
+ time_axis = torch.arange(0, num_frames) / sample_rate
193
+
194
+ figure, (ax1, ax2) = plt.subplots(2, 1)
195
+
196
+ ax1.plot(time_axis, waveform, linewidth=1)
197
+ ax1.grid(True)
198
+ ax2.specgram(waveform, Fs=sample_rate)
199
+
200
+ figure.suptitle(f"Waveform and specgram")
201
+
202
+ return figure
203
+
204
+ def generate_hex_format_regex(bytes_count):
205
+ hex_length = bytes_count * 2
206
+ hex_string = 'F' * hex_length
207
+ split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
208
+ format_like = "-".join(split_hex)
209
+ regex_pattern = '^' + '-'.join([r'[0-9A-Fa-f]{4}'] * len(split_hex)) + '$'
210
+ return format_like, regex_pattern
211
+
212
+ def generate_hex_random_message(bytes_count):
213
+ hex_length = bytes_count * 2
214
+ hex_string = ''.join(random.choice(string.hexdigits) for _ in range(hex_length))
215
+ split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
216
+ random_str = "-".join(split_hex)
217
+ return random_str, "".join(split_hex)
218
+
219
+ with gr.Blocks(title="VideoSeal") as demo:
220
+ gr.Markdown("""
221
+ # VideoSeal Demo
222
+
223
+ The current video will be YUV444P encoded, truncated to 3 seconds for use, and multi-channel audio will be merged into a single channel for processing.
224
+
225
+ Find the project [here](https://github.com/facebookresearch/videoseal.git).
226
+ """)
227
+
228
+ with gr.Tabs():
229
+ with gr.TabItem("Embed Watermark"):
230
+ with gr.Row():
231
+ with gr.Column():
232
+ embedding_vid = gr.Video(label="Input Video")
233
+
234
+ with gr.Row():
235
+ with gr.Column():
236
+ embedding_type = gr.Radio(["random", "input"], value="random", label="Type", info="Type of watermarks")
237
+
238
+ format_like, regex_pattern = generate_hex_format_regex(video_model_nbytes)
239
+ msg, _ = generate_hex_random_message(video_model_nbytes)
240
+ embedding_msg = gr.Textbox(
241
+ label=f"Message ({video_model_nbytes} bytes hex string)",
242
+ info=f"format like {format_like}",
243
+ value=msg,
244
+ interactive=False, show_copy_button=True)
245
+ with gr.Column():
246
+ embedding_only_vid = gr.Checkbox(label="Only Video", value=False)
247
+
248
+ embedding_specgram = gr.Checkbox(label="Show specgram", value=False, info="Show debug information")
249
+
250
+ format_like_a, regex_pattern_a = generate_hex_format_regex(audio_generator_nbytes)
251
+ msg_a, _ = generate_hex_random_message(audio_generator_nbytes)
252
+ embedding_msg_a = gr.Textbox(
253
+ label=f"Audio Message ({audio_generator_nbytes} bytes hex string)",
254
+ info=f"format like {format_like_a}",
255
+ value=msg_a,
256
+ interactive=False, show_copy_button=True)
257
+
258
+ embedding_btn = gr.Button("Embed Watermark")
259
+ with gr.Column():
260
+ marked_vid = gr.Video(label="Output Audio", show_download_button=True)
261
+ specgram_original = gr.Plot(label="Original Audio", format="png", visible=False)
262
+ specgram_watermarked = gr.Plot(label="Watermarked Audio", format="png", visible=False)
263
+
264
+ def change_embedding_type(video_only):
265
+ return [gr.update(visible=not video_only, value=False),gr.update(visible=not video_only)]
266
+ embedding_only_vid.change(
267
+ fn=change_embedding_type,
268
+ inputs=[embedding_only_vid],
269
+ outputs=[embedding_specgram, embedding_msg_a]
270
+ )
271
+
272
+ def change_embedding_type(type):
273
+ if type == "random":
274
+ msg, _ = generate_hex_random_message(video_model_nbytes)
275
+ msg_a,_ = generate_hex_random_message(audio_generator_nbytes)
276
+ return [gr.update(interactive=False, value=msg),gr.update(interactive=False, value=msg_a)]
277
+ else:
278
+ return [gr.update(interactive=True),gr.update(interactive=True)]
279
+ embedding_type.change(
280
+ fn=change_embedding_type,
281
+ inputs=[embedding_type],
282
+ outputs=[embedding_msg, embedding_msg_a]
283
+ )
284
+
285
+ def check_embedding_msg(msg, msg_a):
286
+ if not re.match(regex_pattern, msg):
287
+ gr.Warning(
288
+ f"Invalid format. Please use like '{format_like}'",
289
+ duration=0)
290
+ if not re.match(regex_pattern_a, msg_a):
291
+ gr.Warning(
292
+ f"Invalid format. Please use like '{format_like_a}'",
293
+ duration=0)
294
+ embedding_msg.change(
295
+ fn=check_embedding_msg,
296
+ inputs=[embedding_msg, embedding_msg_a],
297
+ outputs=[]
298
+ )
299
+
300
+ def run_embed_watermark(file, video_only, show_specgram, msg, msg_a):
301
+ if file is None:
302
+ raise gr.Error("No file uploaded", duration=5)
303
+ if not re.match(regex_pattern, msg):
304
+ raise gr.Error(f"Invalid format. Please use like '{format_like}'", duration=5)
305
+ if not re.match(regex_pattern_a, msg_a):
306
+ raise gr.Error(f"Invalid format. Please use like '{format_like_a}'", duration=5)
307
+
308
+ msg_pt = generate_msg_pt_by_format_string(msg, video_model_nbytes)
309
+ msg_pt_a = generate_msg_pt_by_format_string(msg_a, audio_generator_nbytes)
310
+ video, fps, audio, rate = load_video(file)
311
+
312
+ output_path = file + '.marked.mp4'
313
+ _, audio_w = embed_watermark(output_path, msg_pt, msg_pt_a, video_only, video, fps, audio, rate)
314
+
315
+ if show_specgram:
316
+ fig_original = get_waveform_and_specgram(audio, rate)
317
+ fig_watermarked = get_waveform_and_specgram(audio_w, rate)
318
+ return [
319
+ output_path,
320
+ gr.update(visible=True, value=fig_original),
321
+ gr.update(visible=True, value=fig_watermarked)]
322
+ else:
323
+ return [
324
+ output_path,
325
+ gr.update(visible=False),
326
+ gr.update(visible=False)]
327
+ embedding_btn.click(
328
+ fn=run_embed_watermark,
329
+ inputs=[embedding_vid, embedding_only_vid, embedding_specgram, embedding_msg, embedding_msg_a],
330
+ outputs=[marked_vid, specgram_original, specgram_watermarked]
331
+ )
332
+
333
+ with gr.TabItem("Detect Watermark"):
334
+ with gr.Row():
335
+ with gr.Column():
336
+ detecting_vid = gr.Video(label="Input Video")
337
+ detecting_only_vid = gr.Checkbox(label="Only Video", value=False)
338
+ detecting_btn = gr.Button("Detect Watermark")
339
+ with gr.Column():
340
+ predicted_messages = gr.JSON(label="Detected Messages")
341
+
342
+ def run_detect_watermark(file, video_only):
343
+ if file is None:
344
+ raise gr.Error("No file uploaded", duration=5)
345
+
346
+ video, _, audio, rate = load_video(file)
347
+
348
+ if video_only:
349
+ msg_extracted, _ = detect_watermark(video_only, video, audio, rate)
350
+
351
+ audio_json = None
352
+ else:
353
+ msg_extracted, (result, message, pred_prob, message_prob) = detect_watermark(video_only, video, audio, rate)
354
+
355
+ _, fromat_msg = generate_format_string_by_msg_pt(message[0], audio_generator_nbytes)
356
+
357
+ sum_above_05 = (pred_prob[:, 1, :] > 0.5).sum(dim=1)
358
+
359
+ audio_json = {
360
+ "socre": result,
361
+ "message": fromat_msg,
362
+ "frames_count_all": pred_prob.shape[2],
363
+ "frames_count_above_05": sum_above_05[0].item(),
364
+ "bits_probability": message_prob[0].tolist(),
365
+ "bits_massage": message[0].tolist()
366
+ }
367
+
368
+ _, fromat_msg = generate_format_string_by_msg_pt(msg_extracted[0], video_model_nbytes)
369
+
370
+ # Create message output as JSON
371
+ message_json = {
372
+ "video": {
373
+ "message": fromat_msg,
374
+ },
375
+ "audio:": audio_json
376
+ }
377
+ return message_json
378
+ detecting_btn.click(
379
+ fn=run_detect_watermark,
380
+ inputs=[detecting_vid, detecting_only_vid],
381
+ outputs=[predicted_messages]
382
+ )
383
+
384
+ if __name__ == "__main__":
385
+ demo.launch()
gradio_run.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # NOTE: copy from gradio bin
2
+ import re
3
+ import sys
4
+ from gradio.cli import cli
5
+ if __name__ == '__main__':
6
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
7
+ sys.exit(cli())
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ gradio==5.8.0
3
+ huggingface-hub==0.26.3
4
+ audioseal==0.1.4
5
+ matplotlib==3.10.0
6
+ soundfile==0.12.1
7
+ torchaudio==2.5.1