whisper / app.py
opex792's picture
Update app.py
a7193be verified
import os
import io
import logging
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import torch
# Import specific classes from transformers
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import librosa
import soundfile # Often needed by librosa/transformers for specific formats
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configuration ---
MODEL_NAME = "openai/whisper-large-v3"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# Ensure float32 on CPU, float16 on GPU if available
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() and DEVICE != "cpu" else torch.float32
logger.info(f"Using device: {DEVICE}")
logger.info(f"Using dtype: {TORCH_DTYPE}")
logger.info(f"Loading model and processor for: {MODEL_NAME}...")
# --- Load the Model and Processor Explicitly ---
try:
logger.info("Loading model...")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_NAME,
torch_dtype=TORCH_DTYPE,
low_cpu_mem_usage=True, # Crucial for large models on limited RAM
use_safetensors=True # Explicitly use safetensors
# If using GPU and have flash-attn installed, could add:
# attn_implementation="flash_attention_2"
)
# Move model to device *after* loading with low_cpu_mem_usage
model.to(DEVICE)
logger.info("Model loaded successfully.")
logger.info("Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_NAME)
logger.info("Processor loaded successfully.")
logger.info("Creating pipeline...")
# Create the pipeline using the pre-loaded model and processor components
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=TORCH_DTYPE,
device=DEVICE,
# Note: chunk_length_s and batch_size are inference-time args,
# better applied when calling pipe(), not during initialization.
)
logger.info("Pipeline created successfully.")
except Exception as e:
logger.error(f"Error during model/processor loading or pipeline creation: {e}", exc_info=True)
# Exit if loading fails
raise RuntimeError(f"Failed to initialize model pipeline for {MODEL_NAME}") from e
# --- FastAPI App ---
app = FastAPI()
@app.get("/")
async def read_root():
""" Basic endpoint to check if the API is running """
return {"message": f"Whisper API using {MODEL_NAME} is running."}
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
"""
Endpoint to transcribe an uploaded audio file (e.g., MP3).
"""
if not file:
raise HTTPException(status_code=400, detail="No file uploaded.")
filename = file.filename
logger.info(f"Received file: {filename}")
logger.info(f"Content type: {file.content_type}")
# Read file content into memory
try:
contents = await file.read()
logger.info(f"File read into memory ({len(contents)} bytes).")
except Exception as e:
logger.error(f"Error reading file: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error reading uploaded file: {e}")
finally:
await file.close() # Ensure file handle is closed
# Process the audio file
try:
logger.info("Processing audio...")
audio_stream = io.BytesIO(contents)
# Load audio using librosa, ensuring 16kHz mono
audio_input, sample_rate = librosa.load(audio_stream, sr=16000, mono=True)
logger.info(f"Audio loaded. Sample rate: {sample_rate}, Duration: {len(audio_input)/sample_rate:.2f}s")
if not isinstance(audio_input, np.ndarray):
audio_input = np.array(audio_input)
logger.info("Starting transcription...")
# Perform inference using the pipeline
# Apply chunking and batching here for inference
result = pipe(audio_input.copy(),
chunk_length_s=30, # Process in 30-second chunks
batch_size=4, # Adjust based on memory (start low, e.g., 1 or 2 for free tier)
return_timestamps=False # Set to True or "word" if needed
# generate_kwargs={"language": "en"} # Optional: Force language
)
transcription = result["text"]
logger.info(f"Transcription successful for {filename}.")
return JSONResponse(content={
"filename": filename,
"transcription": transcription
})
except librosa.LibrosaError as e:
logger.error(f"Error processing audio file {filename} with librosa: {e}", exc_info=True)
raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}. Ensure it's a valid audio format.")
except Exception as e:
logger.error(f"Transcription failed for {filename}: {e}", exc_info=True)
# Check for common errors like Out of Memory (OOM)
if "out of memory" in str(e).lower():
logger.error("Potential Out-of-Memory error. The model might be too large for the available resources.")
raise HTTPException(status_code=507, detail=f"Transcription failed: Insufficient Memory. Try a smaller model or shorter audio.")
else:
raise HTTPException(status_code=500, detail=f"Transcription failed: {e}")
# Optional: Add health check endpoint
@app.get("/health")
async def health_check():
""" Health check endpoint """
try:
# Check if the pipeline object exists and seems valid
if pipe and pipe.model and pipe.processor:
return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}
else:
return {"status": "error", "detail": "Model pipeline component missing"}
except Exception as e:
logger.error(f"Health check failed: {e}", exc_info=True)
return {"status": "error", "detail": str(e)}
# --- Run with Uvicorn (for local testing) ---
if __name__ == "__main__":
import uvicorn
logger.info("Starting Uvicorn server locally...")
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) # Added reload for local dev