Spaces:
Sleeping
Sleeping
add app.py & requirements.txt
Browse files- app.py +157 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import streamlit.components.v1 as components
|
3 |
+
import tempfile
|
4 |
+
import base64
|
5 |
+
import os
|
6 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
7 |
+
from torch.nn.functional import softmax
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Load Whisper (fine-tuned zh_tw)
|
11 |
+
whisper_model = pipeline("automatic-speech-recognition", model="Jingmiao/whisper-small-zh_tw")
|
12 |
+
|
13 |
+
# Intent classifier models
|
14 |
+
available_models = {
|
15 |
+
"ALBERT-tiny (Chinese)": "Luigi/albert-tiny-chinese-dinercall-intent",
|
16 |
+
"ALBERT-base (Chinese)": "Luigi/albert-base-chinese-dinercall-intent",
|
17 |
+
}
|
18 |
+
|
19 |
+
@st.cache_resource
|
20 |
+
def load_model(model_id):
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
22 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
23 |
+
return tokenizer, model
|
24 |
+
|
25 |
+
def predict_intent(text, model_id):
|
26 |
+
tokenizer, model = load_model(model_id)
|
27 |
+
inputs = tokenizer(text, return_tensors="pt")
|
28 |
+
with torch.no_grad():
|
29 |
+
logits = model(**inputs).logits
|
30 |
+
probs = softmax(logits, dim=-1)
|
31 |
+
confidence = probs[0, 1].item()
|
32 |
+
if confidence >= 0.7:
|
33 |
+
label = "📞 訂位意圖 (Reservation intent)"
|
34 |
+
else:
|
35 |
+
label = "❌ 無訂位意圖 (No intent)"
|
36 |
+
return f"{label}(信心度 Confidence: {confidence:.2%})"
|
37 |
+
|
38 |
+
# UI
|
39 |
+
st.title("🍽️ 餐廳訂位意圖識別")
|
40 |
+
st.markdown("錄音或輸入文字,自動判斷是否具有訂位意圖。")
|
41 |
+
|
42 |
+
model_label = st.selectbox("選擇模型", list(available_models.keys()))
|
43 |
+
model_id = available_models[model_label]
|
44 |
+
|
45 |
+
# JS-based mic recorder
|
46 |
+
st.markdown("### 🎙️ 點擊錄音(支援瀏覽器)")
|
47 |
+
components.html("""
|
48 |
+
<script>
|
49 |
+
let mediaRecorder;
|
50 |
+
let audioChunks = [];
|
51 |
+
let stream;
|
52 |
+
|
53 |
+
async function startRecording() {
|
54 |
+
stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
55 |
+
mediaRecorder = new MediaRecorder(stream);
|
56 |
+
mediaRecorder.ondataavailable = e => {
|
57 |
+
audioChunks.push(e.data);
|
58 |
+
};
|
59 |
+
mediaRecorder.onstop = e => {
|
60 |
+
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
|
61 |
+
audioChunks = [];
|
62 |
+
const reader = new FileReader();
|
63 |
+
reader.onloadend = () => {
|
64 |
+
const base64Audio = reader.result.split(',')[1];
|
65 |
+
const streamlitEvent = new CustomEvent("streamlit:recordedAudio", {
|
66 |
+
detail: base64Audio
|
67 |
+
});
|
68 |
+
window.dispatchEvent(streamlitEvent);
|
69 |
+
};
|
70 |
+
reader.readAsDataURL(audioBlob);
|
71 |
+
};
|
72 |
+
mediaRecorder.start();
|
73 |
+
document.getElementById("status").innerText = "🎙️ 錄音中... 按下停止結束錄音";
|
74 |
+
}
|
75 |
+
|
76 |
+
function stopRecording() {
|
77 |
+
mediaRecorder.stop();
|
78 |
+
stream.getTracks().forEach(track => track.stop());
|
79 |
+
document.getElementById("status").innerText = "✅ 錄音完成,請稍候...";
|
80 |
+
}
|
81 |
+
|
82 |
+
function setup() {
|
83 |
+
const startBtn = document.getElementById("startBtn");
|
84 |
+
const stopBtn = document.getElementById("stopBtn");
|
85 |
+
startBtn.onclick = startRecording;
|
86 |
+
stopBtn.onclick = stopRecording;
|
87 |
+
}
|
88 |
+
|
89 |
+
window.addEventListener("DOMContentLoaded", setup);
|
90 |
+
</script>
|
91 |
+
<div>
|
92 |
+
<button id="startBtn">▶️ 開始錄音</button>
|
93 |
+
<button id="stopBtn">⏹️ 停止錄音</button>
|
94 |
+
<p id="status">等待開始錄音...</p>
|
95 |
+
</div>
|
96 |
+
""", height=180)
|
97 |
+
|
98 |
+
# Handle base64 audio input
|
99 |
+
base64_audio = st.experimental_get_query_params().get("audio", [None])[0]
|
100 |
+
audio_data = st.experimental_get_query_params().get("audio_data", [None])[0]
|
101 |
+
|
102 |
+
if '_RECORDING_AUDIO_' not in st.session_state:
|
103 |
+
st.session_state._RECORDING_AUDIO_ = None
|
104 |
+
|
105 |
+
def _handle_audio_recorder():
|
106 |
+
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
107 |
+
import streamlit.runtime.legacy_caching as legacy_caching
|
108 |
+
|
109 |
+
ctx = get_script_run_ctx()
|
110 |
+
if ctx is None:
|
111 |
+
return
|
112 |
+
|
113 |
+
import streamlit.runtime.scriptrunner.script_run_context as src
|
114 |
+
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
115 |
+
|
116 |
+
# Attach JS callback
|
117 |
+
components.html("""
|
118 |
+
<script>
|
119 |
+
window.addEventListener("streamlit:recordedAudio", function(e) {
|
120 |
+
const audioData = e.detail;
|
121 |
+
const form = document.createElement("form");
|
122 |
+
form.method = "POST";
|
123 |
+
form.action = window.location.href.split("?")[0];
|
124 |
+
form.innerHTML = `<input type="hidden" name="audio_data" value="${audioData}">`;
|
125 |
+
document.body.appendChild(form);
|
126 |
+
form.submit();
|
127 |
+
});
|
128 |
+
</script>
|
129 |
+
""", height=0)
|
130 |
+
|
131 |
+
if audio_data:
|
132 |
+
# Decode and save to temp file
|
133 |
+
audio_bytes = base64.b64decode(audio_data)
|
134 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as f:
|
135 |
+
f.write(audio_bytes)
|
136 |
+
f.flush()
|
137 |
+
st.session_state._RECORDING_AUDIO_ = f.name
|
138 |
+
st.success("✅ 錄音完成!")
|
139 |
+
|
140 |
+
_handle_audio_recorder()
|
141 |
+
|
142 |
+
# Use audio file if recorded
|
143 |
+
text_input = ""
|
144 |
+
if st.session_state._RECORDING_AUDIO_:
|
145 |
+
st.audio(st.session_state._RECORDING_AUDIO_)
|
146 |
+
with st.spinner("🧠 Whisper 處理語音..."):
|
147 |
+
transcription = whisper_model(st.session_state._RECORDING_AUDIO_)["text"]
|
148 |
+
text_input = transcription
|
149 |
+
st.success(f"📝 語音轉文字:{transcription}")
|
150 |
+
|
151 |
+
# Manual fallback
|
152 |
+
text_input = st.text_input("或手動輸入語句", value=text_input)
|
153 |
+
|
154 |
+
if text_input and st.button("🚀 送出"):
|
155 |
+
with st.spinner("預測中..."):
|
156 |
+
result = predict_intent(text_input, model_id)
|
157 |
+
st.success(result)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sentencepiece
|
2 |
+
transformers>=4.30.0
|
3 |
+
torch
|
4 |
+
gradio
|
5 |
+
safetensors
|
6 |
+
huggingface_hub
|
7 |
+
torchaudio
|
8 |
+
gradio>=3.30.0
|