File size: 6,332 Bytes
362401f a7193be 362401f a7193be 362401f a7193be 362401f a7193be 362401f a7193be 362401f a7193be 362401f a7193be 362401f a7193be 362401f a7193be 362401f a7193be 362401f a7193be 41c33fc a7193be 362401f a7193be 362401f a7193be 362401f a7193be 362401f a7193be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import io
import logging
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import torch
# Import specific classes from transformers
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import librosa
import soundfile # Often needed by librosa/transformers for specific formats
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configuration ---
MODEL_NAME = "openai/whisper-large-v3"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# Ensure float32 on CPU, float16 on GPU if available
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() and DEVICE != "cpu" else torch.float32
logger.info(f"Using device: {DEVICE}")
logger.info(f"Using dtype: {TORCH_DTYPE}")
logger.info(f"Loading model and processor for: {MODEL_NAME}...")
# --- Load the Model and Processor Explicitly ---
try:
logger.info("Loading model...")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_NAME,
torch_dtype=TORCH_DTYPE,
low_cpu_mem_usage=True, # Crucial for large models on limited RAM
use_safetensors=True # Explicitly use safetensors
# If using GPU and have flash-attn installed, could add:
# attn_implementation="flash_attention_2"
)
# Move model to device *after* loading with low_cpu_mem_usage
model.to(DEVICE)
logger.info("Model loaded successfully.")
logger.info("Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_NAME)
logger.info("Processor loaded successfully.")
logger.info("Creating pipeline...")
# Create the pipeline using the pre-loaded model and processor components
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=TORCH_DTYPE,
device=DEVICE,
# Note: chunk_length_s and batch_size are inference-time args,
# better applied when calling pipe(), not during initialization.
)
logger.info("Pipeline created successfully.")
except Exception as e:
logger.error(f"Error during model/processor loading or pipeline creation: {e}", exc_info=True)
# Exit if loading fails
raise RuntimeError(f"Failed to initialize model pipeline for {MODEL_NAME}") from e
# --- FastAPI App ---
app = FastAPI()
@app.get("/")
async def read_root():
""" Basic endpoint to check if the API is running """
return {"message": f"Whisper API using {MODEL_NAME} is running."}
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
"""
Endpoint to transcribe an uploaded audio file (e.g., MP3).
"""
if not file:
raise HTTPException(status_code=400, detail="No file uploaded.")
filename = file.filename
logger.info(f"Received file: {filename}")
logger.info(f"Content type: {file.content_type}")
# Read file content into memory
try:
contents = await file.read()
logger.info(f"File read into memory ({len(contents)} bytes).")
except Exception as e:
logger.error(f"Error reading file: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error reading uploaded file: {e}")
finally:
await file.close() # Ensure file handle is closed
# Process the audio file
try:
logger.info("Processing audio...")
audio_stream = io.BytesIO(contents)
# Load audio using librosa, ensuring 16kHz mono
audio_input, sample_rate = librosa.load(audio_stream, sr=16000, mono=True)
logger.info(f"Audio loaded. Sample rate: {sample_rate}, Duration: {len(audio_input)/sample_rate:.2f}s")
if not isinstance(audio_input, np.ndarray):
audio_input = np.array(audio_input)
logger.info("Starting transcription...")
# Perform inference using the pipeline
# Apply chunking and batching here for inference
result = pipe(audio_input.copy(),
chunk_length_s=30, # Process in 30-second chunks
batch_size=4, # Adjust based on memory (start low, e.g., 1 or 2 for free tier)
return_timestamps=False # Set to True or "word" if needed
# generate_kwargs={"language": "en"} # Optional: Force language
)
transcription = result["text"]
logger.info(f"Transcription successful for {filename}.")
return JSONResponse(content={
"filename": filename,
"transcription": transcription
})
except librosa.LibrosaError as e:
logger.error(f"Error processing audio file {filename} with librosa: {e}", exc_info=True)
raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}. Ensure it's a valid audio format.")
except Exception as e:
logger.error(f"Transcription failed for {filename}: {e}", exc_info=True)
# Check for common errors like Out of Memory (OOM)
if "out of memory" in str(e).lower():
logger.error("Potential Out-of-Memory error. The model might be too large for the available resources.")
raise HTTPException(status_code=507, detail=f"Transcription failed: Insufficient Memory. Try a smaller model or shorter audio.")
else:
raise HTTPException(status_code=500, detail=f"Transcription failed: {e}")
# Optional: Add health check endpoint
@app.get("/health")
async def health_check():
""" Health check endpoint """
try:
# Check if the pipeline object exists and seems valid
if pipe and pipe.model and pipe.processor:
return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}
else:
return {"status": "error", "detail": "Model pipeline component missing"}
except Exception as e:
logger.error(f"Health check failed: {e}", exc_info=True)
return {"status": "error", "detail": str(e)}
# --- Run with Uvicorn (for local testing) ---
if __name__ == "__main__":
import uvicorn
logger.info("Starting Uvicorn server locally...")
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) # Added reload for local dev |