|
import os |
|
import io |
|
import logging |
|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.responses import JSONResponse |
|
import torch |
|
|
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
import librosa |
|
import soundfile |
|
import numpy as np |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_NAME = "openai/whisper-large-v3" |
|
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() and DEVICE != "cpu" else torch.float32 |
|
|
|
logger.info(f"Using device: {DEVICE}") |
|
logger.info(f"Using dtype: {TORCH_DTYPE}") |
|
logger.info(f"Loading model and processor for: {MODEL_NAME}...") |
|
|
|
|
|
try: |
|
logger.info("Loading model...") |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype=TORCH_DTYPE, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True |
|
|
|
|
|
) |
|
|
|
model.to(DEVICE) |
|
logger.info("Model loaded successfully.") |
|
|
|
logger.info("Loading processor...") |
|
processor = AutoProcessor.from_pretrained(MODEL_NAME) |
|
logger.info("Processor loaded successfully.") |
|
|
|
logger.info("Creating pipeline...") |
|
|
|
pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
torch_dtype=TORCH_DTYPE, |
|
device=DEVICE, |
|
|
|
|
|
) |
|
logger.info("Pipeline created successfully.") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during model/processor loading or pipeline creation: {e}", exc_info=True) |
|
|
|
raise RuntimeError(f"Failed to initialize model pipeline for {MODEL_NAME}") from e |
|
|
|
|
|
app = FastAPI() |
|
|
|
@app.get("/") |
|
async def read_root(): |
|
""" Basic endpoint to check if the API is running """ |
|
return {"message": f"Whisper API using {MODEL_NAME} is running."} |
|
|
|
@app.post("/transcribe") |
|
async def transcribe_audio(file: UploadFile = File(...)): |
|
""" |
|
Endpoint to transcribe an uploaded audio file (e.g., MP3). |
|
""" |
|
if not file: |
|
raise HTTPException(status_code=400, detail="No file uploaded.") |
|
|
|
filename = file.filename |
|
logger.info(f"Received file: {filename}") |
|
logger.info(f"Content type: {file.content_type}") |
|
|
|
|
|
try: |
|
contents = await file.read() |
|
logger.info(f"File read into memory ({len(contents)} bytes).") |
|
except Exception as e: |
|
logger.error(f"Error reading file: {e}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=f"Error reading uploaded file: {e}") |
|
finally: |
|
await file.close() |
|
|
|
|
|
try: |
|
logger.info("Processing audio...") |
|
audio_stream = io.BytesIO(contents) |
|
|
|
audio_input, sample_rate = librosa.load(audio_stream, sr=16000, mono=True) |
|
logger.info(f"Audio loaded. Sample rate: {sample_rate}, Duration: {len(audio_input)/sample_rate:.2f}s") |
|
|
|
if not isinstance(audio_input, np.ndarray): |
|
audio_input = np.array(audio_input) |
|
|
|
logger.info("Starting transcription...") |
|
|
|
|
|
result = pipe(audio_input.copy(), |
|
chunk_length_s=30, |
|
batch_size=4, |
|
return_timestamps=False |
|
|
|
) |
|
|
|
transcription = result["text"] |
|
logger.info(f"Transcription successful for {filename}.") |
|
|
|
return JSONResponse(content={ |
|
"filename": filename, |
|
"transcription": transcription |
|
}) |
|
|
|
except librosa.LibrosaError as e: |
|
logger.error(f"Error processing audio file {filename} with librosa: {e}", exc_info=True) |
|
raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}. Ensure it's a valid audio format.") |
|
except Exception as e: |
|
logger.error(f"Transcription failed for {filename}: {e}", exc_info=True) |
|
|
|
if "out of memory" in str(e).lower(): |
|
logger.error("Potential Out-of-Memory error. The model might be too large for the available resources.") |
|
raise HTTPException(status_code=507, detail=f"Transcription failed: Insufficient Memory. Try a smaller model or shorter audio.") |
|
else: |
|
raise HTTPException(status_code=500, detail=f"Transcription failed: {e}") |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
""" Health check endpoint """ |
|
try: |
|
|
|
if pipe and pipe.model and pipe.processor: |
|
return {"status": "ok", "model": MODEL_NAME, "device": DEVICE} |
|
else: |
|
return {"status": "error", "detail": "Model pipeline component missing"} |
|
except Exception as e: |
|
logger.error(f"Health check failed: {e}", exc_info=True) |
|
return {"status": "error", "detail": str(e)} |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
logger.info("Starting Uvicorn server locally...") |
|
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |