asrnersbx / app.py
MikeTangoEcho's picture
initial commit
48b9b5d
raw
history blame
2.24 kB
import torch
from transformers import pipeline
import gradio as gr
# Pipelines
device = 0 if torch.cuda.is_available() else "cpu"
## Automatic Speech Recognition
## https://huggingface.co/docs/transformers/task_summary#automatic-speech-recognition
## Require ffmpeg to be installed
asr_model = "openai/whisper-tiny"
asr = pipeline(
"automatic-speech-recognition",
model=asr_model,
# torch_dtype=torch.float16,
device=device
)
## Token Classification / Name Entity Recognition
## https://huggingface.co/docs/transformers/task_summary#token-classification
tc_model = "dslim/distilbert-NER"
tc = pipeline(
"token-classification", # ner
model=ner_model,
device=device
)
# ---
# Transformers
# https://www.gradio.app/main/docs/gradio/audio#behavior
# As output component: expects audio data in any of these formats:
# - a str or pathlib.Path filepath
# - or URL to an audio file,
# - or a bytes object (recommended for streaming),
# - or a tuple of (sample rate in Hz, audio data as numpy array)
def transcribe(audio: str | Path | bytes | tuple[int, np.ndarray] | None):
if audio is None:
return "..."
# TODO Manage str/Path
text = ""
# https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
# Whisper input format for tuple differ from output provided by gradio audio component
if asr_model.startswith("openai/whisper"):
inputs = {"sampling_rate": audio[0], "raw": audio[1]} if type(audio) is tuple and else audio
transcript = asr(inputs)
text = transcript['text']
entities = tc(text)
# TODO Add Text Classification for sentiment analysis
return {"text": text, "entities": entities}
# ---
# Gradio
## Interfaces
# https://www.gradio.app/main/docs/gradio/audio
input_audio = gr.Audio(
sources=["upload", "microphone"],
show_share_button=False
)
## App
gradio_app = gr.Interface(
transcribe,
inputs=[
input_audio
],
outputs=[
gr.HighlightedText()
],
theme="huggingface"
title="ASRNERSBX"
description=(
"Transcribe, Tokenize, Classify"
)
allow_flagging="never"
)
## Start!
gradio_app.launch()