Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import faiss | |
| import sqlite3 | |
| import torch | |
| import librosa | |
| import nemo.collections.asr as nemo_asr | |
| speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( | |
| "rimelabs/rimecaster" | |
| ) | |
| speaker_model.freeze() | |
| def get_embedding(row: dict) -> torch.Tensor | None: | |
| # Ensure audio is mono | |
| if row["audio"]["array"].ndim > 1: | |
| audio_array = librosa.to_mono( | |
| row["audio"]["array"].T | |
| ) # Transpose if shape is (samples, channels) | |
| else: | |
| audio_array = row["audio"]["array"] | |
| # Resample for embedding (keep original for upload) | |
| try: | |
| audio_resampled = librosa.resample( | |
| audio_array, orig_sr=row["audio"]["sampling_rate"], target_sr=16_000 | |
| ) | |
| except Exception as e: | |
| print(f"Error resampling audio: {e}. Skipping embedding for row.") | |
| return None # Return None if resampling fails | |
| audio_length = audio_resampled.shape[0] | |
| device = speaker_model.device | |
| audio_resampled = np.array([audio_resampled]) # Add batch dim for model | |
| audio_signal, audio_signal_len = ( | |
| torch.tensor(audio_resampled, device=device, dtype=torch.float32), | |
| torch.tensor([audio_length], device=device), | |
| ) | |
| _, emb = speaker_model.forward( | |
| input_signal=audio_signal, input_signal_length=audio_signal_len | |
| ) | |
| del audio_signal, audio_signal_len, audio_resampled # Clean up resampled audio | |
| return emb.detach().cpu().numpy() # Return the tensor | |
| def get_embedding_from_array(sample_rate: int, audio_array: np.ndarray): | |
| row = {"audio": {"array": audio_array, "sampling_rate": sample_rate}} | |
| return get_embedding(row) | |
| class AudioEmbeddingSystem: | |
| def __init__( | |
| self, db_path="audio_db.sqlite", index_path="audio_faiss.index", vector_dim=768 | |
| ): | |
| """ | |
| Initialize the audio embedding system. | |
| Args: | |
| model_name: HuggingFace model to use for embeddings | |
| db_path: Path to SQLite database | |
| index_path: Path to save FAISS index | |
| vector_dim: Dimension of embedding vectors | |
| use_quantization: Whether to use vector quantization (reduces size) | |
| """ | |
| self.db_path = db_path | |
| self.index_path = index_path | |
| self.vector_dim = vector_dim | |
| self._init_db() | |
| if os.path.exists(index_path): | |
| self.index = faiss.read_index(index_path) | |
| else: | |
| self.index = faiss.IndexFlatL2(vector_dim) | |
| def _init_db(self): | |
| """Initialize SQLite database with required tables""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS audio_files ( | |
| id INTEGER PRIMARY KEY, | |
| file_path TEXT UNIQUE, | |
| vector_id INTEGER | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| def extract_embedding(self, row: dict): | |
| """Extract embedding from audio file""" | |
| return get_embedding(row) | |
| def add_audio(self, row): | |
| """Add audio file to the database and index""" | |
| embedding = self.extract_embedding(row) | |
| embedding_normalized = embedding.reshape(1, -1).astype(np.float32) | |
| current_index_size = self.index.ntotal | |
| self.index.add(embedding_normalized) | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "INSERT INTO audio_files (file_path, vector_id) VALUES (?, ?)", | |
| (row["path"], current_index_size), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| faiss.write_index(self.index, self.index_path) | |
| return current_index_size | |
| def search(self, row: dict | tuple, top_k=5, least_similar=False): | |
| """ | |
| Search for similar audio files. | |
| Either provide query_audio (path to audio file) or query_embedding (numpy array) | |
| """ | |
| if isinstance(row, dict): | |
| query_embedding = self.extract_embedding(row) | |
| else: | |
| query_embedding = get_embedding_from_array(*row) | |
| query_embedding = query_embedding.reshape(1, -1).astype(np.float32) | |
| if least_similar: | |
| query_embedding = -1 * query_embedding | |
| distances, indices = self.index.search(query_embedding, top_k) | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| results = [] | |
| for i, idx in enumerate(indices[0]): | |
| cursor.execute( | |
| "SELECT file_path FROM audio_files WHERE vector_id = ?", | |
| (int(idx),), | |
| ) | |
| row = cursor.fetchone() | |
| if row: | |
| results.append( | |
| { | |
| "path": row[0], | |
| "distance": float(distances[0][i]), | |
| "vector_id": int(idx), | |
| } | |
| ) | |
| conn.close() | |
| return results | |