Spaces:
Runtime error
Runtime error
File size: 5,559 Bytes
26a5a6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import torch
import laion_clap
import pandas as pd
import ast # To convert string representation of list back to list
import numpy as np
class CLAPSimilarity:
def __init__(self, training_embeddings_prefix='training', clap_model=None, device=None):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device
if clap_model is None:
# Load CLAP model
self.clap_model = laion_clap.CLAP_Module(enable_fusion=True, device=self.device)
self.clap_model.load_ckpt()
self.clap_model.eval()
else:
self.clap_model = clap_model
# Load precomputed training embeddings from files
self.training_embeddings, self.training_filenames = self.load_embeddings(training_embeddings_prefix)
# Normalize training embeddings
self.training_embeddings = torch.nn.functional.normalize(self.training_embeddings, dim=1)
def load_embeddings(self, filename_prefix):
# Load embeddings
embeddings = np.load(f'{filename_prefix}_embeddings.npy')
embeddings = torch.tensor(embeddings, device=self.device)
# Load filenames
with open(f'{filename_prefix}_filenames.txt', 'r') as f:
filenames = [line.strip() for line in f]
return embeddings, filenames
def compute_similarity(self, input_data, input_type='audio', max_tracks=0):
"""
Compute similarity scores between input data and training embeddings.
Parameters:
- input_data: Either a string (text prompt or path to audio file) or a list of strings.
- input_type: 'audio' or 'text'
- max_tracks: Maximum number of tracks to include in the results. 0 means all tracks.
Returns:
- similarity_scores: A dictionary mapping training filenames to normalized similarity scores.
"""
with torch.no_grad():
if input_type == 'audio':
# If input_data is a path to an audio file
if isinstance(input_data, str):
input_files = [input_data]
else:
input_files = input_data
embeddings = self.clap_model.get_audio_embedding_from_filelist(
x=input_files, use_tensor=True
).to(self.device)
elif input_type == 'text':
# If input_data is a text string or list of strings
if isinstance(input_data, str):
input_texts = [input_data]
else:
input_texts = input_data
embeddings = self.clap_model.get_text_embedding(
input_texts, use_tensor=True
).to(self.device)
else:
raise ValueError("input_type must be 'audio' or 'text'")
# Normalize embeddings
embeddings = torch.nn.functional.normalize(embeddings, dim=1)
# Compute similarity scores
similarity_matrix = embeddings @ self.training_embeddings.T # (input_samples, training_samples)
# For single input, process accordingly
if similarity_matrix.shape[0] == 1:
similarities = similarity_matrix[0]
similarities = similarities.cpu().numpy()
# Shift to positive values
similarities = similarities - similarities.min()
# Normalize scores to sum to 100
total = similarities.sum()
if total > 0:
normalized_scores = (similarities / total) * 100
else:
normalized_scores = similarities
# Create a dictionary of filenames and scores
similarity_scores = dict(zip(self.training_filenames, normalized_scores))
# Sort the scores in descending order
similarity_scores = dict(sorted(similarity_scores.items(), key=lambda item: item[1], reverse=True))
# Limit the number of tracks if max_tracks is specified
if max_tracks > 0:
similarity_scores = dict(list(similarity_scores.items())[:max_tracks])
else:
# For multiple inputs, return a list of dictionaries
similarity_scores = []
for i in range(similarity_matrix.shape[0]):
similarities = similarity_matrix[i]
similarities = similarities.cpu().numpy()
# Shift to positive values
similarities = similarities - similarities.min()
# Normalize scores to sum to 100
total = similarities.sum()
if total > 0:
normalized_scores = (similarities / total) * 100
else:
normalized_scores = similarities
# Create a dictionary of filenames and scores
scores = dict(zip(self.training_filenames, normalized_scores))
# Sort the scores in descending order
scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
# Limit the number of tracks if max_tracks is specified
if max_tracks > 0:
scores = dict(list(scores.items())[:max_tracks])
similarity_scores.append(scores)
return similarity_scores |