Luigi commited on
Commit
4991207
·
1 Parent(s): 8c6c70c

add app.py & requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +157 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ import tempfile
4
+ import base64
5
+ import os
6
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
7
+ from torch.nn.functional import softmax
8
+ import torch
9
+
10
+ # Load Whisper (fine-tuned zh_tw)
11
+ whisper_model = pipeline("automatic-speech-recognition", model="Jingmiao/whisper-small-zh_tw")
12
+
13
+ # Intent classifier models
14
+ available_models = {
15
+ "ALBERT-tiny (Chinese)": "Luigi/albert-tiny-chinese-dinercall-intent",
16
+ "ALBERT-base (Chinese)": "Luigi/albert-base-chinese-dinercall-intent",
17
+ }
18
+
19
+ @st.cache_resource
20
+ def load_model(model_id):
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
22
+ model = AutoModelForSequenceClassification.from_pretrained(model_id)
23
+ return tokenizer, model
24
+
25
+ def predict_intent(text, model_id):
26
+ tokenizer, model = load_model(model_id)
27
+ inputs = tokenizer(text, return_tensors="pt")
28
+ with torch.no_grad():
29
+ logits = model(**inputs).logits
30
+ probs = softmax(logits, dim=-1)
31
+ confidence = probs[0, 1].item()
32
+ if confidence >= 0.7:
33
+ label = "📞 訂位意圖 (Reservation intent)"
34
+ else:
35
+ label = "❌ 無訂位意圖 (No intent)"
36
+ return f"{label}(信心度 Confidence: {confidence:.2%})"
37
+
38
+ # UI
39
+ st.title("🍽️ 餐廳訂位意圖識別")
40
+ st.markdown("錄音或輸入文字,自動判斷是否具有訂位意圖。")
41
+
42
+ model_label = st.selectbox("選擇模型", list(available_models.keys()))
43
+ model_id = available_models[model_label]
44
+
45
+ # JS-based mic recorder
46
+ st.markdown("### 🎙️ 點擊錄音(支援瀏覽器)")
47
+ components.html("""
48
+ <script>
49
+ let mediaRecorder;
50
+ let audioChunks = [];
51
+ let stream;
52
+
53
+ async function startRecording() {
54
+ stream = await navigator.mediaDevices.getUserMedia({ audio: true });
55
+ mediaRecorder = new MediaRecorder(stream);
56
+ mediaRecorder.ondataavailable = e => {
57
+ audioChunks.push(e.data);
58
+ };
59
+ mediaRecorder.onstop = e => {
60
+ const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
61
+ audioChunks = [];
62
+ const reader = new FileReader();
63
+ reader.onloadend = () => {
64
+ const base64Audio = reader.result.split(',')[1];
65
+ const streamlitEvent = new CustomEvent("streamlit:recordedAudio", {
66
+ detail: base64Audio
67
+ });
68
+ window.dispatchEvent(streamlitEvent);
69
+ };
70
+ reader.readAsDataURL(audioBlob);
71
+ };
72
+ mediaRecorder.start();
73
+ document.getElementById("status").innerText = "🎙️ 錄音中... 按下停止結束錄音";
74
+ }
75
+
76
+ function stopRecording() {
77
+ mediaRecorder.stop();
78
+ stream.getTracks().forEach(track => track.stop());
79
+ document.getElementById("status").innerText = "✅ 錄音完成,請稍候...";
80
+ }
81
+
82
+ function setup() {
83
+ const startBtn = document.getElementById("startBtn");
84
+ const stopBtn = document.getElementById("stopBtn");
85
+ startBtn.onclick = startRecording;
86
+ stopBtn.onclick = stopRecording;
87
+ }
88
+
89
+ window.addEventListener("DOMContentLoaded", setup);
90
+ </script>
91
+ <div>
92
+ <button id="startBtn">▶️ 開始錄音</button>
93
+ <button id="stopBtn">⏹️ 停止錄音</button>
94
+ <p id="status">等待開始錄音...</p>
95
+ </div>
96
+ """, height=180)
97
+
98
+ # Handle base64 audio input
99
+ base64_audio = st.experimental_get_query_params().get("audio", [None])[0]
100
+ audio_data = st.experimental_get_query_params().get("audio_data", [None])[0]
101
+
102
+ if '_RECORDING_AUDIO_' not in st.session_state:
103
+ st.session_state._RECORDING_AUDIO_ = None
104
+
105
+ def _handle_audio_recorder():
106
+ from streamlit.runtime.scriptrunner import get_script_run_ctx
107
+ import streamlit.runtime.legacy_caching as legacy_caching
108
+
109
+ ctx = get_script_run_ctx()
110
+ if ctx is None:
111
+ return
112
+
113
+ import streamlit.runtime.scriptrunner.script_run_context as src
114
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
115
+
116
+ # Attach JS callback
117
+ components.html("""
118
+ <script>
119
+ window.addEventListener("streamlit:recordedAudio", function(e) {
120
+ const audioData = e.detail;
121
+ const form = document.createElement("form");
122
+ form.method = "POST";
123
+ form.action = window.location.href.split("?")[0];
124
+ form.innerHTML = `<input type="hidden" name="audio_data" value="${audioData}">`;
125
+ document.body.appendChild(form);
126
+ form.submit();
127
+ });
128
+ </script>
129
+ """, height=0)
130
+
131
+ if audio_data:
132
+ # Decode and save to temp file
133
+ audio_bytes = base64.b64decode(audio_data)
134
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as f:
135
+ f.write(audio_bytes)
136
+ f.flush()
137
+ st.session_state._RECORDING_AUDIO_ = f.name
138
+ st.success("✅ 錄音完成!")
139
+
140
+ _handle_audio_recorder()
141
+
142
+ # Use audio file if recorded
143
+ text_input = ""
144
+ if st.session_state._RECORDING_AUDIO_:
145
+ st.audio(st.session_state._RECORDING_AUDIO_)
146
+ with st.spinner("🧠 Whisper 處理語音..."):
147
+ transcription = whisper_model(st.session_state._RECORDING_AUDIO_)["text"]
148
+ text_input = transcription
149
+ st.success(f"📝 語音轉文字:{transcription}")
150
+
151
+ # Manual fallback
152
+ text_input = st.text_input("或手動輸入語句", value=text_input)
153
+
154
+ if text_input and st.button("🚀 送出"):
155
+ with st.spinner("預測中..."):
156
+ result = predict_intent(text_input, model_id)
157
+ st.success(result)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ sentencepiece
2
+ transformers>=4.30.0
3
+ torch
4
+ gradio
5
+ safetensors
6
+ huggingface_hub
7
+ torchaudio
8
+ gradio>=3.30.0