File size: 5,559 Bytes
26a5a6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
import laion_clap
import pandas as pd
import ast  # To convert string representation of list back to list
import numpy as np

class CLAPSimilarity:
    def __init__(self, training_embeddings_prefix='training', clap_model=None, device=None):
        if device is None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = device

        if clap_model is None:
            # Load CLAP model
            self.clap_model = laion_clap.CLAP_Module(enable_fusion=True, device=self.device)
            self.clap_model.load_ckpt()
            self.clap_model.eval()
        else:
            self.clap_model = clap_model

        # Load precomputed training embeddings from files
        self.training_embeddings, self.training_filenames = self.load_embeddings(training_embeddings_prefix)

        # Normalize training embeddings
        self.training_embeddings = torch.nn.functional.normalize(self.training_embeddings, dim=1)

    def load_embeddings(self, filename_prefix):
        # Load embeddings
        embeddings = np.load(f'{filename_prefix}_embeddings.npy')
        embeddings = torch.tensor(embeddings, device=self.device)

        # Load filenames
        with open(f'{filename_prefix}_filenames.txt', 'r') as f:
            filenames = [line.strip() for line in f]

        return embeddings, filenames

    def compute_similarity(self, input_data, input_type='audio', max_tracks=0):
        """
        Compute similarity scores between input data and training embeddings.

        Parameters:
        - input_data: Either a string (text prompt or path to audio file) or a list of strings.
        - input_type: 'audio' or 'text'
        - max_tracks: Maximum number of tracks to include in the results. 0 means all tracks.

        Returns:
        - similarity_scores: A dictionary mapping training filenames to normalized similarity scores.
        """
        with torch.no_grad():
            if input_type == 'audio':
                # If input_data is a path to an audio file
                if isinstance(input_data, str):
                    input_files = [input_data]
                else:
                    input_files = input_data
                embeddings = self.clap_model.get_audio_embedding_from_filelist(
                    x=input_files, use_tensor=True
                ).to(self.device)
            elif input_type == 'text':
                # If input_data is a text string or list of strings
                if isinstance(input_data, str):
                    input_texts = [input_data]
                else:
                    input_texts = input_data
                embeddings = self.clap_model.get_text_embedding(
                    input_texts, use_tensor=True
                ).to(self.device)
            else:
                raise ValueError("input_type must be 'audio' or 'text'")


            # Normalize embeddings
            embeddings = torch.nn.functional.normalize(embeddings, dim=1)

            # Compute similarity scores
            similarity_matrix = embeddings @ self.training_embeddings.T  # (input_samples, training_samples)

            # For single input, process accordingly
            if similarity_matrix.shape[0] == 1:
                similarities = similarity_matrix[0]
                similarities = similarities.cpu().numpy()
                # Shift to positive values
                similarities = similarities - similarities.min()
                # Normalize scores to sum to 100
                total = similarities.sum()
                if total > 0:
                    normalized_scores = (similarities / total) * 100
                else:
                    normalized_scores = similarities
                # Create a dictionary of filenames and scores
                similarity_scores = dict(zip(self.training_filenames, normalized_scores))
                # Sort the scores in descending order
                similarity_scores = dict(sorted(similarity_scores.items(), key=lambda item: item[1], reverse=True))
                # Limit the number of tracks if max_tracks is specified
                if max_tracks > 0:
                    similarity_scores = dict(list(similarity_scores.items())[:max_tracks])
            else:
                # For multiple inputs, return a list of dictionaries
                similarity_scores = []
                for i in range(similarity_matrix.shape[0]):
                    similarities = similarity_matrix[i]
                    similarities = similarities.cpu().numpy()
                    # Shift to positive values
                    similarities = similarities - similarities.min()
                    # Normalize scores to sum to 100
                    total = similarities.sum()
                    if total > 0:
                        normalized_scores = (similarities / total) * 100
                    else:
                        normalized_scores = similarities
                    # Create a dictionary of filenames and scores
                    scores = dict(zip(self.training_filenames, normalized_scores))
                    # Sort the scores in descending order
                    scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
                    # Limit the number of tracks if max_tracks is specified
                    if max_tracks > 0:
                        scores = dict(list(scores.items())[:max_tracks])
                    similarity_scores.append(scores)

            return similarity_scores