Spaces:
Running
Running
import os | |
import tempfile | |
import zipfile | |
from pathlib import Path | |
import gradio as gr | |
import librosa | |
import numpy as np | |
import torch | |
from huggingface_hub import snapshot_download | |
from loguru import logger | |
from pyannote.audio import Inference, Model | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
HF_REPO_ID = "litagin/galgame_voice_samples" | |
RESNET34_ROOT = Path("./embeddings") | |
RESNET34_DIM = 256 | |
AUDIO_ZIP_DIR = Path("./audio_files_zipped_by_game_22_050") | |
if AUDIO_ZIP_DIR.exists(): | |
logger.info("Audio files already downloaded. Skip downloading.") | |
else: | |
logger.info("Downloading audio files...") | |
token = os.getenv("HF_TOKEN") | |
snapshot_download( | |
HF_REPO_ID, repo_type="dataset", local_dir=AUDIO_ZIP_DIR, token=token | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Device: {device}") | |
logger.info("Loading resnet34 vectors...") | |
resnet34_embs = np.load(RESNET34_ROOT / "all_embs.npy") | |
resnet34_embs_normalized = resnet34_embs / np.linalg.norm( | |
resnet34_embs, axis=1, keepdims=True | |
) | |
logger.info("Loading resnet34 model...") | |
model_resnet34 = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") | |
inference = Inference(model_resnet34, window="whole") | |
inference.to(device) | |
logger.info("Loading filelist...") | |
with open(RESNET34_ROOT / "all_filelists.txt", "r", encoding="utf-8") as file: | |
files = [line.strip() for line in file] | |
def get_speaker_name(file_idx: int): | |
filepath = Path(files[file_idx]) | |
game_name = filepath.parent.parent.name | |
speaker_name = filepath.parent.name | |
return f"{game_name}/{speaker_name}" # ゲーム名とスピーカー名を返す | |
# スピーカーIDの配列を取得 | |
logger.info("Getting speaker ids...") | |
all_speaker_set = set([get_speaker_name(i) for i in range(len(files))]) | |
id2speaker = {i: speaker for i, speaker in enumerate(sorted(all_speaker_set))} | |
num_speakers = len(id2speaker) | |
speaker2id = {speaker: i for i, speaker in id2speaker.items()} | |
speaker_id_array = np.array( | |
[speaker2id[get_speaker_name(i)] for i in range(len(files))] | |
) | |
# def get_zip_archive_path_and_internal_path(file_path: Path) -> tuple[str, str]: | |
# # 構造: audio_files/{game_name}/{speaker_name}/{audio_file} | |
# game_name = file_path.parent.parent.name | |
# speaker_name = file_path.parent.name | |
# archive_path = AUDIO_ZIP_DIR / game_name / f"{speaker_name}.zip" | |
# internal_path = file_path.name # ZIP内のパスはファイル名のみ | |
# return str(archive_path), str(internal_path) | |
def get_zip_archive_path_and_internal_path(file_path: Path) -> tuple[str, str]: | |
# 構造: audio_files/{game_name}/{speaker_name}/{audio_file} | |
game_name = file_path.parent.parent.name | |
speaker_name = file_path.parent.name | |
archive_path = AUDIO_ZIP_DIR / f"{game_name}.zip" | |
internal_path = f"{speaker_name}/{file_path.name}" # ZIP内のパスを "speaker_name/ファイル名" とする | |
return str(archive_path), str(internal_path) | |
def load_audio_from_zip(file_path: Path) -> tuple[np.ndarray, int]: | |
archive_path, internal_path = get_zip_archive_path_and_internal_path(file_path) | |
with zipfile.ZipFile(archive_path, "r") as zf: | |
with zf.open(internal_path) as audio_file: | |
audio_bytes = audio_file.read() | |
# 一時ファイルに書き出してから読み込む | |
with tempfile.NamedTemporaryFile( | |
delete=False, suffix=Path(internal_path).suffix | |
) as tmp_file: | |
tmp_file.write(audio_bytes) | |
tmp_file_path = tmp_file.name | |
waveform, sample_rate = librosa.load(tmp_file_path, sr=None) | |
# 一時ファイルを削除 | |
Path(tmp_file_path).unlink() | |
return waveform, int(sample_rate) | |
def get_emb(audio_path: Path | str) -> np.ndarray: | |
emb = inference(str(audio_path)) | |
assert isinstance(emb, np.ndarray) | |
assert emb.shape == (RESNET34_DIM,) | |
return emb | |
def search(audio_path: str): | |
logger.info("Computing embeddings...") | |
emb = get_emb(audio_path) # ユーザー入力の音声ファイル | |
emb = emb.reshape(1, -1) # (1, dim) | |
logger.success("Embeddings computed.") | |
# Normalize query vector | |
logger.info("Computing similarities...") | |
emb_normalized = emb / np.linalg.norm(emb, axis=1, keepdims=True) | |
similarities = np.dot(resnet34_embs_normalized, emb_normalized.T).flatten() | |
logger.success("Similarities computed.") | |
# Search max similarity files | |
top_k = 10 | |
top_k_indices = np.argsort(similarities)[::-1][:top_k] | |
top_k_files = [files[file_idx] for file_idx in top_k_indices] | |
top_k_scores = similarities[top_k_indices] | |
logger.info("Fetching audio files...") | |
result = [] | |
for i, (f, file_idx, score) in enumerate( | |
zip(top_k_files, top_k_indices, top_k_scores) | |
): | |
waveform_np, sample_rate = load_audio_from_zip(Path(f)) | |
result.append( | |
gr.Audio( | |
value=(sample_rate, waveform_np), | |
label=f"Top {i+1}: {get_speaker_name(file_idx)}, {score:.4f}", | |
) | |
) | |
logger.success("Audio files fetched.") | |
return result | |
def get_label(audio_path: str, num_top_classes: int = 10): | |
logger.info("Computing embeddings...") | |
emb = get_emb(audio_path) # ユーザー入力の音声ファイル | |
emb = emb.reshape(1, -1) # (1, dim) | |
logger.success("Embeddings computed.") | |
# Normalize query vector | |
emb_normalized = emb / np.linalg.norm(emb, axis=1, keepdims=True) | |
similarities = np.dot(resnet34_embs_normalized, emb_normalized.T).flatten() | |
logger.info("Calculating average scores...") | |
speaker_scores = {} | |
for character_id in range(num_speakers): | |
# 各キャラクターのインデックスを取得 | |
character_indices = np.where(speaker_id_array == character_id)[0] | |
# このキャラクターのトップ10の類似度を選択 | |
top_similarities = np.sort(similarities[character_indices])[::-1][ | |
:num_top_classes | |
] | |
# 平均スコアを計算 | |
average_score = np.mean(top_similarities) | |
# スピーカー名を取得 | |
speaker_name = id2speaker[character_id] | |
speaker_scores[speaker_name] = average_score | |
# スコアでソートして上位10件を返す | |
sorted_scores = dict( | |
sorted(speaker_scores.items(), key=lambda item: item[1], reverse=True)[:10] | |
) | |
logger.success("Average scores calculated.") | |
return sorted_scores | |
with gr.Blocks() as app: | |
input_audio = gr.Audio(type="filepath") | |
with gr.Row(): | |
with gr.Column(): | |
btn_audio = gr.Button("似ている音声を検索") | |
top_k = 10 | |
components = [gr.Audio(label=f"Top {i+1}") for i in range(top_k)] | |
with gr.Column(): | |
btn_label = gr.Button("似ている話者を検索") | |
label = gr.Label(num_top_classes=10) | |
btn_audio.click(search, inputs=[input_audio], outputs=components) | |
btn_label.click(get_label, inputs=[input_audio], outputs=[label]) | |
app.launch() | |