from typing import Dict, List, Any from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch class EndpointHandler: def __init__(self, path=""): # load model and processor self.processor = AutoProcessor.from_pretrained(path) self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16) self.model.to('cuda') def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ inputs = data.pop("inputs", data) params = data.pop("parameters", None) inputs = self.processor( text=[inputs], padding=True, return_tensors="pt" ).to('cuda') if params is not None: with torch.cuda.amp.autocast(): outputs = self.model.generate(**inputs, **params) else: with torch.cuda.amp.autocast(): outputs = self.model.generate(**inputs) pred = outputs[0].cpu().numpy().tolist() return [{"audio": pred, "sr": self.model.config.sampling_rate}]