imrafarafarafa commited on
Commit
26a5a6b
·
verified ·
1 Parent(s): dc1446d

Upload 3 files

Browse files
Files changed (3) hide show
  1. attribution.py +118 -0
  2. similarity.py +122 -0
  3. training_filenames.txt +10 -0
attribution.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import laion_clap
5
+ import pandas as pd
6
+
7
+ # Set device to GPU if available
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ def compute_clap_embeddings(audio_dir, clap_model, batch_size=32):
11
+ # Collect all audio file paths in the directory
12
+ audio_files = [os.path.join(audio_dir, fn) for fn in os.listdir(audio_dir)
13
+ if os.path.isfile(os.path.join(audio_dir, fn))]
14
+
15
+ # Add debug print
16
+ print(f"Found {len(audio_files)} files in {audio_dir}")
17
+
18
+ if not audio_files:
19
+ print(f"No files found in directory: {audio_dir}")
20
+ return [], None
21
+
22
+ embeddings_list = []
23
+ filenames_list = []
24
+
25
+ # Process audio files in batches
26
+ for i in range(0, len(audio_files), batch_size):
27
+ batch_files = audio_files[i:i + batch_size]
28
+ with torch.no_grad():
29
+ try:
30
+ # Get embeddings for the batch
31
+ embeddings = clap_model.get_audio_embedding_from_filelist(x=batch_files, use_tensor=True)
32
+ embeddings_list.append(embeddings)
33
+ filenames_list.extend([os.path.basename(f) for f in batch_files])
34
+ except Exception as e:
35
+ print(f"Error processing batch starting at index {i}: {str(e)}")
36
+ print(f"Problematic files: {batch_files}")
37
+ continue
38
+
39
+ if not embeddings_list:
40
+ print("No embeddings were generated")
41
+ return [], None
42
+
43
+ # Concatenate all embeddings
44
+ all_embeddings = torch.cat(embeddings_list, dim=0)
45
+ return filenames_list, all_embeddings
46
+
47
+ # Load CLAP model
48
+ clap_model = laion_clap.CLAP_Module(enable_fusion=True, device=device)
49
+ clap_model.load_ckpt()
50
+ clap_model.eval()
51
+
52
+ # Step 1: Compute embeddings for training tracks
53
+ training_dir = "./training"
54
+ training_filenames, training_embeddings = compute_clap_embeddings(training_dir, clap_model)
55
+ training_embeddings = training_embeddings.to(device)
56
+
57
+ # Step 2: Compute embeddings for test tracks
58
+ test_dir = './test'
59
+ test_filenames, test_embeddings = compute_clap_embeddings(test_dir, clap_model)
60
+
61
+ # Normalize training embeddings
62
+ training_embeddings = torch.nn.functional.normalize(training_embeddings, dim=1)
63
+
64
+ if test_filenames:
65
+ test_embeddings = test_embeddings.to(device)
66
+ test_embeddings = torch.nn.functional.normalize(test_embeddings, dim=1)
67
+
68
+ # Compute similarity matrix (test samples x training samples)
69
+ similarity_matrix = test_embeddings @ training_embeddings.T # Shape: (num_test_samples, num_training_samples)
70
+
71
+ # Convert similarities to attribution scores by normalizing
72
+ attribution_scores = similarity_matrix / similarity_matrix.sum(dim=1, keepdim=True)
73
+
74
+ # Map filenames to attribution scores
75
+ attribution_dict = {}
76
+ for i, test_file in enumerate(test_filenames):
77
+ scores = attribution_scores[i].cpu().numpy()
78
+ attribution_dict[test_file] = dict(zip(training_filenames, scores))
79
+
80
+ # Optional: Save attribution scores to a JSON file
81
+ import json
82
+ with open('attribution_scores.json', 'w') as f:
83
+ json.dump(attribution_dict, f, indent=4)
84
+ else:
85
+ print("No test files found in the directory. Skipping test embeddings computation and similarity calculations.")
86
+
87
+ # Function to save embeddings to CSV
88
+ def save_embeddings_to_csv(filenames, embeddings, csv_filename):
89
+ # Convert embeddings to a list
90
+ embeddings_list = embeddings.cpu().numpy().tolist()
91
+ # Create a DataFrame
92
+ df = pd.DataFrame({
93
+ 'filename': filenames,
94
+ 'embedding': embeddings_list
95
+ })
96
+ # Save to CSV
97
+ df.to_csv(csv_filename, index=False)
98
+
99
+ # Save training embeddings
100
+ save_embeddings_to_csv(training_filenames, training_embeddings, 'training_embeddings.csv')
101
+
102
+ # Optional: Save test embeddings if needed
103
+ # save_embeddings_to_csv(test_filenames, test_embeddings, 'test_embeddings.csv')
104
+
105
+ # Function to save embeddings and filenames
106
+ def save_embeddings(filenames, embeddings, filename_prefix):
107
+ # Save embeddings
108
+ np.save(f'{filename_prefix}_embeddings.npy', embeddings.cpu().numpy())
109
+ # Save filenames
110
+ with open(f'{filename_prefix}_filenames.txt', 'w') as f:
111
+ for item in filenames:
112
+ f.write("%s\n" % item)
113
+
114
+ # Save training embeddings
115
+ save_embeddings(training_filenames, training_embeddings, 'training')
116
+
117
+ # Optional: Save test embeddings if needed
118
+ # save_embeddings(test_filenames, test_embeddings, 'test')
similarity.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import laion_clap
4
+ import pandas as pd
5
+ import ast # To convert string representation of list back to list
6
+ import numpy as np
7
+
8
+ class CLAPSimilarity:
9
+ def __init__(self, training_embeddings_prefix='training', clap_model=None, device=None):
10
+ if device is None:
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ self.device = device
13
+
14
+ if clap_model is None:
15
+ # Load CLAP model
16
+ self.clap_model = laion_clap.CLAP_Module(enable_fusion=True, device=self.device)
17
+ self.clap_model.load_ckpt()
18
+ self.clap_model.eval()
19
+ else:
20
+ self.clap_model = clap_model
21
+
22
+ # Load precomputed training embeddings from files
23
+ self.training_embeddings, self.training_filenames = self.load_embeddings(training_embeddings_prefix)
24
+
25
+ # Normalize training embeddings
26
+ self.training_embeddings = torch.nn.functional.normalize(self.training_embeddings, dim=1)
27
+
28
+ def load_embeddings(self, filename_prefix):
29
+ # Load embeddings
30
+ embeddings = np.load(f'{filename_prefix}_embeddings.npy')
31
+ embeddings = torch.tensor(embeddings, device=self.device)
32
+
33
+ # Load filenames
34
+ with open(f'{filename_prefix}_filenames.txt', 'r') as f:
35
+ filenames = [line.strip() for line in f]
36
+
37
+ return embeddings, filenames
38
+
39
+ def compute_similarity(self, input_data, input_type='audio', max_tracks=0):
40
+ """
41
+ Compute similarity scores between input data and training embeddings.
42
+
43
+ Parameters:
44
+ - input_data: Either a string (text prompt or path to audio file) or a list of strings.
45
+ - input_type: 'audio' or 'text'
46
+ - max_tracks: Maximum number of tracks to include in the results. 0 means all tracks.
47
+
48
+ Returns:
49
+ - similarity_scores: A dictionary mapping training filenames to normalized similarity scores.
50
+ """
51
+ with torch.no_grad():
52
+ if input_type == 'audio':
53
+ # If input_data is a path to an audio file
54
+ if isinstance(input_data, str):
55
+ input_files = [input_data]
56
+ else:
57
+ input_files = input_data
58
+ embeddings = self.clap_model.get_audio_embedding_from_filelist(
59
+ x=input_files, use_tensor=True
60
+ ).to(self.device)
61
+ elif input_type == 'text':
62
+ # If input_data is a text string or list of strings
63
+ if isinstance(input_data, str):
64
+ input_texts = [input_data]
65
+ else:
66
+ input_texts = input_data
67
+ embeddings = self.clap_model.get_text_embedding(
68
+ input_texts, use_tensor=True
69
+ ).to(self.device)
70
+ else:
71
+ raise ValueError("input_type must be 'audio' or 'text'")
72
+
73
+
74
+ # Normalize embeddings
75
+ embeddings = torch.nn.functional.normalize(embeddings, dim=1)
76
+
77
+ # Compute similarity scores
78
+ similarity_matrix = embeddings @ self.training_embeddings.T # (input_samples, training_samples)
79
+
80
+ # For single input, process accordingly
81
+ if similarity_matrix.shape[0] == 1:
82
+ similarities = similarity_matrix[0]
83
+ similarities = similarities.cpu().numpy()
84
+ # Shift to positive values
85
+ similarities = similarities - similarities.min()
86
+ # Normalize scores to sum to 100
87
+ total = similarities.sum()
88
+ if total > 0:
89
+ normalized_scores = (similarities / total) * 100
90
+ else:
91
+ normalized_scores = similarities
92
+ # Create a dictionary of filenames and scores
93
+ similarity_scores = dict(zip(self.training_filenames, normalized_scores))
94
+ # Sort the scores in descending order
95
+ similarity_scores = dict(sorted(similarity_scores.items(), key=lambda item: item[1], reverse=True))
96
+ # Limit the number of tracks if max_tracks is specified
97
+ if max_tracks > 0:
98
+ similarity_scores = dict(list(similarity_scores.items())[:max_tracks])
99
+ else:
100
+ # For multiple inputs, return a list of dictionaries
101
+ similarity_scores = []
102
+ for i in range(similarity_matrix.shape[0]):
103
+ similarities = similarity_matrix[i]
104
+ similarities = similarities.cpu().numpy()
105
+ # Shift to positive values
106
+ similarities = similarities - similarities.min()
107
+ # Normalize scores to sum to 100
108
+ total = similarities.sum()
109
+ if total > 0:
110
+ normalized_scores = (similarities / total) * 100
111
+ else:
112
+ normalized_scores = similarities
113
+ # Create a dictionary of filenames and scores
114
+ scores = dict(zip(self.training_filenames, normalized_scores))
115
+ # Sort the scores in descending order
116
+ scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
117
+ # Limit the number of tracks if max_tracks is specified
118
+ if max_tracks > 0:
119
+ scores = dict(list(scores.items())[:max_tracks])
120
+ similarity_scores.append(scores)
121
+
122
+ return similarity_scores
training_filenames.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ES_BUZZER - Cushy.wav
2
+ ES_Life (Instrumental Version) - Northside.wav
3
+ ES_A Christmas Dance - Arthur Benson.wav
4
+ ES_Droppin Buckets (Instrumental Version) - Nyck Caution.wav
5
+ ES_EPIC FIGHT SONG NO. 1 - Def Lev.wav
6
+ ES_JACKHAMMER - Cushy.wav
7
+ ES_Sunday Blues - Hara Noda.wav
8
+ ES_Christmas Magic - Megan Wofford.wav
9
+ ES_Moving Up - Origo.wav
10
+ ES_Breeze - Basixx.wav