|
import logging |
|
from typing import Any, Dict, List, Optional |
|
|
|
import transformers |
|
|
|
|
|
|
|
from .ultravox_model import UltravoxModel |
|
from .ultravox_processing import UltravoxProcessor |
|
|
|
|
|
class UltravoxPipeline(transformers.Pipeline): |
|
def __init__( |
|
self, |
|
model: UltravoxModel, |
|
tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None, |
|
audio_processor: Optional[transformers.ProcessorMixin] = None, |
|
**kwargs |
|
): |
|
if tokenizer is None: |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model.config._name_or_path |
|
) |
|
|
|
if audio_processor is None: |
|
audio_processor = transformers.AutoProcessor.from_pretrained( |
|
model.config.audio_model_id or model.config.audio_config._name_or_path |
|
) |
|
|
|
self.processor = UltravoxProcessor( |
|
audio_processor=audio_processor, |
|
tokenizer=tokenizer, |
|
stack_factor=model.config.stack_factor, |
|
) |
|
|
|
super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
generation_kwargs = {} |
|
if "temperature" in kwargs: |
|
generation_kwargs["temperature"] = kwargs["temperature"] |
|
if "max_new_tokens" in kwargs: |
|
generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"] |
|
if "repetition_penalty" in kwargs: |
|
generation_kwargs["repetition_penalty"] = kwargs["repetition_penalty"] |
|
return {}, generation_kwargs, {} |
|
|
|
def preprocess(self, inputs: Dict[str, Any]): |
|
if "turns" in inputs: |
|
turns = inputs["turns"] |
|
else: |
|
prompt = inputs.get("prompt", "<|audio|>") |
|
if "<|audio|>" not in prompt: |
|
logging.warning( |
|
"Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt." |
|
) |
|
prompt += " <|audio|>" |
|
turns = [{"role": "user", "content": prompt}] |
|
|
|
text = self.processor.tokenizer.apply_chat_template(turns, tokenize=False) |
|
|
|
|
|
assert "audio" in inputs, "Audio input is required" |
|
|
|
if "sampling_rate" not in inputs: |
|
logging.warning( |
|
"No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate." |
|
) |
|
|
|
return self.processor( |
|
text=text, |
|
audio=inputs["audio"], |
|
sampling_rate=inputs.get("sampling_rate", 16000), |
|
) |
|
|
|
def _forward( |
|
self, |
|
model_inputs: Dict[str, Any], |
|
temperature: Optional[float] = None, |
|
max_new_tokens: Optional[int] = None, |
|
repetition_penalty: float = 1.1, |
|
) -> List[int]: |
|
temperature = temperature or None |
|
do_sample = temperature is not None |
|
|
|
terminators = [self.tokenizer.eos_token_id] |
|
if "<|eot_id|>" in self.tokenizer.added_tokens_encoder: |
|
terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>")) |
|
|
|
input_len = model_inputs["input_ids"].shape[1] |
|
|
|
outputs = self.model.generate( |
|
**model_inputs, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
repetition_penalty=repetition_penalty, |
|
eos_token_id=terminators |
|
) |
|
return outputs[0][input_len:] |
|
|
|
def postprocess(self, model_outputs) -> str: |
|
output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True) |
|
return output_text |
|
|
|
|
|
transformers.pipelines.PIPELINE_REGISTRY.register_pipeline( |
|
"ultravox-pipeline", |
|
pipeline_class=UltravoxPipeline, |
|
pt_model=transformers.AutoModel, |
|
type="multimodal", |
|
) |
|
|