WavMark / app.py
xxxfortest's picture
d
edb49b9
raw
history blame
4.49 kB
import pdb
import time
import wavmark
import streamlit as st
import os
import torch
import uuid
import datetime
import numpy as np
import soundfile
from huggingface_hub import hf_hub_download, HfApi
from wavmark.utils import file_reader
# Function to add watermark to audio
def add_watermark(audio_path, watermark_text):
assert len(watermark_text) == 16
watermark_npy = np.array([int(i) for i in watermark_text])
# todo: 控制时间
signal, sr, audio_length_second = file_reader.read_as_single_channel_16k(audio_path, 16000)
watermarked_signal, _ = wavmark.encode_watermark(model, signal, watermark_npy, show_progress=False)
tmp_file_name = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + str(uuid.uuid4()) + ".wav"
tmp_file_path = '/tmp/' + tmp_file_name
soundfile.write(tmp_file_path, watermarked_signal, sr)
return tmp_file_path
# Function to decode watermark from audio
def decode_watermark(audio_path):
watermarked_signal, sr, audio_length_second = file_reader.read_as_single_channel_16k(audio_path, 16000)
payload_decoded, _ = wavmark.decode_watermark(model, watermarked_signal, show_progress=False)
if payload_decoded is None:
return "No Watermark"
return payload_decoded
# Main web app
def main():
max_upload_size = 20 * 1024 * 1024 # 20 MB in bytes
if "def_value" not in st.session_state:
def_val_npy = np.random.choice([0, 1], size=32 - len_start_bit)
def_val_str = "".join([str(i) for i in def_val_npy])
st.session_state.def_value = def_val_str
st.title("Neural Audio Watermark")
st.write("Choose the action you want to perform:")
action = st.selectbox("Select Action", ["Add Watermark", "Decode Watermark"])
if action == "Add Watermark":
audio_file = st.file_uploader("Upload Audio File (WAV)", type=["wav"], accept_multiple_files=False,
max_upload_size=max_upload_size)
if audio_file:
tmp_input_audio_file = os.path.join("/tmp/", audio_file.name)
with open(tmp_input_audio_file, "wb") as f:
f.write(audio_file.getbuffer())
st.audio(tmp_input_audio_file, format="audio/wav")
watermark_text = st.text_input("Enter Watermark", value=st.session_state.def_value)
add_watermark_button = st.button("Add Watermark", key="add_watermark_btn")
if add_watermark_button: # 点击按钮后执行的
if audio_file and watermark_text:
with st.spinner("Adding Watermark..."):
# add_watermark_button.empty()
# st.button("Add Watermark", disabled=True)
# st.button("Add Watermark", disabled=True, key="add_watermark_btn_disabled")
t1 = time.time()
watermarked_audio = add_watermark(tmp_input_audio_file, watermark_text)
encode_time_cost = time.time() - t1
st.write("Watermarked Audio:")
st.audio(watermarked_audio, format="audio/wav")
st.write("Time Cost:%d seconds" % encode_time_cost)
# st.button("Add Watermark", disabled=False)
elif action == "Decode Watermark":
audio_file = st.file_uploader("Upload Audio File (WAV/MP3)", type=["wav", "mp3"], accept_multiple_files=False,
max_upload_size=max_upload_size)
if audio_file:
if st.button("Decode Watermark"):
# 1.保存
tmp_file_for_decode_path = os.path.join("/tmp/", audio_file.name)
with open(tmp_file_for_decode_path, "wb") as f:
f.write(audio_file.getbuffer())
# 2.执行
with st.spinner("Decoding..."):
t1 = time.time()
decoded_watermark = decode_watermark(tmp_file_for_decode_path)
decode_cost = time.time() - t1
print("decoded_watermark", decoded_watermark)
# Display the decoded watermark
st.write("Decoded Watermark:", decoded_watermark)
st.write("Time Cost:%d seconds" % (decode_cost))
if __name__ == "__main__":
len_start_bit = 16
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = wavmark.load_model().to(device)
main()