Luigi's picture
adjust prompt
2ec8dc9
raw
history blame
5.36 kB
import streamlit as st
from streamlit_mic_recorder import mic_recorder
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from outlines.generate import choice
from outlines.models import llamacpp as outlines_llama_cpp
import torch
from torch.nn.functional import softmax
import tempfile
import re
from pathlib import Path
from faster_whisper import WhisperModel
from huggingface_hub import hf_hub_download
llm_repo_id = "Qwen/Qwen2.5-1.5B-Instruct-GGUF"
llm_filename="qwen2.5-1.5b-instruct-q8_0.gguf"
asr_repo_id = "Luigi/whisper-small-zh_tw-ct2"
llm_model_path = hf_hub_download(repo_id=llm_repo_id, filename=llm_filename)
# Load Whisper fine-tuned for zh_tw
@st.cache_resource
def load_whisper_model():
return WhisperModel(asr_repo_id, device="cpu", compute_type="int8", cpu_threads=2)
whisper_model = load_whisper_model()
# Available models for text classification
available_models = {
"ALBERT-tiny (Chinese)": "Luigi/albert-tiny-chinese-dinercall-intent",
"ALBERT-base (Chinese)": "Luigi/albert-base-chinese-dinercall-intent",
llm_repo_id: "llm"
}
# Load HuggingFace/Transformers model
@st.cache_resource
def load_transformers_model(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(model_id)
return tokenizer, model
@st.cache_resource
def load_outlines_model():
model = outlines_llama_cpp(model_path=llm_model_path,
n_ctx=1024,
n_threads=2,
n_threads_batch=2,
n_batch=4,
n_gpu_layers=0,
use_mlock=False,
use_mmap=True,
verbose=False,)
return model
def predict_with_llm(text):
model = load_outlines_model()
prompt = f"""
You are an expert in classification of restautant customers' message.
I'm going to provide you with a message from a restautant customer.
You have to classify it in one of the follwing two intents:
RESERVATION: Inquiries and requests highly related to table reservations and seating
NOT_RESERVATION: All other messages that do not involve table booking or reservations
Please reply with *only* the name of the intent labels in a JSON object like:
{{\"result\": \"RESERVATION\"}} or {{\"result\": \"NOT_RESERVATION\"}}
Here is the message to classify: {text}
""".strip()
classifier = choice(model, ["RESERVATION", "NOT_RESERVATION"])
prediction = classifier(prompt)
if prediction == "RESERVATION":
return "📞 訂位意圖 (Reservation intent)"
elif prediction == "NOT_RESERVATION":
return "❌ 無訂位意圖 (Not Reservation intent)"
# Standard Transformers classifier
def predict_intent(text, model_id):
tokenizer, model = load_transformers_model(model_id)
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = softmax(logits, dim=-1)
confidence = probs[0, 1].item()
if confidence >= 0.7:
return f"📞 訂位意圖 (Reservation intent)(訂位信心度 Confidence: {confidence:.2%})"
else:
return f"❌ 無訂位意圖 (Not Reservation intent)(訂位信心度 Confidence: {confidence:.2%})"
# Clean README
def load_clean_readme(path="README.md"):
text = Path(path).read_text(encoding="utf-8")
text = re.sub(r"(?s)^---.*?---", "", text).strip()
text = re.sub(r"^# .*?\n+", "", text)
return text
# App UI
st.title("🍽️ 餐廳訂位意圖識別")
st.markdown("錄音或輸入文字,自動判斷是否具有訂位意圖。")
model_label = st.selectbox("選擇模型", list(available_models.keys()))
model_id = available_models[model_label]
st.markdown("### 🎙️ 點擊錄音(支援瀏覽器)")
audio = mic_recorder(start_prompt="開始錄音", stop_prompt="停止錄音", just_once=True, use_container_width=True, format="wav", key="recorder")
if audio:
st.success("錄音完成!")
st.audio(audio["bytes"], format="audio/wav")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpfile:
tmpfile.write(audio["bytes"])
tmpfile_path = tmpfile.name
with st.spinner("🧠 Whisper 處理語音中..."):
try:
segments, _ = whisper_model.transcribe(tmpfile_path, beam_size=5)
transcription = "".join([seg.text for seg in segments])
st.success(f"📝 語音轉文字:{transcription}")
except Exception as e:
st.error(f"❌ Whisper 錯誤:{str(e)}")
transcription = ""
if transcription:
with st.spinner("預測中..."):
if model_id == "llm":
result = predict_with_llm(transcription)
else:
result = predict_intent(transcription, model_id)
st.success(result)
text_input = st.text_input("✍️ 或手動輸入語句")
if text_input and st.button("🚀 送出"):
with st.spinner("預測中..."):
if model_id == "llm":
result = predict_with_llm(text_input)
else:
result = predict_intent(text_input, model_id)
st.success(result)
with st.expander("ℹ️ 說明文件 / 使用說明 (README)", expanded=False):
readme_md = load_clean_readme()
st.markdown(readme_md, unsafe_allow_html=True)