Kabatubare's picture
Update app.py
9b38565 verified
raw
history blame
6.63 kB
import gradio as gr
import torch
import torchaudio
import tempfile
import logging
from audioseal import AudioSeal
import random
import string
from pathlib import Path
from datetime import datetime
import json
import os
# Initialize logging
logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# File to store audio metadata
metadata_file = 'audio_metadata.json'
if not os.path.exists(metadata_file):
with open(metadata_file, 'w') as f:
json.dump({}, f)
# Helper function for generating a unique alphanumeric message
def generate_unique_message(length=16):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
# Converts message to binary, ensuring it fits within the specified bit length
def message_to_binary(message, bit_length=16):
binary_message = ''.join(format(ord(c), '08b') for c in message)
return binary_message[:bit_length].ljust(bit_length, '0')
# Converts binary string to hexadecimal
def binary_to_hex(binary_str):
return hex(int(binary_str, 2))[2:].zfill(4)
# Load and resample audio file to match model's expected sample rate
def load_and_resample_audio(audio_file_path, target_sample_rate=16000):
waveform, sample_rate = torchaudio.load(audio_file_path)
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
return waveform, target_sample_rate
# Function to generate enhanced unique identifier with timestamp and sequential number
def generate_enhanced_identifier():
timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f')
sequential_number = str(get_next_sequential_number()).zfill(6)
return f"{timestamp}-{sequential_number}"
# Function to increment and get the next sequential number from the metadata file
def get_next_sequential_number():
with open(metadata_file, 'r+') as f:
data = json.load(f)
next_number = data.get('next_sequential_number', 1)
data['next_sequential_number'] = next_number + 1
f.seek(0)
json.dump(data, f, indent=4)
f.truncate()
return next_number
# Function to save metadata for an audio file
def save_audio_metadata(unique_id, original_hex, enhanced_id):
with open(metadata_file, 'r+') as f:
data = json.load(f)
data['audio_files'] = data.get('audio_files', {})
data['audio_files'][unique_id] = {'original_hex': original_hex, 'enhanced_id': enhanced_id}
f.seek(0)
json.dump(data, f, indent=4)
f.truncate()
# Modify the watermark_audio function to include enhanced ID generation and saving metadata
def watermark_audio(audio_file_path, unique_message):
# Original watermarking process
waveform, sample_rate = load_and_resample_audio(audio_file_path, target_sample_rate=16000)
waveform = torch.clamp(waveform, min=-1.0, max=1.0)
generator = AudioSeal.load_generator("audioseal_wm_16bits")
binary_message = message_to_binary(unique_message, bit_length=16)
hex_message = binary_to_hex(binary_message)
message_tensor = torch.tensor([int(bit) for bit in binary_message], dtype=torch.int32).unsqueeze(0)
watermarked_audio = generator(waveform.unsqueeze(0), sample_rate=sample_rate, message=message_tensor).squeeze(0)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
torchaudio.save(temp_file.name, watermarked_audio, sample_rate)
# Enhanced ID generation and metadata saving
enhanced_id = generate_enhanced_identifier()
save_audio_metadata(unique_message, hex_message, enhanced_id)
return temp_file.name, hex_message, enhanced_id # Include enhanced ID in the return statement
# Function to detect watermark in audio
def detect_watermark(audio_file_path, original_hex_message=None):
waveform, sample_rate = load_and_resample_audio(audio_file_path, target_sample_rate=16000)
detector = AudioSeal.load_detector("audioseal_detector_16bits")
result, message_tensor = detector.detect_watermark(waveform.unsqueeze(0), sample_rate=sample_rate)
binary_message = ''.join(str(bit) for bit in message_tensor[0].tolist())
detected_hex_message = binary_to_hex(binary_message)
# Compare the detected message with the original, if provided
match_result = "Not compared"
if original_hex_message:
match_result = "Match" if detected_hex_message == original_hex_message.upper() else "No Match"
return result, detected_hex_message, match_result
# Load the CSS styles
style_path = Path("style.css")
if style_path.exists():
style = style_path.read_text()
else:
style = ""
# Define Gradio interface
with gr.Blocks(css=style) as demo:
with gr.Tab("Watermark Audio"):
with gr.Column(scale=6):
audio_input_watermark = gr.Audio(label="Upload Audio File for Watermarking", type="filepath")
unique_message_output = gr.Textbox(label="Unique Message")
watermarked_audio_output = gr.Audio(label="Watermarked Audio")
message_output = gr.Textbox(label="Message Used for Watermarking")
enhanced_id_output = gr.Textbox(label="Enhanced ID")
generate_message_button = gr.Button("Generate Unique Message")
watermark_button = gr.Button("Apply Watermark")
generate_message_button.click(fn=generate_unique_message, inputs=None, outputs=unique_message_output)
watermark_button.click(fn=watermark_audio, inputs=[audio_input_watermark, unique_message_output], outputs=[watermarked_audio_output, message_output, enhanced_id_output])
with gr.Tab("Detect Watermark"):
with gr.Column(scale=6):
audio_input_detect_watermark = gr.Audio(label="Upload Audio File for Watermark Detection", type="filepath")
original_hex_input = gr.Textbox(label="Original Hex Message for Comparison", placeholder="Enter the original hex message here")
detect_watermark_button = gr.Button("Detect Watermark")
watermark_detection_output = gr.Textbox(label="Watermark Detection Result")
detected_message_output = gr.Textbox(label="Detected Hex Message")
match_result_output = gr.Textbox(label="Match Result")
detect_watermark_button.click(fn=detect_watermark, inputs=[audio_input_detect_watermark, original_hex_input], outputs=[watermark_detection_output, detected_message_output, match_result_output])
demo.launch()