Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
from torch.nn.functional import softmax | |
import numpy as np | |
import soundfile as sf | |
import io | |
import tempfile | |
import outlines # For Qwen integration via outlines | |
import kokoro # For TTS synthesis | |
import re | |
from pathlib import Path | |
from functools import lru_cache | |
import warnings | |
# Suppress FutureWarnings (e.g. about using `inputs` vs. `input_features`) | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
# ------------------- Model Identifiers ------------------- | |
whisper_model_id = "Jingmiao/whisper-small-zh_tw" | |
qwen_model_id = "Qwen/Qwen2.5-0.5B-Instruct" | |
available_models = { | |
"ALBERT-tiny (Chinese)": "Luigi/albert-tiny-chinese-dinercall-intent", | |
"ALBERT-base (Chinese)": "Luigi/albert-base-chinese-dinercall-intent", | |
"Qwen (via Transformers - outlines)": "qwen" | |
} | |
# ------------------- Caching and Loading Functions ------------------- | |
def load_whisper_pipeline(): | |
pipe = pipeline("automatic-speech-recognition", model=whisper_model_id) | |
# Move model to GPU if available for faster inference | |
if torch.cuda.is_available(): | |
pipe.model.to("cuda") | |
return pipe | |
def load_transformers_model(model_id: str): | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) | |
model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
return tokenizer, model | |
def load_qwen_model(): | |
return outlines.models.transformers(qwen_model_id) | |
def get_tts_pipeline(): | |
return kokoro.KPipeline(lang_code="z") | |
# ------------------- Inference Functions ------------------- | |
def predict_with_qwen(text: str): | |
model = load_qwen_model() | |
prompt = f""" | |
<|im_start|>system | |
You are an expert in classification of restaurant customers' messages. | |
You must decide between the following two intents: | |
RESERVATION: Inquiries and requests highly related to table reservations and seating. | |
NOT_RESERVATION: All other messages. | |
Respond with *only* the intent label in a JSON object, like: {{"result": "RESERVATION"}}. | |
<|im_end|> | |
<|im_start|>user | |
Classify the following message: "{text}" | |
<|im_end|> | |
<|im_start|>assistant | |
""" | |
generator = outlines.generate.choice(model, ["RESERVATION", "NOT_RESERVATION"]) | |
prediction = generator(prompt) | |
if prediction == "RESERVATION": | |
return "📞 訂位意圖 (Reservation intent)" | |
elif prediction == "NOT_RESERVATION": | |
return "❌ 無訂位意圖 (Not Reservation intent)" | |
else: | |
return f"未知回應: {prediction}" | |
def predict_intent(text: str, model_id: str): | |
tokenizer, model = load_transformers_model(model_id) | |
inputs = tokenizer(text, return_tensors="pt") | |
if torch.cuda.is_available(): | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
probs = softmax(logits, dim=-1) | |
confidence = probs[0, 1].item() | |
if confidence >= 0.7: | |
return f"📞 訂位意圖 (Reservation intent)(訂位信心度 Confidence: {confidence:.2%})" | |
else: | |
return f"❌ 無訂位意圖 (Not Reservation intent)(訂位信心度 Confidence: {confidence:.2%})" | |
def get_tts_message(intent_result: str): | |
if intent_result and "訂位意圖" in intent_result and "無" not in intent_result: | |
return "稍後您將會從簡訊收到訂位連結" | |
elif intent_result: | |
return "我們將會將您的回饋傳達給負責人,謝謝您" | |
else: | |
return "未能判斷意圖" | |
def tts_audio_output(message: str, voice: str = 'af_heart'): | |
pipeline_tts = get_tts_pipeline() | |
generator = pipeline_tts(message, voice=voice) | |
audio_chunks = [] | |
for _, _, audio in generator: | |
audio_chunks.append(audio) | |
if audio_chunks: | |
audio_concat = np.concatenate(audio_chunks) | |
# Return as tuple (sample_rate, numpy_array) for gr.Audio (sample rate used: 24000 Hz) | |
return (24000, audio_concat) | |
else: | |
return None | |
def transcribe_audio(audio_file): | |
whisper_pipe = load_whisper_pipeline() | |
# audio_file is the file path from gr.Audio (with type="filepath") | |
result = whisper_pipe(audio_file) | |
return result["text"] | |
# ------------------- Main Processing Function ------------------- | |
def classify_intent(mode, audio_file, text_input, model_choice): | |
# Determine input based on explicit mode. | |
if mode == "Microphone" and audio_file is not None: | |
transcription = transcribe_audio(audio_file) | |
elif mode == "Text" and text_input: | |
transcription = text_input | |
else: | |
return "請提供語音或文字輸入", "", None | |
# Classify the transcribed or provided text. | |
if available_models[model_choice] == "qwen": | |
classification = predict_with_qwen(transcription) | |
else: | |
classification = predict_intent(transcription, available_models[model_choice]) | |
# Generate TTS message and audio. | |
tts_msg = get_tts_message(classification) | |
tts_audio = tts_audio_output(tts_msg) | |
return transcription, classification, tts_audio | |
# ------------------- Gradio Blocks Interface Setup ------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🍽️ 餐廳訂位意圖識別") | |
gr.Markdown("錄音或輸入文字,自動判斷是否具有訂位意圖。") | |
with gr.Row(): | |
# Input Mode Selector | |
mode = gr.Radio(choices=["Microphone", "Text"], label="選擇輸入模式", value="Microphone") | |
with gr.Row(): | |
# Audio and Text inputs – only one will be visible based on mode selection. | |
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="語音輸入 (點擊錄音)") | |
text_input = gr.Textbox(lines=2, placeholder="請輸入文字", label="文字輸入") | |
# Initially, only the microphone input is visible. | |
text_input.visible = False | |
# Change event for mode selection to toggle visibility. | |
def update_visibility(selected_mode): | |
if selected_mode == "Microphone": | |
return gr.update(visible=True), gr.update(visible=False) | |
else: | |
return gr.update(visible=False), gr.update(visible=True) | |
mode.change(fn=update_visibility, inputs=mode, outputs=[audio_input, text_input]) | |
with gr.Row(): | |
model_dropdown = gr.Dropdown(choices=list(available_models.keys()), | |
value="ALBERT-tiny (Chinese)", label="選擇模型") | |
with gr.Row(): | |
classify_btn = gr.Button("執行辨識") | |
with gr.Row(): | |
transcription_output = gr.Textbox(label="轉換文字") | |
with gr.Row(): | |
classification_output = gr.Textbox(label="意圖判斷結果") | |
with gr.Row(): | |
tts_output = gr.Audio(type="numpy", label="TTS 語音輸出") | |
# Button event triggers the classification. Gradio will show a spinner during processing. | |
classify_btn.click(fn=classify_intent, | |
inputs=[mode, audio_input, text_input, model_dropdown], | |
outputs=[transcription_output, classification_output, tts_output]) | |
demo.launch() | |