from typing import Dict, List, Any import json import numpy as np from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.processor = AutoProcessor.from_pretrained(path) # Check if CUDA is available, and set the device accordingly self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the model to the device self.model = MusicgenForConditionalGeneration.from_pretrained(path) self.model.to(self.device) # Correcting this line def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) duration = parameters.pop("duration", None) audio = parameters.pop("audio", None) sampling_rate = parameters.pop("sampling_rate", None) if audio is not None: audio_list = json.loads(audio) audio_array = np.array(audio_list) audio = audio_array if duration is not None: # Calculate max new tokens based on duration, this is a placeholder, replace with actual logic max_new_tokens = int(duration * 50) else: max_new_tokens = 256 # Default value if duration is not provided # preprocess inputs = self.processor( text=[inputs], padding=True, return_tensors="pt", audio=audio, sampling_rate=sampling_rate).to(self.device) # If 'duration' is inside 'parameters', remove it if parameters is not None and 'duration' in parameters: parameters.pop('duration') if parameters is not None and 'audio' in parameters: parameters.pop('audio') if parameters is not None and 'sampling_rate' in parameters: parameters.pop('sampling_rate') # pass inputs with all kwargs in data if parameters is not None: outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens, **parameters) else: outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens) # postprocess the prediction prediction = outputs[0].cpu().numpy() return [{"generated_text": prediction}]