File size: 6,332 Bytes
362401f
 
 
 
 
 
a7193be
 
362401f
 
 
 
 
 
 
 
 
 
 
a7193be
 
362401f
 
 
a7193be
362401f
a7193be
362401f
a7193be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362401f
a7193be
 
 
362401f
 
a7193be
 
362401f
a7193be
 
362401f
a7193be
 
 
362401f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7193be
362401f
 
 
 
 
 
 
a7193be
 
 
41c33fc
a7193be
 
 
362401f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7193be
 
 
 
 
 
362401f
 
 
 
 
 
a7193be
 
362401f
 
a7193be
362401f
 
 
 
 
 
 
 
a7193be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import io
import logging
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import torch
# Import specific classes from transformers
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import librosa
import soundfile # Often needed by librosa/transformers for specific formats
import numpy as np

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Configuration ---
MODEL_NAME = "openai/whisper-large-v3"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# Ensure float32 on CPU, float16 on GPU if available
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() and DEVICE != "cpu" else torch.float32

logger.info(f"Using device: {DEVICE}")
logger.info(f"Using dtype: {TORCH_DTYPE}")
logger.info(f"Loading model and processor for: {MODEL_NAME}...")

# --- Load the Model and Processor Explicitly ---
try:
    logger.info("Loading model...")
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        MODEL_NAME,
        torch_dtype=TORCH_DTYPE,
        low_cpu_mem_usage=True, # Crucial for large models on limited RAM
        use_safetensors=True    # Explicitly use safetensors
        # If using GPU and have flash-attn installed, could add:
        # attn_implementation="flash_attention_2"
    )
    # Move model to device *after* loading with low_cpu_mem_usage
    model.to(DEVICE)
    logger.info("Model loaded successfully.")

    logger.info("Loading processor...")
    processor = AutoProcessor.from_pretrained(MODEL_NAME)
    logger.info("Processor loaded successfully.")

    logger.info("Creating pipeline...")
    # Create the pipeline using the pre-loaded model and processor components
    pipe = pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        torch_dtype=TORCH_DTYPE,
        device=DEVICE,
        # Note: chunk_length_s and batch_size are inference-time args,
        # better applied when calling pipe(), not during initialization.
    )
    logger.info("Pipeline created successfully.")

except Exception as e:
    logger.error(f"Error during model/processor loading or pipeline creation: {e}", exc_info=True)
    # Exit if loading fails
    raise RuntimeError(f"Failed to initialize model pipeline for {MODEL_NAME}") from e

# --- FastAPI App ---
app = FastAPI()

@app.get("/")
async def read_root():
    """ Basic endpoint to check if the API is running """
    return {"message": f"Whisper API using {MODEL_NAME} is running."}

@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
    """
    Endpoint to transcribe an uploaded audio file (e.g., MP3).
    """
    if not file:
        raise HTTPException(status_code=400, detail="No file uploaded.")

    filename = file.filename
    logger.info(f"Received file: {filename}")
    logger.info(f"Content type: {file.content_type}")

    # Read file content into memory
    try:
        contents = await file.read()
        logger.info(f"File read into memory ({len(contents)} bytes).")
    except Exception as e:
        logger.error(f"Error reading file: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Error reading uploaded file: {e}")
    finally:
        await file.close() # Ensure file handle is closed

    # Process the audio file
    try:
        logger.info("Processing audio...")
        audio_stream = io.BytesIO(contents)
        # Load audio using librosa, ensuring 16kHz mono
        audio_input, sample_rate = librosa.load(audio_stream, sr=16000, mono=True)
        logger.info(f"Audio loaded. Sample rate: {sample_rate}, Duration: {len(audio_input)/sample_rate:.2f}s")

        if not isinstance(audio_input, np.ndarray):
             audio_input = np.array(audio_input)

        logger.info("Starting transcription...")
        # Perform inference using the pipeline
        # Apply chunking and batching here for inference
        result = pipe(audio_input.copy(),
                      chunk_length_s=30, # Process in 30-second chunks
                      batch_size=4,      # Adjust based on memory (start low, e.g., 1 or 2 for free tier)
                      return_timestamps=False # Set to True or "word" if needed
                      # generate_kwargs={"language": "en"} # Optional: Force language
                     )

        transcription = result["text"]
        logger.info(f"Transcription successful for {filename}.")

        return JSONResponse(content={
            "filename": filename,
            "transcription": transcription
        })

    except librosa.LibrosaError as e:
        logger.error(f"Error processing audio file {filename} with librosa: {e}", exc_info=True)
        raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}. Ensure it's a valid audio format.")
    except Exception as e:
        logger.error(f"Transcription failed for {filename}: {e}", exc_info=True)
        # Check for common errors like Out of Memory (OOM)
        if "out of memory" in str(e).lower():
             logger.error("Potential Out-of-Memory error. The model might be too large for the available resources.")
             raise HTTPException(status_code=507, detail=f"Transcription failed: Insufficient Memory. Try a smaller model or shorter audio.")
        else:
             raise HTTPException(status_code=500, detail=f"Transcription failed: {e}")

# Optional: Add health check endpoint
@app.get("/health")
async def health_check():
    """ Health check endpoint """
    try:
        # Check if the pipeline object exists and seems valid
        if pipe and pipe.model and pipe.processor:
             return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}
        else:
             return {"status": "error", "detail": "Model pipeline component missing"}
    except Exception as e:
        logger.error(f"Health check failed: {e}", exc_info=True)
        return {"status": "error", "detail": str(e)}

# --- Run with Uvicorn (for local testing) ---
if __name__ == "__main__":
    import uvicorn
    logger.info("Starting Uvicorn server locally...")
    uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) # Added reload for local dev