|
import pprint |
|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import librosa |
|
import plotly.graph_objects as go |
|
import spaces |
|
import torch |
|
from loguru import logger |
|
from transformers import AutoFeatureExtractor |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
from model import EmotionModel |
|
|
|
repo_id = "my_model" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"device: {device}") |
|
model = EmotionModel.from_pretrained(repo_id, device_map=device) |
|
model.eval() |
|
processor = AutoFeatureExtractor.from_pretrained(repo_id) |
|
|
|
label_map = { |
|
"Angry": "😠 怒り", |
|
"Disgusted": "😒 嫌悪", |
|
"Embarrassed": "😳 戸惑い", |
|
"Fearful": "😨 恐怖", |
|
"Happy": "😊 幸せ", |
|
"Sad": "😢 悲しみ", |
|
"Surprised": "😲 驚き", |
|
"Neutral": "😐 中立", |
|
"Sexual1": "🥰 NSFW1", |
|
"Sexual2": "🍭 NSFW2", |
|
} |
|
|
|
|
|
@spaces.GPU |
|
def pipe(filename: str) -> tuple[dict[str, float], go.Figure]: |
|
if not filename: |
|
return {"Error: ファイルが指定されていません": 0.0}, go.Figure() |
|
logger.info(f"filename: {Path(filename).name}") |
|
try: |
|
y, sr = librosa.load(filename, mono=True, sr=16000) |
|
except Exception as e: |
|
|
|
logger.error(f"Error reading file: {e}") |
|
from pydub import AudioSegment |
|
|
|
segment = AudioSegment.from_file(filename) |
|
segment.export("temp.wav", format="wav") |
|
y, sr = librosa.load("temp.wav", mono=True, sr=16000) |
|
Path("temp.wav").unlink() |
|
duration = librosa.get_duration(y=y, sr=sr) |
|
logger.info(f"Duration: {duration:.2f}s") |
|
if duration > 30.0: |
|
return ( |
|
{f"Error: 音声ファイルの長さが長すぎます: {duration:.2f}s": 0.0}, |
|
go.Figure(), |
|
) |
|
inputs = processor(y, sampling_rate=sr, return_tensors="pt") |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
with torch.no_grad(): |
|
outputs: SequenceClassifierOutput = model(**inputs) |
|
logits = outputs.logits |
|
logits = logits[0].cpu().numpy() |
|
labels = [label_map[label] for id, label in model.config.id2label.items()] |
|
sorted_pairs = sorted(zip(logits, labels), key=lambda x: x[0]) |
|
sorted_logits, sorted_labels = zip(*sorted_pairs) |
|
logger.info(f"Result:\n{pprint.pformat(sorted_pairs)}") |
|
probabilities = outputs.logits.softmax(dim=-1) |
|
scores_dict = {label: prob.item() for label, prob in zip(labels, probabilities[0])} |
|
fig = go.Figure([go.Bar(x=sorted_logits, y=sorted_labels, orientation="h")]) |
|
return scores_dict, fig |
|
|
|
|
|
md = """ |
|
# 音声からの感情認識 ver 0.1 |
|
|
|
- 音声ファイルから感情を予測して、確率とlogits (softmax前の値) を表示します |
|
- 30秒以上の音声ファイルは受け付けません |
|
- 声優による演技セリフ音声以外には微妙な結果になるかもしれません |
|
""" |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown(md) |
|
audio = gr.Audio(type="filepath") |
|
btn = gr.Button("感情を予測") |
|
with gr.Row(): |
|
result = gr.Label(label="結果") |
|
plot = gr.Plot(label="Logits") |
|
|
|
btn.click(pipe, inputs=audio, outputs=[result, plot]) |
|
|
|
app.launch() |
|
|