Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import laion_clap | |
import pandas as pd | |
import ast # To convert string representation of list back to list | |
import numpy as np | |
class CLAPSimilarity: | |
def __init__(self, training_embeddings_prefix='training', clap_model=None, device=None): | |
if device is None: | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.device = device | |
if clap_model is None: | |
# Load CLAP model | |
self.clap_model = laion_clap.CLAP_Module(enable_fusion=True, device=self.device) | |
self.clap_model.load_ckpt() | |
self.clap_model.eval() | |
else: | |
self.clap_model = clap_model | |
# Load precomputed training embeddings from files | |
self.training_embeddings, self.training_filenames = self.load_embeddings(training_embeddings_prefix) | |
# Normalize training embeddings | |
self.training_embeddings = torch.nn.functional.normalize(self.training_embeddings, dim=1) | |
def load_embeddings(self, filename_prefix): | |
# Load embeddings | |
embeddings = np.load(f'{filename_prefix}_embeddings.npy') | |
embeddings = torch.tensor(embeddings, device=self.device) | |
# Load filenames | |
with open(f'{filename_prefix}_filenames.txt', 'r') as f: | |
filenames = [line.strip() for line in f] | |
return embeddings, filenames | |
def compute_similarity(self, input_data, input_type='audio', max_tracks=0): | |
""" | |
Compute similarity scores between input data and training embeddings. | |
Parameters: | |
- input_data: Either a string (text prompt or path to audio file) or a list of strings. | |
- input_type: 'audio' or 'text' | |
- max_tracks: Maximum number of tracks to include in the results. 0 means all tracks. | |
Returns: | |
- similarity_scores: A dictionary mapping training filenames to normalized similarity scores. | |
""" | |
with torch.no_grad(): | |
if input_type == 'audio': | |
# If input_data is a path to an audio file | |
if isinstance(input_data, str): | |
input_files = [input_data] | |
else: | |
input_files = input_data | |
embeddings = self.clap_model.get_audio_embedding_from_filelist( | |
x=input_files, use_tensor=True | |
).to(self.device) | |
elif input_type == 'text': | |
# If input_data is a text string or list of strings | |
if isinstance(input_data, str): | |
input_texts = [input_data] | |
else: | |
input_texts = input_data | |
embeddings = self.clap_model.get_text_embedding( | |
input_texts, use_tensor=True | |
).to(self.device) | |
else: | |
raise ValueError("input_type must be 'audio' or 'text'") | |
# Normalize embeddings | |
embeddings = torch.nn.functional.normalize(embeddings, dim=1) | |
# Compute similarity scores | |
similarity_matrix = embeddings @ self.training_embeddings.T # (input_samples, training_samples) | |
# For single input, process accordingly | |
if similarity_matrix.shape[0] == 1: | |
similarities = similarity_matrix[0] | |
similarities = similarities.cpu().numpy() | |
# Shift to positive values | |
similarities = similarities - similarities.min() | |
# Normalize scores to sum to 100 | |
total = similarities.sum() | |
if total > 0: | |
normalized_scores = (similarities / total) * 100 | |
else: | |
normalized_scores = similarities | |
# Create a dictionary of filenames and scores | |
similarity_scores = dict(zip(self.training_filenames, normalized_scores)) | |
# Sort the scores in descending order | |
similarity_scores = dict(sorted(similarity_scores.items(), key=lambda item: item[1], reverse=True)) | |
# Limit the number of tracks if max_tracks is specified | |
if max_tracks > 0: | |
similarity_scores = dict(list(similarity_scores.items())[:max_tracks]) | |
else: | |
# For multiple inputs, return a list of dictionaries | |
similarity_scores = [] | |
for i in range(similarity_matrix.shape[0]): | |
similarities = similarity_matrix[i] | |
similarities = similarities.cpu().numpy() | |
# Shift to positive values | |
similarities = similarities - similarities.min() | |
# Normalize scores to sum to 100 | |
total = similarities.sum() | |
if total > 0: | |
normalized_scores = (similarities / total) * 100 | |
else: | |
normalized_scores = similarities | |
# Create a dictionary of filenames and scores | |
scores = dict(zip(self.training_filenames, normalized_scores)) | |
# Sort the scores in descending order | |
scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True)) | |
# Limit the number of tracks if max_tracks is specified | |
if max_tracks > 0: | |
scores = dict(list(scores.items())[:max_tracks]) | |
similarity_scores.append(scores) | |
return similarity_scores |