xiaoyao9184 commited on
Commit
f07b84d
·
verified ·
1 Parent(s): a968ae9

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (1) hide show
  1. gradio_app.py +535 -198
gradio_app.py CHANGED
@@ -16,8 +16,13 @@ import torchaudio
16
  import torchvision
17
  import matplotlib.pyplot as plt
18
  import re
 
19
  import random
20
  import string
 
 
 
 
21
  from audioseal import AudioSeal
22
  import videoseal
23
  from videoseal.utils.display import save_video_audio_to_mp4
@@ -44,25 +49,6 @@ if 'audio_detector' not in globals():
44
  audio_detector = AudioSeal.load_detector("audioseal_detector_16bits")
45
  audio_detector = audio_detector.to(device)
46
 
47
-
48
- def load_video(file):
49
- # Read the video and convert to tensor format
50
- video, audio, info = torchvision.io.read_video(file, output_format="TCHW", pts_unit="sec")
51
- assert "audio_fps" in info, "The input video must contain an audio track. Simply refer to the main videoseal inference code if not."
52
-
53
- # Normalize the video frames to the range [0, 1]
54
- # audio = audio.float()
55
- # video = video.float() / 255.0
56
-
57
- # Normalize the video frames to the range [0, 1] and trim to 3 second
58
- fps = 24
59
- video = video[:fps * 3].float() / 255.0
60
-
61
- sample_rate = info["audio_fps"]
62
- audio = audio[:, :int(sample_rate * 3)].float()
63
-
64
- return video, info["video_fps"], audio, info["audio_fps"]
65
-
66
  def generate_msg_pt_by_format_string(format_string, bytes_count):
67
  msg_hex = format_string.replace("-", "")
68
  hex_length = bytes_count * 2
@@ -75,66 +61,8 @@ def generate_msg_pt_by_format_string(format_string, bytes_count):
75
  msg_pt = torch.tensor(binary_list, dtype=torch.int32)
76
  return msg_pt.to(device)
77
 
78
- def embed_watermark(output_file, msg_v, msg_a, video_only, video, fps, audio, sample_rate):
79
- # Perform watermark embedding on video
80
- with torch.no_grad():
81
- outputs = video_model.embed(video, is_video=True, msgs=msg_v)
82
-
83
- # Extract the results
84
- video_w = outputs["imgs_w"] # Watermarked video frames
85
- video_msgs = outputs["msgs"] # Watermark messages
86
-
87
- if not video_only:
88
- # Resample the audio to 16kHz for watermarking
89
- audio_16k = torchaudio.transforms.Resample(sample_rate, 16000)(audio)
90
-
91
- # If the audio has more than one channel, average all channels to 1 channel
92
- if audio_16k.shape[0] > 1:
93
- audio_16k_mono = torch.mean(audio_16k, dim=0, keepdim=True)
94
- else:
95
- audio_16k_mono = audio_16k
96
-
97
- # Add batch dimension to the audio tensor
98
- audio_16k_mono_batched = audio_16k_mono.unsqueeze(0).to(device)
99
-
100
- # Get the watermark for the audio
101
- with torch.no_grad():
102
- watermark = audio_generator.get_watermark(
103
- audio_16k_mono_batched, 16000, message=msg_a
104
- )
105
-
106
- # Embed the watermark in the audio
107
- audio_16k_w = audio_16k_mono_batched + watermark
108
-
109
- # Remove batch dimension from the watermarked audio tensor
110
- audio_16k_w = audio_16k_w.squeeze(0)
111
-
112
- # If the original audio had more than one channel, duplicate the watermarked audio to all channels
113
- if audio_16k.shape[0] > 1:
114
- audio_16k_w = audio_16k_w.repeat(audio_16k.shape[0], 1)
115
-
116
- # Resample the watermarked audio back to the original sample rate
117
- audio_w = torchaudio.transforms.Resample(16000, sample_rate).to(device)(audio_16k_w)
118
- else:
119
- audio_w = audio
120
-
121
- # for Incompatible pixel format 'rgb24' for codec 'libx264', auto-selecting format 'yuv444p'
122
- video_w = video_w.flip(1)
123
-
124
- # Save the watermarked video and audio
125
- save_video_audio_to_mp4(
126
- video_tensor=video_w,
127
- audio_tensor=audio_w,
128
- fps=int(fps),
129
- audio_sample_rate=int(sample_rate),
130
- output_filename=output_file,
131
- )
132
-
133
- print(f"encoded message: \n Audio: {msg_a} \n Video {video_msgs[0]}")
134
-
135
- return video_w, audio_w
136
-
137
  def generate_format_string_by_msg_pt(msg_pt, bytes_count):
 
138
  hex_length = bytes_count * 2
139
  binary_int = 0
140
  for bit in msg_pt:
@@ -143,84 +71,491 @@ def generate_format_string_by_msg_pt(msg_pt, bytes_count):
143
 
144
  split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
145
  format_hex = "-".join(split_hex)
146
-
147
  return hex_string, format_hex
148
 
149
- def detect_watermark(video_only, video, audio, sample_rate):
150
- # Detect watermarks in the video
151
- with torch.no_grad():
152
- msg_extracted = video_model.extract_message(video)
153
-
154
- print(f"Extracted message from video: {msg_extracted}")
155
-
156
- if not video_only:
157
- if len(audio.shape) == 2:
158
- audio = audio.unsqueeze(0).to(device) # batchify
159
 
160
- # if stereo convert to mono
161
- if audio.shape[1] > 1:
162
- audio = torch.mean(audio, dim=1, keepdim=True)
 
 
 
163
 
164
- # Resample the audio to 16kHz for detectting
165
- audio_16k = torchaudio.transforms.Resample(sample_rate, 16000).to(device)(audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- # Detect watermarks in the audio
168
- with torch.no_grad():
169
- result, message = audio_detector.detect_watermark(audio_16k, 16000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- # pred_prob is a tensor of size batch x 2 x frames, indicating the probability (positive and negative) of watermarking for each frame
172
- # A watermarked audio should have pred_prob[:, 1, :] > 0.5
173
- # message_prob is a tensor of size batch x 16, indicating of the probability of each bit to be 1.
174
- # message will be a random tensor if the detector detects no watermarking from the audio
175
- pred_prob, message_prob = audio_detector(audio_16k, sample_rate)
 
 
176
 
177
- print(f"Detection result for audio: {result}")
178
- print(f"Extracted message from audio: {message}")
179
 
180
- return msg_extracted, (result, message, pred_prob, message_prob)
 
 
181
  else:
182
- return msg_extracted, None
183
 
184
- def get_waveform_and_specgram(waveform, sample_rate):
185
- # If the audio has more than one channel, average all channels to 1 channel
186
- if waveform.shape[0] > 1:
187
- waveform = torch.mean(waveform, dim=0, keepdim=True)
188
 
189
- waveform = waveform.squeeze().detach().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- num_frames = waveform.shape[-1]
192
- time_axis = torch.arange(0, num_frames) / sample_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- figure, (ax1, ax2) = plt.subplots(2, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- ax1.plot(time_axis, waveform, linewidth=1)
197
- ax1.grid(True)
198
- ax2.specgram(waveform, Fs=sample_rate)
 
 
199
 
200
- figure.suptitle(f"Waveform and specgram")
 
201
 
202
- return figure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- def generate_hex_format_regex(bytes_count):
205
- hex_length = bytes_count * 2
206
- hex_string = 'F' * hex_length
207
- split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
208
- format_like = "-".join(split_hex)
209
- regex_pattern = '^' + '-'.join([r'[0-9A-Fa-f]{4}'] * len(split_hex)) + '$'
210
- return format_like, regex_pattern
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- def generate_hex_random_message(bytes_count):
213
- hex_length = bytes_count * 2
214
- hex_string = ''.join(random.choice(string.hexdigits) for _ in range(hex_length))
215
- split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
216
- random_str = "-".join(split_hex)
217
- return random_str, "".join(split_hex)
218
 
219
  with gr.Blocks(title="VideoSeal") as demo:
220
  gr.Markdown("""
221
  # VideoSeal Demo
222
 
223
- The current video will be YUV444P encoded, truncated to 3 seconds for use, and multi-channel audio will be merged into a single channel for processing.
 
 
 
 
 
224
 
225
  Find the project [here](https://github.com/facebookresearch/videoseal.git).
226
  """)
@@ -230,23 +565,21 @@ with gr.Blocks(title="VideoSeal") as demo:
230
  with gr.Row():
231
  with gr.Column():
232
  embedding_vid = gr.Video(label="Input Video")
233
-
234
  with gr.Row():
235
  with gr.Column():
236
  embedding_type = gr.Radio(["random", "input"], value="random", label="Type", info="Type of watermarks")
237
 
238
- format_like, regex_pattern = generate_hex_format_regex(video_model_nbytes)
239
- msg, _ = generate_hex_random_message(video_model_nbytes)
240
- embedding_msg = gr.Textbox(
241
  label=f"Message ({video_model_nbytes} bytes hex string)",
242
- info=f"format like {format_like}",
243
- value=msg,
244
  interactive=False, show_copy_button=True)
245
  with gr.Column():
246
  embedding_only_vid = gr.Checkbox(label="Only Video", value=False)
247
 
248
- embedding_specgram = gr.Checkbox(label="Show specgram", value=False, info="Show debug information")
249
-
250
  format_like_a, regex_pattern_a = generate_hex_format_regex(audio_generator_nbytes)
251
  msg_a, _ = generate_hex_random_message(audio_generator_nbytes)
252
  embedding_msg_a = gr.Textbox(
@@ -258,76 +591,70 @@ with gr.Blocks(title="VideoSeal") as demo:
258
  embedding_btn = gr.Button("Embed Watermark")
259
  with gr.Column():
260
  marked_vid = gr.Video(label="Output Audio", show_download_button=True)
261
- specgram_original = gr.Plot(label="Original Audio", format="png", visible=False)
262
- specgram_watermarked = gr.Plot(label="Watermarked Audio", format="png", visible=False)
263
 
264
  def change_embedding_type(video_only):
265
- return [gr.update(visible=not video_only, value=False),gr.update(visible=not video_only)]
266
  embedding_only_vid.change(
267
  fn=change_embedding_type,
268
  inputs=[embedding_only_vid],
269
- outputs=[embedding_specgram, embedding_msg_a]
270
  )
271
 
272
  def change_embedding_type(type):
273
  if type == "random":
274
- msg, _ = generate_hex_random_message(video_model_nbytes)
275
  msg_a,_ = generate_hex_random_message(audio_generator_nbytes)
276
- return [gr.update(interactive=False, value=msg),gr.update(interactive=False, value=msg_a)]
277
  else:
278
  return [gr.update(interactive=True),gr.update(interactive=True)]
279
  embedding_type.change(
280
  fn=change_embedding_type,
281
  inputs=[embedding_type],
282
- outputs=[embedding_msg, embedding_msg_a]
283
  )
284
 
285
- def check_embedding_msg(msg, msg_a):
286
- if not re.match(regex_pattern, msg):
287
  gr.Warning(
288
- f"Invalid format. Please use like '{format_like}'",
289
  duration=0)
290
  if not re.match(regex_pattern_a, msg_a):
291
  gr.Warning(
292
  f"Invalid format. Please use like '{format_like_a}'",
293
  duration=0)
294
- embedding_msg.change(
295
  fn=check_embedding_msg,
296
- inputs=[embedding_msg, embedding_msg_a],
 
 
 
 
 
297
  outputs=[]
298
  )
299
 
300
- def run_embed_watermark(file, video_only, show_specgram, msg, msg_a):
301
- if file is None:
302
  raise gr.Error("No file uploaded", duration=5)
303
- if not re.match(regex_pattern, msg):
304
- raise gr.Error(f"Invalid format. Please use like '{format_like}'", duration=5)
305
  if not re.match(regex_pattern_a, msg_a):
306
  raise gr.Error(f"Invalid format. Please use like '{format_like_a}'", duration=5)
307
 
308
- msg_pt = generate_msg_pt_by_format_string(msg, video_model_nbytes)
309
  msg_pt_a = generate_msg_pt_by_format_string(msg_a, audio_generator_nbytes)
310
- video, fps, audio, rate = load_video(file)
311
-
312
- output_path = file + '.marked.mp4'
313
- _, audio_w = embed_watermark(output_path, msg_pt, msg_pt_a, video_only, video, fps, audio, rate)
314
-
315
- if show_specgram:
316
- fig_original = get_waveform_and_specgram(audio, rate)
317
- fig_watermarked = get_waveform_and_specgram(audio_w, rate)
318
- return [
319
- output_path,
320
- gr.update(visible=True, value=fig_original),
321
- gr.update(visible=True, value=fig_watermarked)]
322
  else:
323
- return [
324
- output_path,
325
- gr.update(visible=False),
326
- gr.update(visible=False)]
327
  embedding_btn.click(
328
  fn=run_embed_watermark,
329
- inputs=[embedding_vid, embedding_only_vid, embedding_specgram, embedding_msg, embedding_msg_a],
330
- outputs=[marked_vid, specgram_original, specgram_watermarked]
331
  )
332
 
333
  with gr.TabItem("Detect Watermark"):
@@ -339,39 +666,49 @@ with gr.Blocks(title="VideoSeal") as demo:
339
  with gr.Column():
340
  predicted_messages = gr.JSON(label="Detected Messages")
341
 
342
- def run_detect_watermark(file, video_only):
343
  if file is None:
344
  raise gr.Error("No file uploaded", duration=5)
345
 
346
- video, _, audio, rate = load_video(file)
347
-
348
- if video_only:
349
- msg_extracted, _ = detect_watermark(video_only, video, audio, rate)
 
 
 
 
 
 
 
 
 
350
 
 
351
  audio_json = None
352
  else:
353
- msg_extracted, (result, message, pred_prob, message_prob) = detect_watermark(video_only, video, audio, rate)
354
-
355
- _, fromat_msg = generate_format_string_by_msg_pt(message[0], audio_generator_nbytes)
356
-
357
- sum_above_05 = (pred_prob[:, 1, :] > 0.5).sum(dim=1)
358
-
 
 
 
 
 
 
 
 
359
  audio_json = {
360
- "socre": result,
361
- "message": fromat_msg,
362
- "frames_count_all": pred_prob.shape[2],
363
- "frames_count_above_05": sum_above_05[0].item(),
364
- "bits_probability": message_prob[0].tolist(),
365
- "bits_massage": message[0].tolist()
366
  }
367
 
368
- _, fromat_msg = generate_format_string_by_msg_pt(msg_extracted[0], video_model_nbytes)
369
-
370
  # Create message output as JSON
371
  message_json = {
372
- "video": {
373
- "message": fromat_msg,
374
- },
375
  "audio:": audio_json
376
  }
377
  return message_json
 
16
  import torchvision
17
  import matplotlib.pyplot as plt
18
  import re
19
+ import math
20
  import random
21
  import string
22
+ import ffmpeg
23
+ import subprocess
24
+ import numpy as np
25
+ import tqdm
26
  from audioseal import AudioSeal
27
  import videoseal
28
  from videoseal.utils.display import save_video_audio_to_mp4
 
49
  audio_detector = AudioSeal.load_detector("audioseal_detector_16bits")
50
  audio_detector = audio_detector.to(device)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def generate_msg_pt_by_format_string(format_string, bytes_count):
53
  msg_hex = format_string.replace("-", "")
54
  hex_length = bytes_count * 2
 
61
  msg_pt = torch.tensor(binary_list, dtype=torch.int32)
62
  return msg_pt.to(device)
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def generate_format_string_by_msg_pt(msg_pt, bytes_count):
65
+ if msg_pt is None: return '', None
66
  hex_length = bytes_count * 2
67
  binary_int = 0
68
  for bit in msg_pt:
 
71
 
72
  split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
73
  format_hex = "-".join(split_hex)
 
74
  return hex_string, format_hex
75
 
76
+ def generate_hex_format_regex(bytes_count):
77
+ hex_length = bytes_count * 2
78
+ hex_string = 'F' * hex_length
79
+ split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
80
+ format_like = "-".join(split_hex)
81
+ regex_pattern = '^' + '-'.join([r'[0-9A-Fa-f]{4}'] * len(split_hex)) + '$'
82
+ return format_like, regex_pattern
 
 
 
83
 
84
+ def generate_hex_random_message(bytes_count):
85
+ hex_length = bytes_count * 2
86
+ hex_string = ''.join(random.choice(string.hexdigits) for _ in range(hex_length))
87
+ split_hex = [hex_string[i:i + 4] for i in range(0, len(hex_string), 4)]
88
+ random_str = "-".join(split_hex)
89
+ return random_str, "".join(split_hex)
90
 
91
+ def embed_video_clip(
92
+ model,
93
+ clip: np.ndarray,
94
+ msgs: torch.Tensor
95
+ ) -> np.ndarray:
96
+ clip_tensor = torch.tensor(clip, dtype=torch.float32).to(device).permute(0, 3, 1, 2) / 255.0
97
+ outputs = model.embed(clip_tensor, msgs=msgs, is_video=True)
98
+ processed_clip = outputs["imgs_w"]
99
+ processed_clip = (processed_clip * 255.0).byte().permute(0, 2, 3, 1).cpu().numpy()
100
+ return processed_clip
101
+
102
+ def embed_video(
103
+ model,
104
+ input_path: str,
105
+ output_path: str,
106
+ msgs: torch.Tensor,
107
+ chunk_size: int,
108
+ crf: int = 23
109
+ ) -> None:
110
+ # Read video dimensions
111
+ probe = ffmpeg.probe(input_path)
112
+ video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video')
113
+ width = int(video_info['width'])
114
+ height = int(video_info['height'])
115
+ fps = float(video_info['r_frame_rate'].split('/')[0]) / float(video_info['r_frame_rate'].split('/')[1])
116
+ num_frames = int(video_info['nb_frames'])
117
+
118
+ # Open the input video
119
+ process1 = (
120
+ ffmpeg
121
+ .input(input_path)
122
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height), r=fps)
123
+ .run_async(pipe_stdout=True, pipe_stderr=subprocess.PIPE)
124
+ )
125
+ # Open the output video
126
+ process2 = (
127
+ ffmpeg
128
+ .input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height), r=fps)
129
+ .output(output_path, vcodec='libx264', pix_fmt='yuv420p', r=fps, crf=crf)
130
+ .overwrite_output()
131
+ .run_async(pipe_stdin=True, pipe_stderr=subprocess.PIPE)
132
+ )
133
 
134
+ # Process the video
135
+ frame_size = width * height * 3
136
+ chunk = np.zeros((chunk_size, height, width, 3), dtype=np.uint8)
137
+ frame_count = 0
138
+ pbar = tqdm.tqdm(total=num_frames, unit='frame', desc="Watermark video embedding")
139
+ while True:
140
+ # TODO block EOF on Windows
141
+ in_bytes = process1.stdout.read(frame_size)
142
+ if not in_bytes:
143
+ break
144
+ frame = np.frombuffer(in_bytes, np.uint8).reshape([height, width, 3])
145
+ chunk[frame_count % chunk_size] = frame
146
+ frame_count += 1
147
+ pbar.update(1)
148
+ if frame_count % chunk_size == 0:
149
+ processed_frame = embed_video_clip(model, chunk, msgs)
150
+ process2.stdin.write(processed_frame.tobytes())
151
+
152
+ process1.stdout.close()
153
+ process2.stdin.close()
154
+ process1.wait()
155
+ process2.wait()
156
+
157
+ return
158
+
159
+ def get_sample_size(sample_fmt):
160
+ if sample_fmt == 's16':
161
+ return 2, np.int16
162
+ elif sample_fmt == 's16p':
163
+ return 2, np.float16
164
+ elif sample_fmt == 'flt':
165
+ return 4, np.int32
166
+ elif sample_fmt == 'fltp':
167
+ return 4, np.float32
168
+ elif sample_fmt == 's32':
169
+ return 4, np.int32
170
+ elif sample_fmt == 's32p':
171
+ return 4, np.float32
172
+ elif sample_fmt == 'u8':
173
+ return 1, np.int8
174
+ else:
175
+ raise ValueError(f"Unsupported sample_fmt: {sample_fmt}")
176
 
177
+ def embed_audio_clip(
178
+ model,
179
+ clip: np.ndarray,
180
+ msgs: torch.Tensor,
181
+ sample_rate
182
+ ) -> np.ndarray:
183
+ clip_tensor = torch.tensor(clip, dtype=torch.float32).to(device)
184
 
185
+ # Resample the audio to 16kHz for watermarking
186
+ audio_16k = torchaudio.transforms.Resample(sample_rate, 16000).to(device)(clip_tensor)
187
 
188
+ # If the audio has more than one channel, average all channels to 1 channel
189
+ if audio_16k.shape[0] > 1:
190
+ audio_16k_mono = torch.mean(audio_16k, dim=0, keepdim=True)
191
  else:
192
+ audio_16k_mono = audio_16k
193
 
194
+ # Add batch dimension to the audio tensor
195
+ audio_16k_mono_batched = audio_16k_mono.unsqueeze(0)
 
 
196
 
197
+ # Get the watermark for the audio
198
+ with torch.no_grad():
199
+ watermark = model.get_watermark(
200
+ audio_16k_mono_batched, 16000, message=msgs
201
+ )
202
+
203
+ # Embed the watermark in the audio
204
+ audio_16k_w = audio_16k_mono_batched + watermark
205
+
206
+ # Remove batch dimension from the watermarked audio tensor
207
+ audio_16k_w = audio_16k_w.squeeze(0)
208
+
209
+ # If the original audio had more than one channel, duplicate the watermarked audio to all channels
210
+ if audio_16k.shape[0] > 1:
211
+ audio_16k_w = audio_16k_w.repeat(audio_16k.shape[0], 1)
212
+
213
+ # Resample the watermarked audio back to the original sample rate
214
+ audio_w = torchaudio.transforms.Resample(16000, sample_rate).to(device)(audio_16k_w)
215
+
216
+ processed_clip = audio_w.cpu().numpy()
217
+ return processed_clip
218
+
219
+ def embed_audio(
220
+ model,
221
+ input_path: str,
222
+ output_path: str,
223
+ msgs: torch.Tensor,
224
+ chunk_size: int
225
+ ) -> None:
226
+ # Read audio dimensions
227
+ probe = ffmpeg.probe(input_path)
228
+ audio_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'audio')
229
+ sample_rate = int(audio_info['sample_rate'])
230
+ sample_fmt = audio_info['sample_fmt']
231
+ channels = int(audio_info['channels'])
232
+ duration = float(audio_info['duration'])
233
+
234
+ # CASE 1 Read audio all at once
235
+
236
+ # audio_data, stderr_output = (
237
+ # ffmpeg
238
+ # .input(input_path, loglevel='debug')
239
+ # .output('pipe:', format='f32le', acodec='pcm_f32le', ar=sample_rate, ac=channels)
240
+ # .run(capture_stdout=True, capture_stderr=True)
241
+ # )
242
+ # audio_data = process.stdout.read()
243
+ # print("audio numpy total size:", len(audio_data))
244
+ # process.stdout.close()
245
+ # process.wait()
246
+ # stderr_output = process.stderr.read().decode('utf-8')
247
+ # print(stderr_output)
248
+
249
+ # CASE 2 Read async
250
+ # NOTE loglevel='debug' not work on Windows
251
+ # NOTE format='wav' data size(4104768) bigger than format='s16le'(4104688)
252
+
253
+ # process = (
254
+ # ffmpeg
255
+ # .input(input_path, loglevel='debug')
256
+ # .output('pipe:', format='f32le', acodec='pcm_f32le', ar=sample_rate, ac=channels)
257
+ # .run_async(pipe_stdout=True, pipe_stderr=subprocess.PIPE)
258
+ # )
259
+ # audio_data = process.stdout.read()
260
+ # print("audio numpy total size:", len(audio_data))
261
+ # process.stdout.close()
262
+ # process.wait()
263
+ # stderr_output = process.stderr.read().decode('utf-8')
264
+ # print(stderr_output)
265
+
266
+ # stderr_output example:
267
+ #
268
+ # # AVIOContext @ 0x5d878ea02e80] Statistics: 4104688 bytes written, 0 seeks, 251 writeouts
269
+ # # [out#0/f32le @ 0x5d878eaf31c0] Output file #0 (pipe:):
270
+ # # [out#0/f32le @ 0x5d878eaf31c0] Output stream #0:0 (audio): 251 frames encoded (513086 samples); 251 packets muxed (4104688 bytes);
271
+ # # [out#0/f32le @ 0x5d878eaf31c0] Total: 251 packets (4104688 bytes) muxed
272
+
273
+ # CASE 3 Read by torchaudio
274
+ # NOTE torchvision read audio format is f32le
275
+
276
+ # _, audio, info = torchvision.io.read_video(input_path, output_format="TCHW")
277
+ # print("audio numpy total size:", audio.nbytes)
278
+
279
+
280
+ # Open the input audio
281
+ process1 = (
282
+ ffmpeg
283
+ .input(input_path)
284
+ .output('pipe:', format='f32le', acodec='pcm_f32le', ac=channels, ar=sample_rate)
285
+ .run_async(pipe_stdout=True, pipe_stderr=subprocess.PIPE)
286
+ )
287
+ # Open the output audio
288
+ process2 = (
289
+ ffmpeg
290
+ .input('pipe:', format='f32le', ac=channels, ar=sample_rate)
291
+ .output(output_path, format='wav', acodec='pcm_f32le', ac=channels, ar=sample_rate)
292
+ # not work
293
+ # .output(output_path, acodec='libfdk_aac', ac=channels, ar=sample_rate)
294
+ .overwrite_output()
295
+ .run_async(pipe_stdin=True, pipe_stderr=subprocess.PIPE)
296
+ )
297
 
298
+ # CASE read all and write all
299
+
300
+ # while True:
301
+ # audio_data = process1.stdout.read()
302
+ # if not audio_data:
303
+ # break
304
+ # try:
305
+ # process2.stdin.write(audio_data)
306
+ # except BrokenPipeError:
307
+ # print("Broken pipe: process2 has closed the input.")
308
+ # break
309
+
310
+ # Process the audio
311
+ sample_size, sample_type = get_sample_size(sample_fmt)
312
+ second_size = sample_size * channels * sample_rate
313
+ chunk = np.zeros((chunk_size, sample_rate, channels), dtype=sample_type)
314
+ second_count = 0
315
+ pbar = tqdm.tqdm(total=math.ceil(duration), unit='second', desc="Watermark audio embedding")
316
+ while True:
317
+ in_bytes = process1.stdout.read(second_size)
318
+ if not in_bytes:
319
+ break
320
+ frame = np.frombuffer(in_bytes, sample_type)
321
+ frame = frame.reshape((-1, channels))
322
+ chunk[second_count % chunk_size, :len(frame)] = frame
323
+ second_count += 1
324
+ pbar.update(1)
325
+ if second_count % chunk_size == 0:
326
+ if msgs is None:
327
+ process2.stdin.write(in_bytes)
328
+ else:
329
+ clip = np.concatenate(chunk, axis=0).T
330
+ processed_frame = embed_audio_clip(model, clip, msgs, sample_rate)
331
+ process2.stdin.write(processed_frame.T.tobytes())
332
+
333
+ process1.stdout.close()
334
+ process2.stdin.close()
335
+ process1.wait()
336
+ process2.wait()
337
+
338
+ # CASE print stderr
339
+
340
+ # stderr_output1 = process1.stderr.read().decode('utf-8')
341
+ # stderr_output2 = process2.stderr.read().decode('utf-8')
342
+ # print("Process 1 stderr:")
343
+ # print(stderr_output1)
344
+ # print("Process 2 stderr:")
345
+ # print(stderr_output2)
346
+ return
347
+
348
+ def embed_watermark(input_path, output_path, msg_v, msg_a, video_only, progress):
349
+ output_path_video = output_path + ".video.mp4"
350
+ embed_video(video_model, input_path, output_path_video, msg_v, 16)
351
+
352
+ output_path_audio = output_path + ".audio.m4a"
353
+ if video_only:
354
+ msg_a = None
355
+ embed_audio(audio_generator, input_path, output_path_audio, msg_a, 3)
356
+
357
+ # Use FFmpeg to add audio to the video
358
+ final_command = [
359
+ 'ffmpeg',
360
+ '-i', output_path_video,
361
+ '-i', output_path_audio,
362
+ '-c:v', 'copy',
363
+ '-c:a', 'aac',
364
+ '-strict', 'experimental',
365
+ '-y', output_path
366
+ ]
367
+ subprocess.run(final_command, check=True)
368
+ return
369
+
370
+ def detect_video_clip(
371
+ model,
372
+ clip: np.ndarray
373
+ ) -> torch.Tensor:
374
+ clip_tensor = torch.tensor(clip, dtype=torch.float32).permute(0, 3, 1, 2) / 255.0
375
+ outputs = model.detect(clip_tensor, is_video=True)
376
+ output_bits = outputs["preds"][:, 1:] # exclude the first which may be used for detection
377
+ return output_bits
378
+
379
+ def detect_video(
380
+ model,
381
+ input_path: str,
382
+ chunk_size: int
383
+ ) -> None:
384
+ # Read video dimensions
385
+ probe = ffmpeg.probe(input_path)
386
+ video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video')
387
+ width = int(video_info['width'])
388
+ height = int(video_info['height'])
389
+ fps = float(video_info['r_frame_rate'].split('/')[0]) / float(video_info['r_frame_rate'].split('/')[1])
390
+ num_frames = int(video_info['nb_frames'])
391
+
392
+ # Open the input video
393
+ process1 = (
394
+ ffmpeg
395
+ .input(input_path)
396
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height), r=fps)
397
+ .run_async(pipe_stdout=True, pipe_stderr=subprocess.PIPE)
398
+ )
399
 
400
+ # Process the video
401
+ frame_size = width * height * 3
402
+ chunk = np.zeros((chunk_size, height, width, 3), dtype=np.uint8)
403
+ frame_count = 0
404
+ soft_msgs = []
405
+ pbar = tqdm.tqdm(total=num_frames, unit='frame', desc="Watermark video detecting")
406
+ while True:
407
+ in_bytes = process1.stdout.read(frame_size)
408
+ if not in_bytes:
409
+ break
410
+ frame = np.frombuffer(in_bytes, np.uint8).reshape([height, width, 3])
411
+ chunk[frame_count % chunk_size] = frame
412
+ frame_count += 1
413
+ pbar.update(1)
414
+ if frame_count % chunk_size == 0:
415
+ soft_msgs.append(detect_video_clip(model, chunk))
416
+
417
+ process1.stdout.close()
418
+ process1.wait()
419
+
420
+ soft_msgs = torch.cat(soft_msgs, dim=0)
421
+ return soft_msgs
422
+
423
+ def detect_audio_clip(
424
+ model,
425
+ clip: np.ndarray,
426
+ sample_rate
427
+ ) -> torch.Tensor:
428
+ clip_tensor = torch.tensor(clip, dtype=torch.float32).to(device)
429
+
430
+ # Resample the audio to 16kHz for watermarking
431
+ audio_16k = torchaudio.transforms.Resample(sample_rate, 16000).to(device)(clip_tensor)
432
 
433
+ # If the audio has more than one channel, average all channels to 1 channel
434
+ if audio_16k.shape[0] > 1:
435
+ audio_16k_mono = torch.mean(audio_16k, dim=0, keepdim=True)
436
+ else:
437
+ audio_16k_mono = audio_16k
438
 
439
+ # Add batch dimension to the audio tensor
440
+ audio_16k_mono_batched = audio_16k_mono.unsqueeze(0)
441
 
442
+ # Detect watermarks in the audio
443
+ with torch.no_grad():
444
+ result, message = model.detect_watermark(
445
+ audio_16k_mono_batched, 16000
446
+ )
447
+
448
+ # pred_prob is a tensor of size batch x 2 x frames, indicating the probability (positive and negative) of watermarking for each frame
449
+ # A watermarked audio should have pred_prob[:, 1, :] > 0.5
450
+ # message_prob is a tensor of size batch x 16, indicating of the probability of each bit to be 1.
451
+ # message will be a random tensor if the detector detects no watermarking from the audio
452
+ pred_prob, message_prob = model(audio_16k_mono_batched, sample_rate)
453
+
454
+ # print(f"Detection result for audio: {result}")
455
+ # _, format_msg = generate_format_string_by_msg_pt(message[0], audio_generator_nbytes)
456
+ # print(f"Extracted message from audio: {message}: {format_msg}")
457
+ # print(f"Extracted pred_prob from audio: {pred_prob.shape}")
458
+ # print(f"Extracted message_prob from audio: {message_prob}")
459
+ # print(f"Extracted shape from audio 16k: {audio_16k_mono_batched.shape}")
460
+ # print(f"Extracted shape from audio original: {clip_tensor.shape}")
461
+ return result, message, pred_prob, message_prob
462
+
463
+ def detect_audio(
464
+ model,
465
+ input_path: str,
466
+ chunk_size: int
467
+ ) -> None:
468
+ # Read audio dimensions
469
+ probe = ffmpeg.probe(input_path)
470
+ audio_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'audio']
471
+ if len(audio_streams) == 0:
472
+ gr.Warning("No audio stream found in the input file.")
473
+ return None, None, None, None
474
+ audio_info = audio_streams[0]
475
+ sample_rate = int(audio_info['sample_rate'])
476
+ sample_fmt = audio_info['sample_fmt']
477
+ channels = int(audio_info['channels'])
478
+ duration = float(audio_info['duration'])
479
+
480
+ # Open the input audio
481
+ process1 = (
482
+ ffmpeg
483
+ .input(input_path)
484
+ .output('pipe:', format='f32le', acodec='pcm_f32le', ac=channels, ar=sample_rate)
485
+ .run_async(pipe_stdout=True, pipe_stderr=subprocess.PIPE)
486
+ )
487
 
488
+ # Process the audio
489
+ sample_size, sample_type = get_sample_size(sample_fmt)
490
+ second_size = sample_size * channels * sample_rate
491
+ chunk = np.zeros((chunk_size, sample_rate, channels), dtype=sample_type)
492
+ second_count = 0
493
+ soft_result = []
494
+ soft_message = []
495
+ soft_pred_prob = []
496
+ soft_message_prob = []
497
+ pbar = tqdm.tqdm(total=math.ceil(duration), unit='second', desc="Watermark audio detecting")
498
+ while True:
499
+ in_bytes = process1.stdout.read(second_size)
500
+ if not in_bytes:
501
+ break
502
+ frame = np.frombuffer(in_bytes, sample_type)
503
+ frame = frame.reshape((-1, channels))
504
+ chunk[second_count % chunk_size, :len(frame)] = frame
505
+ second_count += 1
506
+ pbar.update(1)
507
+ if second_count % chunk_size == 0:
508
+ clip = np.concatenate(chunk, axis=0).T
509
+ # print(f"Detection audio second: {second_count-chunk_size}-{second_count}")
510
+ result, message, pred_prob, message_prob = detect_audio_clip(model, clip, sample_rate)
511
+ soft_result.append(result)
512
+ soft_message.append(message)
513
+ soft_pred_prob.append(pred_prob)
514
+ soft_message_prob.append(message_prob)
515
+
516
+ process1.stdout.close()
517
+ process1.wait()
518
+
519
+ soft_message = torch.cat(soft_message, dim=0)
520
+ soft_pred_prob = torch.cat(soft_pred_prob, dim=0)
521
+ soft_message_prob = torch.cat(soft_message_prob, dim=0)
522
+ return (soft_result, soft_message, soft_pred_prob, soft_message_prob)
523
+
524
+ def detect_watermark(input_path, video_only):
525
+ msgs_v_frame = detect_video(video_model, input_path, 16)
526
+ msgs_v_avg = msgs_v_frame.mean(dim=0) # Average the predictions across all frames
527
+ msgs_v_frame = (msgs_v_frame > 0).to(int)
528
+ msgs_v_avg = (msgs_v_avg > 0).to(int)
529
+ msgs_v_unique, msgs_v_counts = torch.unique(msgs_v_frame, dim=0, return_counts=True)
530
+ msgs_v_most = None
531
+ if len(msgs_v_frame) > len(msgs_v_counts) > 0:
532
+ msgs_v_most_idx = torch.argmax(msgs_v_counts)
533
+ msgs_v_most = msgs_v_unique[msgs_v_most_idx]
534
+
535
+ msgs_a_most = msgs_a_res = msgs_a_frame = msgs_a_pred = msgs_a_prob = None
536
+ if not video_only:
537
+ msgs_a_res, msgs_a_frame, msgs_a_pred, msgs_a_prob = detect_audio(audio_detector, input_path, 1)
538
+ if msgs_a_res is not None:
539
+ msgs_a_res_not_zero = [i for i, x in enumerate(msgs_a_res) if x > 0.5]
540
+ msgs_a_frame_not_zero = msgs_a_frame[msgs_a_res_not_zero]
541
+ msgs_a_unique, msgs_a_counts = torch.unique(msgs_a_frame_not_zero, dim=0, return_counts=True)
542
+ if len(msgs_a_counts) > 0:
543
+ msgs_a_most_idx = torch.argmax(msgs_a_counts)
544
+ msgs_a_most = msgs_a_unique[msgs_a_most_idx]
545
+
546
+ return msgs_v_most, msgs_v_avg, msgs_v_frame, msgs_a_most, msgs_a_res, msgs_a_frame, msgs_a_pred, msgs_a_prob
547
 
 
 
 
 
 
 
548
 
549
  with gr.Blocks(title="VideoSeal") as demo:
550
  gr.Markdown("""
551
  # VideoSeal Demo
552
 
553
+ For video, each frame will be watermarked and detected.
554
+ For audio, each 3 seconds will be watermarked, and each second will be detected.
555
+
556
+ **NOTE: The watermarked process will modify both audio and video.
557
+ The video will be re-encoded to yuv420p using libx264,
558
+ and the audio will be duplicated from mono 16kHz and resampled back to the original channel sample rate.**
559
 
560
  Find the project [here](https://github.com/facebookresearch/videoseal.git).
561
  """)
 
565
  with gr.Row():
566
  with gr.Column():
567
  embedding_vid = gr.Video(label="Input Video")
568
+
569
  with gr.Row():
570
  with gr.Column():
571
  embedding_type = gr.Radio(["random", "input"], value="random", label="Type", info="Type of watermarks")
572
 
573
+ format_like_v, regex_pattern_v = generate_hex_format_regex(video_model_nbytes)
574
+ msg_v, _ = generate_hex_random_message(video_model_nbytes)
575
+ embedding_msg_v = gr.Textbox(
576
  label=f"Message ({video_model_nbytes} bytes hex string)",
577
+ info=f"format like {format_like_v}",
578
+ value=msg_v,
579
  interactive=False, show_copy_button=True)
580
  with gr.Column():
581
  embedding_only_vid = gr.Checkbox(label="Only Video", value=False)
582
 
 
 
583
  format_like_a, regex_pattern_a = generate_hex_format_regex(audio_generator_nbytes)
584
  msg_a, _ = generate_hex_random_message(audio_generator_nbytes)
585
  embedding_msg_a = gr.Textbox(
 
591
  embedding_btn = gr.Button("Embed Watermark")
592
  with gr.Column():
593
  marked_vid = gr.Video(label="Output Audio", show_download_button=True)
 
 
594
 
595
  def change_embedding_type(video_only):
596
+ return gr.update(visible=not video_only)
597
  embedding_only_vid.change(
598
  fn=change_embedding_type,
599
  inputs=[embedding_only_vid],
600
+ outputs=[embedding_msg_a]
601
  )
602
 
603
  def change_embedding_type(type):
604
  if type == "random":
605
+ msg_v, _ = generate_hex_random_message(video_model_nbytes)
606
  msg_a,_ = generate_hex_random_message(audio_generator_nbytes)
607
+ return [gr.update(interactive=False, value=msg_v),gr.update(interactive=False, value=msg_a)]
608
  else:
609
  return [gr.update(interactive=True),gr.update(interactive=True)]
610
  embedding_type.change(
611
  fn=change_embedding_type,
612
  inputs=[embedding_type],
613
+ outputs=[embedding_msg_v, embedding_msg_a]
614
  )
615
 
616
+ def check_embedding_msg(msg_v, msg_a):
617
+ if not re.match(regex_pattern_v, msg_v):
618
  gr.Warning(
619
+ f"Invalid format. Please use like '{format_like_v}'",
620
  duration=0)
621
  if not re.match(regex_pattern_a, msg_a):
622
  gr.Warning(
623
  f"Invalid format. Please use like '{format_like_a}'",
624
  duration=0)
625
+ embedding_msg_v.change(
626
  fn=check_embedding_msg,
627
+ inputs=[embedding_msg_v, embedding_msg_a],
628
+ outputs=[]
629
+ )
630
+ embedding_msg_a.change(
631
+ fn=check_embedding_msg,
632
+ inputs=[embedding_msg_v, embedding_msg_a],
633
  outputs=[]
634
  )
635
 
636
+ def run_embed_watermark(input_path, video_only, msg_v, msg_a, progress=gr.Progress(track_tqdm=True)):
637
+ if input_path is None:
638
  raise gr.Error("No file uploaded", duration=5)
639
+ if not re.match(regex_pattern_v, msg_v):
640
+ raise gr.Error(f"Invalid format. Please use like '{format_like_v}'", duration=5)
641
  if not re.match(regex_pattern_a, msg_a):
642
  raise gr.Error(f"Invalid format. Please use like '{format_like_a}'", duration=5)
643
 
644
+ msg_pt_v = generate_msg_pt_by_format_string(msg_v, video_model_nbytes)
645
  msg_pt_a = generate_msg_pt_by_format_string(msg_a, audio_generator_nbytes)
646
+
647
+ if video_only:
648
+ output_path = os.path.join(os.path.dirname(input_path), "__".join([msg_v]) + '.mp4')
 
 
 
 
 
 
 
 
 
649
  else:
650
+ output_path = os.path.join(os.path.dirname(input_path), "__".join([msg_v, msg_a]) + '.mp4')
651
+ embed_watermark(input_path, output_path, msg_pt_v, msg_pt_a, video_only, progress)
652
+
653
+ return output_path
654
  embedding_btn.click(
655
  fn=run_embed_watermark,
656
+ inputs=[embedding_vid, embedding_only_vid, embedding_msg_v, embedding_msg_a],
657
+ outputs=[marked_vid]
658
  )
659
 
660
  with gr.TabItem("Detect Watermark"):
 
666
  with gr.Column():
667
  predicted_messages = gr.JSON(label="Detected Messages")
668
 
669
+ def run_detect_watermark(file, video_only, progress=gr.Progress(track_tqdm=True)):
670
  if file is None:
671
  raise gr.Error("No file uploaded", duration=5)
672
 
673
+ msgs_v_most, msgs_v_avg, msgs_v_frame, msgs_a_most, msgs_a_res, msgs_a_frame, msgs_a_pred, msgs_a_prob = detect_watermark(file, video_only)
674
+
675
+ _, format_msg_v_most = generate_format_string_by_msg_pt(msgs_v_most, video_model_nbytes)
676
+ _, format_msg_v_avg = generate_format_string_by_msg_pt(msgs_v_avg, video_model_nbytes)
677
+ format_msg_v_frames = {}
678
+ for idx, msg in enumerate(msgs_v_frame):
679
+ _, format_msg = generate_format_string_by_msg_pt(msg, video_model_nbytes)
680
+ format_msg_v_frames[f"{idx}"] = format_msg
681
+ video_json = {
682
+ "most": format_msg_v_most,
683
+ "avg": format_msg_v_avg,
684
+ "frames": format_msg_v_frames
685
+ }
686
 
687
+ if msgs_a_res is None:
688
  audio_json = None
689
  else:
690
+ _, format_msg_a_most = generate_format_string_by_msg_pt(msgs_a_most, audio_generator_nbytes)
691
+ format_msg_a_seconds = {}
692
+ for idx, (result, message, pred_prob, message_prob) in enumerate(zip(msgs_a_res, msgs_a_frame, msgs_a_pred, msgs_a_prob)):
693
+ _, format_msg = generate_format_string_by_msg_pt(message, audio_generator_nbytes)
694
+
695
+ sum_above_05 = (pred_prob[1, :] > 0.5).sum(dim=0)
696
+ format_msg_a_seconds[f"{idx}"] = {
697
+ "socre": result,
698
+ "message": format_msg,
699
+ "frames_count_all": pred_prob.shape[1],
700
+ "frames_count_above_05": sum_above_05.item(),
701
+ "bits_probability": message_prob.tolist(),
702
+ "bits_massage": message.tolist()
703
+ }
704
  audio_json = {
705
+ "most": format_msg_a_most,
706
+ "seconds": format_msg_a_seconds
 
 
 
 
707
  }
708
 
 
 
709
  # Create message output as JSON
710
  message_json = {
711
+ "video": video_json,
 
 
712
  "audio:": audio_json
713
  }
714
  return message_json