Teapack1's picture
Initial commit
183ee92
raw
history blame
4.51 kB
from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
import os
import numpy as np
from transformers import pipeline
import torch
from transformers.pipelines.audio_utils import ffmpeg_microphone_live
device = "cuda:0" if torch.cuda.is_available() else "cpu"
classifier = pipeline(
"audio-classification", model="MIT/ast-finetuned-speech-commands-v2", device=device
)
intent_class_pipe = pipeline(
"audio-classification", model="anton-l/xtreme_s_xlsr_minds14", device=device
)
async def launch_fn(
wake_word="marvin",
prob_threshold=0.5,
chunk_length_s=2.0,
stream_chunk_s=0.25,
debug=False,
):
if wake_word not in classifier.model.config.label2id.keys():
raise ValueError(
f"Wake word {wake_word} not in set of valid class labels, pick a wake word in the set {classifier.model.config.label2id.keys()}."
)
sampling_rate = classifier.feature_extractor.sampling_rate
mic = ffmpeg_microphone_live(
sampling_rate=sampling_rate,
chunk_length_s=chunk_length_s,
stream_chunk_s=stream_chunk_s,
)
print("Listening for wake word...")
for prediction in classifier(mic):
prediction = prediction[0]
if debug:
print(prediction)
if prediction["label"] == wake_word:
if prediction["score"] > prob_threshold:
return True
async def listen(websocket, chunk_length_s=2.0, stream_chunk_s=2.0):
sampling_rate = intent_class_pipe.feature_extractor.sampling_rate
mic = ffmpeg_microphone_live(
sampling_rate=sampling_rate,
chunk_length_s=chunk_length_s,
stream_chunk_s=stream_chunk_s,
)
audio_buffer = []
print("Listening")
for i in range(4):
audio_chunk = next(mic)
audio_buffer.append(audio_chunk["raw"])
prediction = intent_class_pipe(audio_chunk["raw"])
await websocket.send_text(f"chunk: {prediction[0]['label']} | {i+1} / 4")
if await is_silence(audio_chunk["raw"], threshold=0.7):
print("Silence detected, processing audio.")
break
combined_audio = np.concatenate(audio_buffer)
prediction = intent_class_pipe(combined_audio)
top_3_predictions = prediction[:3]
formatted_predictions = "\n".join([f"{pred['label']}: {pred['score'] * 100:.2f}%" for pred in top_3_predictions])
await websocket.send_text(f"classes: \n{formatted_predictions}")
return
async def is_silence(audio_chunk, threshold):
silence = intent_class_pipe(audio_chunk)
if silence[0]["label"] == "silence" and silence[0]["score"] > threshold:
return True
else:
return False
# Initialize FastAPI app
app = FastAPI()
# Set up static file directory
app.mount("/static", StaticFiles(directory="static"), name="static")
# Jinja2 Template for HTML rendering
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def get_home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
process_active = False # Flag to track the state of the process
while True:
message = await websocket.receive_text()
if message == "start" and not process_active:
process_active = True
await websocket.send_text("Listening for wake word...")
wake_word_detected = await launch_fn(debug=True)
if wake_word_detected:
await websocket.send_text("Wake word detected. Listening for your query...")
await listen(websocket)
process_active = False # Reset the process flag
elif message == "stop":
if process_active:
# Implement logic to stop the ongoing process
# This might involve setting a flag that your launch_fn and listen functions check
process_active = False
await websocket.send_text("Process stopped. Ready to restart.")
break # Or keep the loop running if you want to allow restarting without reconnecting
except WebSocketDisconnect:
print("Client disconnected.")