clap / similarity.py
imrafarafarafa's picture
Upload 3 files
26a5a6b verified
import os
import torch
import laion_clap
import pandas as pd
import ast # To convert string representation of list back to list
import numpy as np
class CLAPSimilarity:
def __init__(self, training_embeddings_prefix='training', clap_model=None, device=None):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device
if clap_model is None:
# Load CLAP model
self.clap_model = laion_clap.CLAP_Module(enable_fusion=True, device=self.device)
self.clap_model.load_ckpt()
self.clap_model.eval()
else:
self.clap_model = clap_model
# Load precomputed training embeddings from files
self.training_embeddings, self.training_filenames = self.load_embeddings(training_embeddings_prefix)
# Normalize training embeddings
self.training_embeddings = torch.nn.functional.normalize(self.training_embeddings, dim=1)
def load_embeddings(self, filename_prefix):
# Load embeddings
embeddings = np.load(f'{filename_prefix}_embeddings.npy')
embeddings = torch.tensor(embeddings, device=self.device)
# Load filenames
with open(f'{filename_prefix}_filenames.txt', 'r') as f:
filenames = [line.strip() for line in f]
return embeddings, filenames
def compute_similarity(self, input_data, input_type='audio', max_tracks=0):
"""
Compute similarity scores between input data and training embeddings.
Parameters:
- input_data: Either a string (text prompt or path to audio file) or a list of strings.
- input_type: 'audio' or 'text'
- max_tracks: Maximum number of tracks to include in the results. 0 means all tracks.
Returns:
- similarity_scores: A dictionary mapping training filenames to normalized similarity scores.
"""
with torch.no_grad():
if input_type == 'audio':
# If input_data is a path to an audio file
if isinstance(input_data, str):
input_files = [input_data]
else:
input_files = input_data
embeddings = self.clap_model.get_audio_embedding_from_filelist(
x=input_files, use_tensor=True
).to(self.device)
elif input_type == 'text':
# If input_data is a text string or list of strings
if isinstance(input_data, str):
input_texts = [input_data]
else:
input_texts = input_data
embeddings = self.clap_model.get_text_embedding(
input_texts, use_tensor=True
).to(self.device)
else:
raise ValueError("input_type must be 'audio' or 'text'")
# Normalize embeddings
embeddings = torch.nn.functional.normalize(embeddings, dim=1)
# Compute similarity scores
similarity_matrix = embeddings @ self.training_embeddings.T # (input_samples, training_samples)
# For single input, process accordingly
if similarity_matrix.shape[0] == 1:
similarities = similarity_matrix[0]
similarities = similarities.cpu().numpy()
# Shift to positive values
similarities = similarities - similarities.min()
# Normalize scores to sum to 100
total = similarities.sum()
if total > 0:
normalized_scores = (similarities / total) * 100
else:
normalized_scores = similarities
# Create a dictionary of filenames and scores
similarity_scores = dict(zip(self.training_filenames, normalized_scores))
# Sort the scores in descending order
similarity_scores = dict(sorted(similarity_scores.items(), key=lambda item: item[1], reverse=True))
# Limit the number of tracks if max_tracks is specified
if max_tracks > 0:
similarity_scores = dict(list(similarity_scores.items())[:max_tracks])
else:
# For multiple inputs, return a list of dictionaries
similarity_scores = []
for i in range(similarity_matrix.shape[0]):
similarities = similarity_matrix[i]
similarities = similarities.cpu().numpy()
# Shift to positive values
similarities = similarities - similarities.min()
# Normalize scores to sum to 100
total = similarities.sum()
if total > 0:
normalized_scores = (similarities / total) * 100
else:
normalized_scores = similarities
# Create a dictionary of filenames and scores
scores = dict(zip(self.training_filenames, normalized_scores))
# Sort the scores in descending order
scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
# Limit the number of tracks if max_tracks is specified
if max_tracks > 0:
scores = dict(list(scores.items())[:max_tracks])
similarity_scores.append(scores)
return similarity_scores