Spaces:
Paused
Paused
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import tempfile | |
| import base64 | |
| import os | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
| from torch.nn.functional import softmax | |
| import torch | |
| # Load Whisper (fine-tuned zh_tw) | |
| whisper_model = pipeline("automatic-speech-recognition", model="Jingmiao/whisper-small-zh_tw") | |
| # Intent classifier models | |
| available_models = { | |
| "ALBERT-tiny (Chinese)": "Luigi/albert-tiny-chinese-dinercall-intent", | |
| "ALBERT-base (Chinese)": "Luigi/albert-base-chinese-dinercall-intent", | |
| } | |
| def load_model(model_id): | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
| return tokenizer, model | |
| def predict_intent(text, model_id): | |
| tokenizer, model = load_model(model_id) | |
| inputs = tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = softmax(logits, dim=-1) | |
| confidence = probs[0, 1].item() | |
| if confidence >= 0.7: | |
| label = "📞 訂位意圖 (Reservation intent)" | |
| else: | |
| label = "❌ 無訂位意圖 (No intent)" | |
| return f"{label}(信心度 Confidence: {confidence:.2%})" | |
| # UI | |
| st.title("🍽️ 餐廳訂位意圖識別") | |
| st.markdown("錄音或輸入文字,自動判斷是否具有訂位意圖。") | |
| model_label = st.selectbox("選擇模型", list(available_models.keys())) | |
| model_id = available_models[model_label] | |
| # JS-based mic recorder | |
| st.markdown("### 🎙️ 點擊錄音(支援瀏覽器)") | |
| components.html(""" | |
| <script> | |
| let mediaRecorder; | |
| let audioChunks = []; | |
| let stream; | |
| async function startRecording() { | |
| stream = await navigator.mediaDevices.getUserMedia({ audio: true }); | |
| mediaRecorder = new MediaRecorder(stream); | |
| mediaRecorder.ondataavailable = e => { | |
| audioChunks.push(e.data); | |
| }; | |
| mediaRecorder.onstop = e => { | |
| const audioBlob = new Blob(audioChunks, { type: 'audio/webm' }); | |
| audioChunks = []; | |
| const reader = new FileReader(); | |
| reader.onloadend = () => { | |
| const base64Audio = reader.result.split(',')[1]; | |
| const streamlitEvent = new CustomEvent("streamlit:recordedAudio", { | |
| detail: base64Audio | |
| }); | |
| window.dispatchEvent(streamlitEvent); | |
| }; | |
| reader.readAsDataURL(audioBlob); | |
| }; | |
| mediaRecorder.start(); | |
| document.getElementById("status").innerText = "🎙️ 錄音中... 按下停止結束錄音"; | |
| } | |
| function stopRecording() { | |
| mediaRecorder.stop(); | |
| stream.getTracks().forEach(track => track.stop()); | |
| document.getElementById("status").innerText = "✅ 錄音完成,請稍候..."; | |
| } | |
| function setup() { | |
| const startBtn = document.getElementById("startBtn"); | |
| const stopBtn = document.getElementById("stopBtn"); | |
| startBtn.onclick = startRecording; | |
| stopBtn.onclick = stopRecording; | |
| } | |
| window.addEventListener("DOMContentLoaded", setup); | |
| </script> | |
| <div> | |
| <button id="startBtn">▶️ 開始錄音</button> | |
| <button id="stopBtn">⏹️ 停止錄音</button> | |
| <p id="status">等待開始錄音...</p> | |
| </div> | |
| """, height=180) | |
| # Handle base64 audio input | |
| base64_audio = st.experimental_get_query_params().get("audio", [None])[0] | |
| audio_data = st.experimental_get_query_params().get("audio_data", [None])[0] | |
| if '_RECORDING_AUDIO_' not in st.session_state: | |
| st.session_state._RECORDING_AUDIO_ = None | |
| def _handle_audio_recorder(): | |
| from streamlit.runtime.scriptrunner import get_script_run_ctx | |
| import streamlit.runtime.legacy_caching as legacy_caching | |
| ctx = get_script_run_ctx() | |
| if ctx is None: | |
| return | |
| import streamlit.runtime.scriptrunner.script_run_context as src | |
| from streamlit.runtime.uploaded_file_manager import UploadedFile | |
| # Attach JS callback | |
| components.html(""" | |
| <script> | |
| window.addEventListener("streamlit:recordedAudio", function(e) { | |
| const audioData = e.detail; | |
| const form = document.createElement("form"); | |
| form.method = "POST"; | |
| form.action = window.location.href.split("?")[0]; | |
| form.innerHTML = `<input type="hidden" name="audio_data" value="${audioData}">`; | |
| document.body.appendChild(form); | |
| form.submit(); | |
| }); | |
| </script> | |
| """, height=0) | |
| if audio_data: | |
| # Decode and save to temp file | |
| audio_bytes = base64.b64decode(audio_data) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as f: | |
| f.write(audio_bytes) | |
| f.flush() | |
| st.session_state._RECORDING_AUDIO_ = f.name | |
| st.success("✅ 錄音完成!") | |
| _handle_audio_recorder() | |
| # Use audio file if recorded | |
| text_input = "" | |
| if st.session_state._RECORDING_AUDIO_: | |
| st.audio(st.session_state._RECORDING_AUDIO_) | |
| with st.spinner("🧠 Whisper 處理語音..."): | |
| transcription = whisper_model(st.session_state._RECORDING_AUDIO_)["text"] | |
| text_input = transcription | |
| st.success(f"📝 語音轉文字:{transcription}") | |
| # Manual fallback | |
| text_input = st.text_input("或手動輸入語句", value=text_input) | |
| if text_input and st.button("🚀 送出"): | |
| with st.spinner("預測中..."): | |
| result = predict_intent(text_input, model_id) | |
| st.success(result) | |