litagin's picture
init
3e40110
raw
history blame
2.62 kB
import pprint
import gradio as gr
import librosa
import plotly.graph_objects as go
import spaces
import torch
from loguru import logger
from transformers import AutoFeatureExtractor
from transformers.modeling_outputs import SequenceClassifierOutput
from model import EmotionModel
repo_id = "my_model"
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"device: {device}")
model = EmotionModel.from_pretrained(repo_id, device_map=device)
model.eval()
processor = AutoFeatureExtractor.from_pretrained(repo_id)
label_map = {
"Angry": "😠 怒り",
"Disgusted": "😒 嫌悪",
"Embarrassed": "😳 戸惑い",
"Fearful": "😨 恐怖",
"Happy": "😊 幸せ",
"Sad": "😢 悲しみ",
"Surprised": "😲 驚き",
"Neutral": "😐 中立",
"Sexual1": "🥰 NSFW1",
"Sexual2": "🍭 NSFW2",
}
@spaces.GPU
def pipe(filename: str) -> tuple[dict[str, float], go.Figure]:
audio, sr = librosa.load(filename, sr=16000)
duration = librosa.get_duration(y=audio, sr=sr)
logger.info(f"filename: {filename}, duration: {duration}")
if duration > 30.0:
return (
{f"Error: 音声ファイルの長さが長すぎます: {duration}秒": 0.0},
go.Figure(),
)
inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs: SequenceClassifierOutput = model(**inputs)
logits = outputs.logits # shape: (batch_size, num_labels)
# ロジットの取得
logits = logits[0].cpu().numpy()
labels = [label_map[label] for id, label in model.config.id2label.items()]
sorted_pairs = sorted(zip(logits, labels), key=lambda x: x[0])
sorted_logits, sorted_labels = zip(*sorted_pairs)
logger.info(f"Result:\n{pprint.pformat(sorted_pairs)}")
probabilities = outputs.logits.softmax(dim=-1)
scores_dict = {label: prob.item() for label, prob in zip(labels, probabilities[0])}
fig = go.Figure([go.Bar(x=sorted_logits, y=sorted_labels, orientation="h")])
return scores_dict, fig
md = """
# 音声からの感情認識 ver 0.1
- 音声ファイルから感情を予測して、確率とlogits (softmax前の値) を表示します
- 30秒以上の音声ファイルは受け付けません
"""
with gr.Blocks() as app:
gr.Markdown(md)
audio = gr.Audio(type="filepath")
btn = gr.Button("感情を予測")
with gr.Row():
result = gr.Label(label="結果")
plot = gr.Plot(label="Logits")
btn.click(pipe, inputs=audio, outputs=[result, plot])
app.launch()