Luigi's picture
add app.py & requirements.txt
4991207
raw
history blame
5.43 kB
import streamlit as st
import streamlit.components.v1 as components
import tempfile
import base64
import os
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from torch.nn.functional import softmax
import torch
# Load Whisper (fine-tuned zh_tw)
whisper_model = pipeline("automatic-speech-recognition", model="Jingmiao/whisper-small-zh_tw")
# Intent classifier models
available_models = {
"ALBERT-tiny (Chinese)": "Luigi/albert-tiny-chinese-dinercall-intent",
"ALBERT-base (Chinese)": "Luigi/albert-base-chinese-dinercall-intent",
}
@st.cache_resource
def load_model(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(model_id)
return tokenizer, model
def predict_intent(text, model_id):
tokenizer, model = load_model(model_id)
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = softmax(logits, dim=-1)
confidence = probs[0, 1].item()
if confidence >= 0.7:
label = "📞 訂位意圖 (Reservation intent)"
else:
label = "❌ 無訂位意圖 (No intent)"
return f"{label}(信心度 Confidence: {confidence:.2%})"
# UI
st.title("🍽️ 餐廳訂位意圖識別")
st.markdown("錄音或輸入文字,自動判斷是否具有訂位意圖。")
model_label = st.selectbox("選擇模型", list(available_models.keys()))
model_id = available_models[model_label]
# JS-based mic recorder
st.markdown("### 🎙️ 點擊錄音(支援瀏覽器)")
components.html("""
<script>
let mediaRecorder;
let audioChunks = [];
let stream;
async function startRecording() {
stream = await navigator.mediaDevices.getUserMedia({ audio: true });
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.ondataavailable = e => {
audioChunks.push(e.data);
};
mediaRecorder.onstop = e => {
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
audioChunks = [];
const reader = new FileReader();
reader.onloadend = () => {
const base64Audio = reader.result.split(',')[1];
const streamlitEvent = new CustomEvent("streamlit:recordedAudio", {
detail: base64Audio
});
window.dispatchEvent(streamlitEvent);
};
reader.readAsDataURL(audioBlob);
};
mediaRecorder.start();
document.getElementById("status").innerText = "🎙️ 錄音中... 按下停止結束錄音";
}
function stopRecording() {
mediaRecorder.stop();
stream.getTracks().forEach(track => track.stop());
document.getElementById("status").innerText = "✅ 錄音完成,請稍候...";
}
function setup() {
const startBtn = document.getElementById("startBtn");
const stopBtn = document.getElementById("stopBtn");
startBtn.onclick = startRecording;
stopBtn.onclick = stopRecording;
}
window.addEventListener("DOMContentLoaded", setup);
</script>
<div>
<button id="startBtn">▶️ 開始錄音</button>
<button id="stopBtn">⏹️ 停止錄音</button>
<p id="status">等待開始錄音...</p>
</div>
""", height=180)
# Handle base64 audio input
base64_audio = st.experimental_get_query_params().get("audio", [None])[0]
audio_data = st.experimental_get_query_params().get("audio_data", [None])[0]
if '_RECORDING_AUDIO_' not in st.session_state:
st.session_state._RECORDING_AUDIO_ = None
def _handle_audio_recorder():
from streamlit.runtime.scriptrunner import get_script_run_ctx
import streamlit.runtime.legacy_caching as legacy_caching
ctx = get_script_run_ctx()
if ctx is None:
return
import streamlit.runtime.scriptrunner.script_run_context as src
from streamlit.runtime.uploaded_file_manager import UploadedFile
# Attach JS callback
components.html("""
<script>
window.addEventListener("streamlit:recordedAudio", function(e) {
const audioData = e.detail;
const form = document.createElement("form");
form.method = "POST";
form.action = window.location.href.split("?")[0];
form.innerHTML = `<input type="hidden" name="audio_data" value="${audioData}">`;
document.body.appendChild(form);
form.submit();
});
</script>
""", height=0)
if audio_data:
# Decode and save to temp file
audio_bytes = base64.b64decode(audio_data)
with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as f:
f.write(audio_bytes)
f.flush()
st.session_state._RECORDING_AUDIO_ = f.name
st.success("✅ 錄音完成!")
_handle_audio_recorder()
# Use audio file if recorded
text_input = ""
if st.session_state._RECORDING_AUDIO_:
st.audio(st.session_state._RECORDING_AUDIO_)
with st.spinner("🧠 Whisper 處理語音..."):
transcription = whisper_model(st.session_state._RECORDING_AUDIO_)["text"]
text_input = transcription
st.success(f"📝 語音轉文字:{transcription}")
# Manual fallback
text_input = st.text_input("或手動輸入語句", value=text_input)
if text_input and st.button("🚀 送出"):
with st.spinner("預測中..."):
result = predict_intent(text_input, model_id)
st.success(result)