kb-whisper-demo / app.py
birgermoell's picture
Update app.py
1baa979 verified
import streamlit as st
import torch
import base64
import tempfile
import os
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
# Setup model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "KBLab/kb-whisper-tiny"
@st.cache_resource
def load_model():
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, use_safetensors=True, cache_dir="cache"
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
return pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
asr_pipeline = load_model()
st.title("Swedish Speech-to-Text Demo")
# Audio Upload Option
uploaded_file = st.file_uploader("Ladda upp en ljudfil", type=["wav", "mp3", "flac"])
# JavaScript for recording audio
audio_recorder_js = """
<script>
let mediaRecorder;
let audioChunks = [];
let isRecording = false;
function startRecording() {
if (!isRecording) {
isRecording = true;
navigator.mediaDevices.getUserMedia({ audio: true }).then(stream => {
mediaRecorder = new MediaRecorder(stream);
audioChunks = [];
mediaRecorder.ondataavailable = event => {
audioChunks.push(event.data);
};
mediaRecorder.onstop = () => {
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
const reader = new FileReader();
reader.readAsDataURL(audioBlob);
reader.onloadend = () => {
const base64Audio = reader.result.split(',')[1];
fetch('/save_audio', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ audio: base64Audio })
}).then(response => response.json()).then(data => {
console.log(data);
window.location.reload();
});
};
};
mediaRecorder.start();
});
}
}
function stopRecording() {
if (isRecording) {
isRecording = false;
mediaRecorder.stop();
}
}
</script>
<button onclick="startRecording()">🎤 Starta inspelning</button>
<button onclick="stopRecording()">⏹️ Stoppa inspelning</button>
"""
st.components.v1.html(audio_recorder_js)
# Processing audio input (uploaded file or recorded)
audio_path = None
if uploaded_file is not None:
# Save uploaded file to a temp location
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[-1]) as temp_audio:
temp_audio.write(uploaded_file.read())
audio_path = temp_audio.name
elif "audio_data" in st.session_state and st.session_state["audio_data"]:
# Decode base64 audio from JavaScript recording
audio_bytes = base64.b64decode(st.session_state["audio_data"])
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
temp_audio.write(audio_bytes)
audio_path = temp_audio.name
# Transcribe if we have audio
if audio_path:
st.audio(audio_path, format="audio/wav")
with st.spinner("Transkriberar..."):
transcription = asr_pipeline(audio_path)["text"]
st.subheader("📜 Transkription:")
st.write(transcription)
# Cleanup temp file
os.remove(audio_path)