Kabatubare's picture
Update app.py
9d2ce4b verified
import gradio as gr
import torch
import torchaudio
import tempfile
import logging
from audioseal import AudioSeal
import random
import string
from pathlib import Path
from datetime import datetime
import json
import os
# Initialize logging
logging.basicConfig(level=logging.DEBUG, filename='app.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# File to store audio metadata
metadata_file = 'audio_metadata.json'
if not os.path.exists(metadata_file):
with open(metadata_file, 'w') as f:
json.dump({}, f)
# Helper function for generating a unique alphanumeric message
def generate_unique_message(length=16):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
# Converts message to binary, ensuring it fits within the specified bit length
def message_to_binary(message, bit_length=16):
binary_message = ''.join(format(ord(c), '08b') for c in message)
return binary_message[:bit_length].ljust(bit_length, '0')
# Converts binary string to hexadecimal
def binary_to_hex(binary_str):
return hex(int(binary_str, 2))[2:].zfill(4)
# Load and resample audio file to match model's expected sample rate
def load_and_resample_audio(audio_file_path, target_sample_rate=16000):
waveform, sample_rate = torchaudio.load(audio_file_path)
if sample_rate != target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
return waveform, target_sample_rate
# Function to generate enhanced unique identifier with timestamp and sequential number
def generate_enhanced_identifier():
timestamp = datetime.now().strftime('%Y%m%d%H%M%S%f')
sequential_number = str(get_next_sequential_number()).zfill(6)
return f"{timestamp}-{sequential_number}"
# Function to increment and get the next sequential number from the metadata file
def get_next_sequential_number():
with open(metadata_file, 'r+') as f:
data = json.load(f)
next_number = data.get('next_sequential_number', 1)
data['next_sequential_number'] = next_number + 1
f.seek(0)
json.dump(data, f, indent=4)
f.truncate()
return next_number
# Function to save metadata for an audio file
def save_audio_metadata(unique_id, original_hex, enhanced_id):
with open(metadata_file, 'r+') as f:
data = json.load(f)
data['audio_files'] = data.get('audio_files', {})
data['audio_files'][unique_id] = {'original_hex': original_hex, 'enhanced_id': enhanced_id}
f.seek(0)
json.dump(data, f, indent=4)
f.truncate()
# Modify the watermark_audio function to include enhanced ID generation and saving metadata
def watermark_audio(audio_file_path, unique_message):
# Original watermarking process
waveform, sample_rate = load_and_resample_audio(audio_file_path, target_sample_rate=16000)
waveform = torch.clamp(waveform, min=-1.0, max=1.0)
generator = AudioSeal.load_generator("audioseal_wm_16bits")
binary_message = message_to_binary(unique_message, bit_length=16)
hex_message = binary_to_hex(binary_message)
message_tensor = torch.tensor([int(bit) for bit in binary_message], dtype=torch.int32).unsqueeze(0)
watermarked_audio = generator(waveform.unsqueeze(0), sample_rate=sample_rate, message=message_tensor).squeeze(0)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
torchaudio.save(temp_file.name, watermarked_audio, sample_rate)
# Enhanced ID generation and metadata saving
enhanced_id = generate_enhanced_identifier()
save_audio_metadata(unique_message, hex_message, enhanced_id)
return temp_file.name, hex_message, enhanced_id # Include enhanced ID in the return statement
# Function to detect watermark in audio
def detect_watermark(audio_file_path, original_hex_message=None):
waveform, sample_rate = load_and_resample_audio(audio_file_path, target_sample_rate=16000)
detector = AudioSeal.load_detector("audioseal_detector_16bits")
result, message_tensor = detector.detect_watermark(waveform.unsqueeze(0), sample_rate=sample_rate)
binary_message = ''.join(str(bit) for bit in message_tensor[0].tolist())
detected_hex_message = binary_to_hex(binary_message)
# Compare the detected message with the original, if provided
match_result = "Not compared"
if original_hex_message:
match_result = "Match" if detected_hex_message == original_hex_message.upper() else "No Match"
return result, detected_hex_message, match_result
# Load the CSS styles
style_path = Path("style.css")
if style_path.exists():
style = style_path.read_text()
else:
style = ""
# Define Gradio interface
def gradio_interface():
with gr.Blocks(css=style) as demo:
with gr.Tab("Watermark Audio"):
with gr.Column(scale=6):
gr.Markdown("""**How to Watermark Your Audio**
This tool embeds a unique, invisible watermark into your audio files to mark them as yours. Follow these steps:
1. **Upload Your Audio**: Choose the audio file you want to watermark.
2. **Generate Unique Message**: Click this button to generate a unique code that will serve as your audio's watermark.
3. **Apply Watermark**: Embed the watermark into your audio file. This process does not alter the audio's quality.
4. **Download Watermarked Audio**: After the watermark is applied, you can download the watermarked audio. It will sound identical to the original but now contains your unique watermark.
5. **View Enhanced ID**: Along with the watermarked audio, you'll get an enhanced ID for additional tracking and verification purposes.
""")
audio_input_watermark = gr.Audio(label="Upload Audio File for Watermarking", type="filepath")
unique_message_output = gr.Textbox(label="Unique Message")
watermarked_audio_output = gr.Audio(label="Watermarked Audio")
message_output = gr.Textbox(label="Message Used for Watermarking")
enhanced_id_output = gr.Textbox(label="Enhanced ID")
generate_message_button = gr.Button("Generate Unique Message")
watermark_button = gr.Button("Apply Watermark")
generate_message_button.click(fn=generate_unique_message, inputs=None, outputs=unique_message_output)
watermark_button.click(fn=watermark_audio, inputs=[audio_input_watermark, unique_message_output], outputs=[watermarked_audio_output, message_output, enhanced_id_output])
with gr.Tab("Detect Watermark"):
with gr.Column(scale=6):
gr.Markdown("""**How to Detect a Watermark in Your Audio**
Use this feature to check if an audio file contains a specific watermark. Here's how:
1. **Upload the Audio File
**: Select the audio file you want to check.
2. **Enter Original Hex Message**: If you know the hexadecimal code of the watermark, enter it here for a precise search.
3. **Detect Watermark**: Click to analyze the audio for your watermark.
4. **Review Results**: Find out whether your watermark was detected and if the detected code matches your input.
""")
audio_input_detect_watermark = gr.Audio(label="Upload Audio File for Watermark Detection", type="filepath")
original_hex_input = gr.Textbox(label="Original Hex Message for Comparison", placeholder="Enter the original hex message here")
detect_watermark_button = gr.Button("Detect Watermark")
watermark_detection_output = gr.Textbox(label="Watermark Detection Result")
detected_message_output = gr.Textbox(label="Detected Hex Message")
match_result_output = gr.Textbox(label="Match Result")
detect_watermark_button.click(fn=detect_watermark, inputs=[audio_input_detect_watermark, original_hex_input], outputs=[watermark_detection_output, detected_message_output, match_result_output])
return demo
if __name__ == "__main__":
demo = gradio_interface()
demo.launch()