import os import torch import laion_clap import pandas as pd import ast # To convert string representation of list back to list import numpy as np class CLAPSimilarity: def __init__(self, training_embeddings_prefix='training', clap_model=None, device=None): if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = device if clap_model is None: # Load CLAP model self.clap_model = laion_clap.CLAP_Module(enable_fusion=True, device=self.device) self.clap_model.load_ckpt() self.clap_model.eval() else: self.clap_model = clap_model # Load precomputed training embeddings from files self.training_embeddings, self.training_filenames = self.load_embeddings(training_embeddings_prefix) # Normalize training embeddings self.training_embeddings = torch.nn.functional.normalize(self.training_embeddings, dim=1) def load_embeddings(self, filename_prefix): # Load embeddings embeddings = np.load(f'{filename_prefix}_embeddings.npy') embeddings = torch.tensor(embeddings, device=self.device) # Load filenames with open(f'{filename_prefix}_filenames.txt', 'r') as f: filenames = [line.strip() for line in f] return embeddings, filenames def compute_similarity(self, input_data, input_type='audio', max_tracks=0): """ Compute similarity scores between input data and training embeddings. Parameters: - input_data: Either a string (text prompt or path to audio file) or a list of strings. - input_type: 'audio' or 'text' - max_tracks: Maximum number of tracks to include in the results. 0 means all tracks. Returns: - similarity_scores: A dictionary mapping training filenames to normalized similarity scores. """ with torch.no_grad(): if input_type == 'audio': # If input_data is a path to an audio file if isinstance(input_data, str): input_files = [input_data] else: input_files = input_data embeddings = self.clap_model.get_audio_embedding_from_filelist( x=input_files, use_tensor=True ).to(self.device) elif input_type == 'text': # If input_data is a text string or list of strings if isinstance(input_data, str): input_texts = [input_data] else: input_texts = input_data embeddings = self.clap_model.get_text_embedding( input_texts, use_tensor=True ).to(self.device) else: raise ValueError("input_type must be 'audio' or 'text'") # Normalize embeddings embeddings = torch.nn.functional.normalize(embeddings, dim=1) # Compute similarity scores similarity_matrix = embeddings @ self.training_embeddings.T # (input_samples, training_samples) # For single input, process accordingly if similarity_matrix.shape[0] == 1: similarities = similarity_matrix[0] similarities = similarities.cpu().numpy() # Shift to positive values similarities = similarities - similarities.min() # Normalize scores to sum to 100 total = similarities.sum() if total > 0: normalized_scores = (similarities / total) * 100 else: normalized_scores = similarities # Create a dictionary of filenames and scores similarity_scores = dict(zip(self.training_filenames, normalized_scores)) # Sort the scores in descending order similarity_scores = dict(sorted(similarity_scores.items(), key=lambda item: item[1], reverse=True)) # Limit the number of tracks if max_tracks is specified if max_tracks > 0: similarity_scores = dict(list(similarity_scores.items())[:max_tracks]) else: # For multiple inputs, return a list of dictionaries similarity_scores = [] for i in range(similarity_matrix.shape[0]): similarities = similarity_matrix[i] similarities = similarities.cpu().numpy() # Shift to positive values similarities = similarities - similarities.min() # Normalize scores to sum to 100 total = similarities.sum() if total > 0: normalized_scores = (similarities / total) * 100 else: normalized_scores = similarities # Create a dictionary of filenames and scores scores = dict(zip(self.training_filenames, normalized_scores)) # Sort the scores in descending order scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True)) # Limit the number of tracks if max_tracks is specified if max_tracks > 0: scores = dict(list(scores.items())[:max_tracks]) similarity_scores.append(scores) return similarity_scores