Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import numpy as np | |
import laion_clap | |
import pandas as pd | |
# Set device to GPU if available | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def compute_clap_embeddings(audio_dir, clap_model, batch_size=32): | |
# Collect all audio file paths in the directory | |
audio_files = [os.path.join(audio_dir, fn) for fn in os.listdir(audio_dir) | |
if os.path.isfile(os.path.join(audio_dir, fn))] | |
# Add debug print | |
print(f"Found {len(audio_files)} files in {audio_dir}") | |
if not audio_files: | |
print(f"No files found in directory: {audio_dir}") | |
return [], None | |
embeddings_list = [] | |
filenames_list = [] | |
# Process audio files in batches | |
for i in range(0, len(audio_files), batch_size): | |
batch_files = audio_files[i:i + batch_size] | |
with torch.no_grad(): | |
try: | |
# Get embeddings for the batch | |
embeddings = clap_model.get_audio_embedding_from_filelist(x=batch_files, use_tensor=True) | |
embeddings_list.append(embeddings) | |
filenames_list.extend([os.path.basename(f) for f in batch_files]) | |
except Exception as e: | |
print(f"Error processing batch starting at index {i}: {str(e)}") | |
print(f"Problematic files: {batch_files}") | |
continue | |
if not embeddings_list: | |
print("No embeddings were generated") | |
return [], None | |
# Concatenate all embeddings | |
all_embeddings = torch.cat(embeddings_list, dim=0) | |
return filenames_list, all_embeddings | |
# Load CLAP model | |
clap_model = laion_clap.CLAP_Module(enable_fusion=True, device=device) | |
clap_model.load_ckpt() | |
clap_model.eval() | |
# Step 1: Compute embeddings for training tracks | |
training_dir = "./training" | |
training_filenames, training_embeddings = compute_clap_embeddings(training_dir, clap_model) | |
training_embeddings = training_embeddings.to(device) | |
# Step 2: Compute embeddings for test tracks | |
test_dir = './test' | |
test_filenames, test_embeddings = compute_clap_embeddings(test_dir, clap_model) | |
# Normalize training embeddings | |
training_embeddings = torch.nn.functional.normalize(training_embeddings, dim=1) | |
if test_filenames: | |
test_embeddings = test_embeddings.to(device) | |
test_embeddings = torch.nn.functional.normalize(test_embeddings, dim=1) | |
# Compute similarity matrix (test samples x training samples) | |
similarity_matrix = test_embeddings @ training_embeddings.T # Shape: (num_test_samples, num_training_samples) | |
# Convert similarities to attribution scores by normalizing | |
attribution_scores = similarity_matrix / similarity_matrix.sum(dim=1, keepdim=True) | |
# Map filenames to attribution scores | |
attribution_dict = {} | |
for i, test_file in enumerate(test_filenames): | |
scores = attribution_scores[i].cpu().numpy() | |
attribution_dict[test_file] = dict(zip(training_filenames, scores)) | |
# Optional: Save attribution scores to a JSON file | |
import json | |
with open('attribution_scores.json', 'w') as f: | |
json.dump(attribution_dict, f, indent=4) | |
else: | |
print("No test files found in the directory. Skipping test embeddings computation and similarity calculations.") | |
# Function to save embeddings to CSV | |
def save_embeddings_to_csv(filenames, embeddings, csv_filename): | |
# Convert embeddings to a list | |
embeddings_list = embeddings.cpu().numpy().tolist() | |
# Create a DataFrame | |
df = pd.DataFrame({ | |
'filename': filenames, | |
'embedding': embeddings_list | |
}) | |
# Save to CSV | |
df.to_csv(csv_filename, index=False) | |
# Save training embeddings | |
save_embeddings_to_csv(training_filenames, training_embeddings, 'training_embeddings.csv') | |
# Optional: Save test embeddings if needed | |
# save_embeddings_to_csv(test_filenames, test_embeddings, 'test_embeddings.csv') | |
# Function to save embeddings and filenames | |
def save_embeddings(filenames, embeddings, filename_prefix): | |
# Save embeddings | |
np.save(f'{filename_prefix}_embeddings.npy', embeddings.cpu().numpy()) | |
# Save filenames | |
with open(f'{filename_prefix}_filenames.txt', 'w') as f: | |
for item in filenames: | |
f.write("%s\n" % item) | |
# Save training embeddings | |
save_embeddings(training_filenames, training_embeddings, 'training') | |
# Optional: Save test embeddings if needed | |
# save_embeddings(test_filenames, test_embeddings, 'test') |