Spaces:
Running
on
Zero
Running
on
Zero
import streamlit as st | |
from streamlit_mic_recorder import mic_recorder | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from outlines.generate import choice | |
from outlines.models import llamacpp as outlines_llama_cpp | |
import torch | |
from torch.nn.functional import softmax | |
import tempfile | |
import re | |
from pathlib import Path | |
from faster_whisper import WhisperModel | |
from huggingface_hub import hf_hub_download | |
llm_repo_id = "Qwen/Qwen2.5-1.5B-Instruct-GGUF" | |
llm_filename="qwen2.5-1.5b-instruct-q8_0.gguf" | |
asr_repo_id = "Luigi/whisper-small-zh_tw-ct2" | |
llm_model_path = hf_hub_download(repo_id=llm_repo_id, filename=llm_filename) | |
# Load Whisper fine-tuned for zh_tw | |
def load_whisper_model(): | |
return WhisperModel(asr_repo_id, device="cpu", compute_type="int8", cpu_threads=2) | |
whisper_model = load_whisper_model() | |
# Available models for text classification | |
available_models = { | |
"ALBERT-tiny (Chinese)": "Luigi/albert-tiny-chinese-dinercall-intent", | |
"ALBERT-base (Chinese)": "Luigi/albert-base-chinese-dinercall-intent", | |
llm_repo_id: "llm" | |
} | |
# Load HuggingFace/Transformers model | |
def load_transformers_model(model_id): | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) | |
model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
return tokenizer, model | |
def load_outlines_model(): | |
model = outlines_llama_cpp(model_path=llm_model_path, | |
n_ctx=1024, | |
n_threads=2, | |
n_threads_batch=2, | |
n_batch=4, | |
n_gpu_layers=0, | |
use_mlock=False, | |
use_mmap=True, | |
verbose=False,) | |
return model | |
def predict_with_llm(text): | |
model = load_outlines_model() | |
prompt = f""" | |
You are an expert in classification of restautant customers' message. | |
I'm going to provide you with a message from a restautant customer. | |
You have to classify it in one of the follwing two intents: | |
RESERVATION: Inquiries and requests highly related to table reservations and seating | |
NOT_RESERVATION: All other messages that do not involve table booking or reservations | |
Please reply with *only* the name of the intent labels in a JSON object like: | |
{{\"result\": \"RESERVATION\"}} or {{\"result\": \"NOT_RESERVATION\"}} | |
Here is the message to classify: {text} | |
""".strip() | |
classifier = choice(model, ["RESERVATION", "NOT_RESERVATION"]) | |
prediction = classifier(prompt) | |
if prediction == "RESERVATION": | |
return "📞 訂位意圖 (Reservation intent)" | |
elif prediction == "NOT_RESERVATION": | |
return "❌ 無訂位意圖 (Not Reservation intent)" | |
# Standard Transformers classifier | |
def predict_intent(text, model_id): | |
tokenizer, model = load_transformers_model(model_id) | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
probs = softmax(logits, dim=-1) | |
confidence = probs[0, 1].item() | |
if confidence >= 0.7: | |
return f"📞 訂位意圖 (Reservation intent)(訂位信心度 Confidence: {confidence:.2%})" | |
else: | |
return f"❌ 無訂位意圖 (Not Reservation intent)(訂位信心度 Confidence: {confidence:.2%})" | |
# Clean README | |
def load_clean_readme(path="README.md"): | |
text = Path(path).read_text(encoding="utf-8") | |
text = re.sub(r"(?s)^---.*?---", "", text).strip() | |
text = re.sub(r"^# .*?\n+", "", text) | |
return text | |
# App UI | |
st.title("🍽️ 餐廳訂位意圖識別") | |
st.markdown("錄音或輸入文字,自動判斷是否具有訂位意圖。") | |
model_label = st.selectbox("選擇模型", list(available_models.keys())) | |
model_id = available_models[model_label] | |
st.markdown("### 🎙️ 點擊錄音(支援瀏覽器)") | |
audio = mic_recorder(start_prompt="開始錄音", stop_prompt="停止錄音", just_once=True, use_container_width=True, format="wav", key="recorder") | |
if audio: | |
st.success("錄音完成!") | |
st.audio(audio["bytes"], format="audio/wav") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpfile: | |
tmpfile.write(audio["bytes"]) | |
tmpfile_path = tmpfile.name | |
with st.spinner("🧠 Whisper 處理語音中..."): | |
try: | |
segments, _ = whisper_model.transcribe(tmpfile_path, beam_size=5) | |
transcription = "".join([seg.text for seg in segments]) | |
st.success(f"📝 語音轉文字:{transcription}") | |
except Exception as e: | |
st.error(f"❌ Whisper 錯誤:{str(e)}") | |
transcription = "" | |
if transcription: | |
with st.spinner("預測中..."): | |
if model_id == "llm": | |
result = predict_with_llm(transcription) | |
else: | |
result = predict_intent(transcription, model_id) | |
st.success(result) | |
text_input = st.text_input("✍️ 或手動輸入語句") | |
if text_input and st.button("🚀 送出"): | |
with st.spinner("預測中..."): | |
if model_id == "llm": | |
result = predict_with_llm(text_input) | |
else: | |
result = predict_intent(text_input, model_id) | |
st.success(result) | |
with st.expander("ℹ️ 說明文件 / 使用說明 (README)", expanded=False): | |
readme_md = load_clean_readme() | |
st.markdown(readme_md, unsafe_allow_html=True) | |