|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
from typing import Dict, List, Any, Literal, Optional, Tuple |
|
import torch |
|
import logging |
|
from pydantic_settings import BaseSettings |
|
from pydantic import field_validator |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
model_id = "openai/whisper-large-v3-turbo" |
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, use_safetensors=True, |
|
attn_implementation="sdpa" |
|
) |
|
model.to(device) |
|
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
self.pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
parameters (:obj: `Any`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
whisper_parameter_handler = WhisperParameterHandler() |
|
logging.info(whisper_parameter_handler.model_dump(exclude_none=True, exclude=["return_timestamps"])) |
|
|
|
|
|
|
|
prediction = self.pipe( |
|
inputs, |
|
return_timestamps=whisper_parameter_handler.return_timestamps, |
|
generate_kwargs=whisper_parameter_handler.model_dump(exclude_none=True, exclude=["return_timestamps"]) |
|
) |
|
logging.info(prediction) |
|
logging.info(prediction['chunks']) |
|
return prediction |
|
|
|
|
|
class WhisperParameterHandler(BaseSettings): |
|
language: Optional[str] = None |
|
max_new_tokens: Optional[int] = None |
|
num_beams: Optional[int] = None |
|
condition_on_prev_tokens: Optional[bool] = None |
|
compression_ratio_threshold: Optional[float] = None |
|
temperature: Optional[Tuple[float, ...]] = None |
|
logprob_threshold: Optional[float] = None |
|
no_speech_threshold: Optional[float] = None |
|
return_timestamps: Optional[Literal["word", True]] = None |
|
|
|
@field_validator("return_timestamps", mode="before") |
|
def cannonize_timestamps(cls, value: Optional[str]): |
|
if value is None: |
|
return None |
|
if value.lower() == "true": |
|
logging.info("return_timestamps == 'True'") |
|
return True |
|
|
|
return value |
|
|
|
model_config = { |
|
"env_prefix": "WHISPER_KWARGS_", |
|
"case_sensitive": False, |
|
} |
|
|
|
def to_kwargs(self): |
|
"""Convert object attributes to kwargs dict, excluding None values.""" |
|
return { |
|
key: value |
|
for key, value in self.model_dump().items() |
|
if value is not None |
|
} |
|
|