Spaces:
Running
Running
import json | |
import os | |
import pprint | |
import tempfile | |
import zipfile | |
from dataclasses import dataclass | |
from pathlib import Path | |
import gradio as gr | |
import librosa | |
import numpy as np | |
import torch | |
from huggingface_hub import snapshot_download | |
from loguru import logger | |
from pyannote.audio import Inference, Model | |
HF_REPO_ID = "litagin/voice-samples-22050" | |
EMB_ROOT = Path("./embeddings") | |
RESNET34_DIM = 256 | |
AUDIO_ZIP_DIR = Path("./audio_files_zipped_by_game_22_050") | |
if AUDIO_ZIP_DIR.exists(): | |
logger.info("Audio files already downloaded. Skip downloading.") | |
else: | |
logger.info("Downloading audio files...") | |
token = os.getenv("HF_TOKEN") | |
snapshot_download( | |
HF_REPO_ID, repo_type="dataset", local_dir=AUDIO_ZIP_DIR, token=token | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Device: {device}") | |
logger.info("Loading resnet34 vectors...") | |
resnet34_embs = np.load(EMB_ROOT / "all_embs.npy") | |
resnet34_embs_normalized = resnet34_embs / np.linalg.norm( | |
resnet34_embs, axis=1, keepdims=True | |
) | |
logger.info("Loading resnet34 model...") | |
model_resnet34 = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") | |
inference = Inference(model_resnet34, window="whole") | |
inference.to(device) | |
logger.info("Loading filelist...") | |
with open(EMB_ROOT / "all_filelist.txt", "r", encoding="utf-8") as file: | |
files = [line.strip() for line in file] | |
def get_speaker_key(file_idx: int): | |
filepath = Path(files[file_idx]) | |
game_name = filepath.parent.parent.name | |
speaker_name = filepath.parent.name | |
return f"{game_name}/{speaker_name}" # ゲーム名とスピーカー名を返す | |
# スピーカーIDの配列を取得 | |
logger.info("Getting speaker ids...") | |
all_speaker_set = set([get_speaker_key(i) for i in range(len(files))]) | |
id2speaker = {i: speaker for i, speaker in enumerate(sorted(all_speaker_set))} | |
num_speakers = len(id2speaker) | |
speaker2id = {speaker: i for i, speaker in id2speaker.items()} | |
speaker_id_array = np.array([speaker2id[get_speaker_key(i)] for i in range(len(files))]) | |
class GameInfo: | |
company: str | |
name: str | |
url: str | |
logger.info("Loading game dictionary...") | |
""" | |
[ | |
{ | |
"key": "Aino+Links_Sousaku_Kanojo_no_Ren'ai_Koushiki", | |
"company": "Aino+Links", | |
"name": "創作彼女の恋愛公式", | |
"url": "http://ainolinks.com/" | |
}, | |
... | |
] | |
""" | |
with open("game_info.json", "r", encoding="utf-8") as file: | |
game_info = json.load(file) | |
game_dict = { | |
game["key"]: GameInfo(company=game["company"], name=game["name"], url=game["url"]) | |
for game in game_info | |
} | |
def get_zip_archive_path_and_internal_path(file_path: Path) -> tuple[str, str]: | |
# 構造: audio_files/{game_name}/{speaker_name}/{audio_file} | |
game_name = file_path.parent.parent.name | |
speaker_name = file_path.parent.name | |
archive_path = AUDIO_ZIP_DIR / f"{game_name}.zip" | |
internal_path = f"{speaker_name}/{file_path.name}" | |
return str(archive_path), str(internal_path) | |
def load_audio_from_zip(file_path: Path) -> tuple[np.ndarray, int]: | |
archive_path, internal_path = get_zip_archive_path_and_internal_path(file_path) | |
with zipfile.ZipFile(archive_path, "r") as zf: | |
with zf.open(internal_path) as audio_file: | |
audio_bytes = audio_file.read() | |
# 一時ファイルに書き出してから読み込む | |
with tempfile.NamedTemporaryFile( | |
delete=False, suffix=Path(internal_path).suffix | |
) as tmp_file: | |
tmp_file.write(audio_bytes) | |
tmp_file_path = tmp_file.name | |
waveform, sample_rate = librosa.load(tmp_file_path, sr=None) | |
# 一時ファイルを削除 | |
Path(tmp_file_path).unlink() | |
return waveform, int(sample_rate) | |
def get_emb(audio_path: Path | str) -> np.ndarray: | |
emb = inference(str(audio_path)) | |
assert isinstance(emb, np.ndarray) | |
assert emb.shape == (RESNET34_DIM,) | |
return emb | |
def search_audio_files(audio_path: str): | |
# Check audio duration, require < 30s | |
logger.info(f"Getting duration of {audio_path}...") | |
waveform, sample_rate = librosa.load(audio_path, sr=None) | |
duration = librosa.get_duration(y=waveform, sr=sample_rate) | |
logger.info(f"Duration: {duration:.2f}s") | |
if duration > 30: | |
logger.error(f"Duration is too long: {duration:.2f}s") | |
return [ | |
f"音声ファイルは30秒以下である必要があります。現在のファイルの長さ: {duration:.2f}s" | |
] + [None] * 20 | |
logger.info("Computing embeddings...") | |
emb = get_emb(audio_path) # ユーザー入力の音声ファイル | |
emb = emb.reshape(1, -1) # (1, dim) | |
logger.success("Embeddings computed.") | |
# Normalize query vector | |
logger.info("Computing similarities...") | |
emb_normalized = emb / np.linalg.norm(emb, axis=1, keepdims=True) | |
similarities = np.dot(resnet34_embs_normalized, emb_normalized.T).flatten() | |
logger.success("Similarities computed.") | |
# Search max similarity files | |
top_k = 10 | |
top_k_indices = np.argsort(similarities)[::-1][:top_k] | |
top_k_files = [files[file_idx] for file_idx in top_k_indices] | |
logger.info(f"Top {top_k} files:\n{pprint.pformat(top_k_files)}") | |
top_k_scores = similarities[top_k_indices] | |
logger.info(f"Top {top_k} scores:\n{pprint.pformat(top_k_scores)}") | |
logger.info("Fetching audio files...") | |
audio_result = [] | |
info_result = [] | |
for i, (file_idx, score) in enumerate(zip(top_k_indices, top_k_scores)): | |
file_path = Path(files[file_idx]) | |
waveform_np, sample_rate = load_audio_from_zip(file_path) | |
audio_result.append( | |
gr.Audio( | |
value=(sample_rate, waveform_np), | |
label=f"Top {i+1} ({score:.4f})", | |
) | |
) | |
game_key = file_path.parent.parent.name | |
speaker_name = file_path.parent.name | |
game = game_dict[game_key] | |
game_info_md = f""" | |
## {i+1}位 (スコア{score:.4f}) | |
- ゲーム名: **{game.name}** ({game.company}) | |
- 公式サイト: {game.url} | |
- キャラクター名: **{speaker_name}**""" | |
info_result.append(gr.Markdown(game_info_md)) | |
logger.success("Audio files fetched.") | |
return ["成功"] + info_result + audio_result | |
def get_label(audio_path: str, num_top_classes_to_use: int = 10): | |
# Check audio duration, require < 30s | |
logger.info(f"Getting duration of {audio_path}...") | |
waveform, sample_rate = librosa.load(audio_path, sr=None) | |
duration = librosa.get_duration(y=waveform, sr=sample_rate) | |
logger.info(f"Duration: {duration:.2f}s") | |
if duration > 30: | |
logger.error(f"Duration is too long: {duration:.2f}s") | |
return ( | |
f"音声ファイルは30秒以下である必要があります。現在のファイルの長さ: {duration:.2f}s", | |
None, | |
) | |
logger.info("Computing embeddings...") | |
emb = get_emb(audio_path) # ユーザー入力の音声ファイル | |
emb = emb.reshape(1, -1) # (1, dim) | |
logger.success("Embeddings computed.") | |
# Normalize query vector | |
emb_normalized = emb / np.linalg.norm(emb, axis=1, keepdims=True) | |
similarities = np.dot(resnet34_embs_normalized, emb_normalized.T).flatten() | |
logger.info("Calculating average scores...") | |
speaker_scores = {} | |
for character_id in range(num_speakers): | |
# 各キャラクターのインデックスを取得 | |
character_indices = np.where(speaker_id_array == character_id)[0] | |
# このキャラクターのトップ10の類似度を選択 | |
top_similarities = np.sort(similarities[character_indices])[::-1][ | |
:num_top_classes_to_use | |
] | |
# 平均スコアを計算 | |
average_score = np.mean(top_similarities) | |
# スピーカー名を取得 | |
speaker_key = id2speaker[character_id] | |
speaker_scores[speaker_key] = average_score | |
# スコアでソートして上位10件を返す | |
sorted_scores_list = sorted( | |
speaker_scores.items(), key=lambda x: x[1], reverse=True | |
) | |
sorted_scores_list = sorted_scores_list[:10] | |
logger.success("Average scores calculated.") | |
logger.info( | |
f"Top {num_top_classes_to_use} speakers:\n{pprint.pformat(sorted_scores_list)}" | |
) | |
results = [] | |
for i, (speaker_key, score) in enumerate(sorted_scores_list): | |
game_key, speaker_name = speaker_key.split("/") | |
result_md = f""" | |
## {i+1}位 (スコア{score:.4f}) | |
- ゲーム名: **{game_dict[game_key].name}** ({game_dict[game_key].company}) | |
- 公式サイト: {game_dict[game_key].url} | |
- キャラクター名: **{speaker_name}** | |
---""" | |
results.append(result_md) | |
all_result_md = "\n\n".join(results) | |
logger.success("Average scores calculated.") | |
return "成功", all_result_md | |
def make_game_info_md(game_key: str) -> str: | |
game = game_dict[game_key] | |
return f"[**{game.name}** ({game.company})]({game.url})" | |
def make_speaker_info_md(game_key: str, speaker_name: str) -> str: | |
game = game_dict[game_key] | |
return f"[{game.name} ({game.company})]({game.url})\n{speaker_name}" | |
initial_md = """ | |
# ギャルゲー似た声検索 (Galgame Voice Finder) | |
- 与えられた音声に対して、声が似ているような日本のギャルゲー(ビジュアルノベル・エロゲー)の音声とキャラクターを検索するアプリです | |
- **30秒未満**の音声ファイルにしか対応させていません (Only supports audio files less than 30 seconds) | |
- 「この声と似たキャラクターが出ているギャルゲーは?」「この音声AIの声に聞き覚えあるけど、学習元は誰なのかな?」といった疑問の参考になるかもしれません | |
- 次の**2つのモード**があります | |
- **セリフ音声検索**: セリフ単位でのTop 10の音声のサンプル表示 | |
- **キャラクター検索**: キャラクター単位でのTop 10のキャラクター表示 | |
- ゲームの公式サイトへのリンクもありますが、**18歳未満の方はリンク先へのアクセスを控えてください** | |
- 全てのゲームや、その中の全ての音声が網羅されているわけではありません(データについては下記詳細を参照) | |
""" | |
details_md = """ | |
### 音声データ | |
- 音声データは全て [OOPPEENN/Galgame_Dataset](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset) から取得(合計293ゲーム) | |
- 音声ファイル処理: 各キャラクターについて次を行う | |
- 総ファイル数が100未満の場合はモブキャラとして除外 | |
- 「2秒以上20秒未満」の音声ファイルのうち、時系列的に最初の100ファイルに加え、ランダムに最大200ファイル、合計最大300ファイルを選択 | |
- 22050Hz oggでリサンプリング | |
### ゲーム情報 | |
- [OOPPEENN/Galgame_Dataset](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset) ではゲームタイトルの英語表記のみが提供されているため、日本語タイトルと公式サイトURLを手動で調べて追加 | |
- 間違っている箇所があったら教えてください | |
### 音声ファイル同士の類似度計算 | |
- 話者埋め込み: [pyannote/wespeaker-voxceleb-resnet34-LM](https://huggingface.co/pyannote/wespeaker-voxceleb-resnet34-LM) の256次元の話者埋め込み | |
- 類似度計算: 2つの音声ファイルの話者埋め込みベクトルのコサイン類似度 | |
### キャラクター検索 | |
- 与えられた音声に対して、全ての音声ファイルとの類似度を計算 | |
- 各キャラクターについて、類似度の高い10ファイルの平均類似度を計算し、スコアとする | |
- そのスコアでソートして上位10キャラクターを表示 | |
""" | |
with gr.Blocks() as app: | |
gr.Markdown(initial_md) | |
with gr.Accordion(label="詳細", open=False): | |
gr.Markdown(details_md) | |
input_audio = gr.Audio(type="filepath", label="音声ファイルを入力") | |
gr.Markdown( | |
"「**セリフ音声検索**」と「**キャラクター検索**」の2つのモードから選択してください" | |
) | |
with gr.Tab(label="セリフ音声検索"): | |
btn_audio = gr.Button("似ているセリフ音声を検索") | |
info_audio = gr.Textbox(label="情報") | |
num_candidates = 10 | |
audio_components = [] | |
game_info_components = [] | |
for i in range(num_candidates): | |
with gr.Row(variant="panel"): | |
game_info_components.append(gr.Markdown(label=f"Top {i+1}")) | |
audio_components.append(gr.Audio(label=f"Top {i+1}")) | |
with gr.Tab(label="キャラクター検索"): | |
btn_character = gr.Button("似ているキャラクターを検索") | |
info_character = gr.Textbox(label="情報") | |
result_character = gr.Markdown("ここに結果が表示されます") | |
btn_audio.click( | |
search_audio_files, | |
inputs=[input_audio], | |
outputs=[info_audio] + game_info_components + audio_components, | |
) | |
btn_character.click( | |
get_label, inputs=[input_audio], outputs=[info_character, result_character] | |
) | |
app.launch() | |