import pprint import gradio as gr import librosa import plotly.graph_objects as go import spaces import torch from loguru import logger from transformers import AutoFeatureExtractor from transformers.modeling_outputs import SequenceClassifierOutput from model import EmotionModel repo_id = "my_model" device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"device: {device}") model = EmotionModel.from_pretrained(repo_id, device_map=device) model.eval() processor = AutoFeatureExtractor.from_pretrained(repo_id) label_map = { "Angry": "😠 怒り", "Disgusted": "😒 嫌悪", "Embarrassed": "😳 戸惑い", "Fearful": "😨 恐怖", "Happy": "😊 幸せ", "Sad": "😢 悲しみ", "Surprised": "😲 驚き", "Neutral": "😐 中立", "Sexual1": "🥰 NSFW1", "Sexual2": "🍭 NSFW2", } @spaces.GPU def pipe(filename: str) -> tuple[dict[str, float], go.Figure]: audio, sr = librosa.load(filename, sr=16000) duration = librosa.get_duration(y=audio, sr=sr) logger.info(f"filename: {filename}, duration: {duration}") if duration > 30.0: return ( {f"Error: 音声ファイルの長さが長すぎます: {duration}秒": 0.0}, go.Figure(), ) inputs = processor(audio, sampling_rate=sr, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs: SequenceClassifierOutput = model(**inputs) logits = outputs.logits # shape: (batch_size, num_labels) # ロジットの取得 logits = logits[0].cpu().numpy() labels = [label_map[label] for id, label in model.config.id2label.items()] sorted_pairs = sorted(zip(logits, labels), key=lambda x: x[0]) sorted_logits, sorted_labels = zip(*sorted_pairs) logger.info(f"Result:\n{pprint.pformat(sorted_pairs)}") probabilities = outputs.logits.softmax(dim=-1) scores_dict = {label: prob.item() for label, prob in zip(labels, probabilities[0])} fig = go.Figure([go.Bar(x=sorted_logits, y=sorted_labels, orientation="h")]) return scores_dict, fig md = """ # 音声からの感情認識 ver 0.1 - 音声ファイルから感情を予測して、確率とlogits (softmax前の値) を表示します - 30秒以上の音声ファイルは受け付けません """ with gr.Blocks() as app: gr.Markdown(md) audio = gr.Audio(type="filepath") btn = gr.Button("感情を予測") with gr.Row(): result = gr.Label(label="結果") plot = gr.Plot(label="Logits") btn.click(pipe, inputs=audio, outputs=[result, plot]) app.launch()