ido-whisper-turbo / handler.py
IdoMachlev's picture
changed attention implementation to "sdpa" from default "eager"
a49a698
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from typing import Dict, List, Any, Literal, Optional, Tuple
import torch
import logging
from pydantic_settings import BaseSettings
from pydantic import field_validator
class EndpointHandler():
def __init__(self, path=""):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3-turbo"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype,
low_cpu_mem_usage=True, use_safetensors=True,
attn_implementation="sdpa"
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
self.pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
parameters (:obj: `Any`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
whisper_parameter_handler = WhisperParameterHandler()
logging.info(whisper_parameter_handler.model_dump(exclude_none=True, exclude=["return_timestamps"]))
# run normal prediction
prediction = self.pipe(
inputs,
return_timestamps=whisper_parameter_handler.return_timestamps,
generate_kwargs=whisper_parameter_handler.model_dump(exclude_none=True, exclude=["return_timestamps"])
)
logging.info(prediction)
logging.info(prediction['chunks'])
return prediction
class WhisperParameterHandler(BaseSettings):
language: Optional[str] = None # Optional fields default to None
max_new_tokens: Optional[int] = None
num_beams: Optional[int] = None
condition_on_prev_tokens: Optional[bool] = None
compression_ratio_threshold: Optional[float] = None
temperature: Optional[Tuple[float, ...]] = None # Optional Tuple
logprob_threshold: Optional[float] = None
no_speech_threshold: Optional[float] = None
return_timestamps: Optional[Literal["word", True]] = None
@field_validator("return_timestamps", mode="before")
def cannonize_timestamps(cls, value: Optional[str]):
if value is None:
return None
if value.lower() == "true":
logging.info("return_timestamps == 'True'")
return True
return value
model_config = {
"env_prefix": "WHISPER_KWARGS_",
"case_sensitive": False,
}
def to_kwargs(self):
"""Convert object attributes to kwargs dict, excluding None values."""
return {
key: value
for key, value in self.model_dump().items() # Use model_dump for accurate representation
if value is not None # Exclude None values
}