musicgen-medium / handler.py
dthomas84's picture
Update handler.py
cc7877d verified
raw
history blame
2.58 kB
from typing import Dict, List, Any
import json
import numpy as np
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = AutoProcessor.from_pretrained(path)
# Check if CUDA is available, and set the device accordingly
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model to the device
self.model = MusicgenForConditionalGeneration.from_pretrained(path)
self.model.to(self.device) # Correcting this line
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
duration = parameters.pop("duration", None)
audio = parameters.pop("audio", None)
sampling_rate = parameters.pop("sampling_rate", None)
if audio is not None:
audio_list = json.loads(audio)
audio_array = np.array(audio_list)
audio = audio_array
if duration is not None:
# Calculate max new tokens based on duration, this is a placeholder, replace with actual logic
max_new_tokens = int(duration * 50)
else:
max_new_tokens = 256 # Default value if duration is not provided
# preprocess
inputs = self.processor(
text=[inputs],
padding=True,
return_tensors="pt",
audio=audio,
sampling_rate=sampling_rate).to(self.device)
# If 'duration' is inside 'parameters', remove it
if parameters is not None and 'duration' in parameters:
parameters.pop('duration')
if parameters is not None and 'audio' in parameters:
parameters.pop('audio')
if parameters is not None and 'sampling_rate' in parameters:
parameters.pop('sampling_rate')
# pass inputs with all kwargs in data
if parameters is not None:
outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens, **parameters)
else:
outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
# postprocess the prediction
prediction = outputs[0].cpu().numpy()
return [{"generated_text": prediction}]