Phoenixak99's picture
Update handler.py
7288895 verified
import logging
from typing import Dict, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
# Load the processor and model from the specified path
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(
path, torch_dtype=torch.float16
).to("cuda")
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
def __call__(self, data: Dict[str, Any]) -> Any:
"""
Args:
data (dict): The payload with the text prompt and generation parameters.
"""
try:
# Extract inputs and parameters from the payload
inputs = data.get("inputs", data)
parameters = data.get("parameters", {})
# Handle inputs
if isinstance(inputs, str):
prompt = inputs
duration = 10 # Default duration
elif isinstance(inputs, dict):
prompt = inputs.get("text") or inputs.get("prompt")
duration = inputs.get("duration", 10)
else:
prompt = None
duration = 10
# Override duration if provided in parameters
if 'duration' in parameters:
duration = parameters.pop('duration')
# Validate the prompt
if not prompt:
return {"error": "No prompt provided."}
# Preprocess the prompt
input_ids = self.processor(
text=[prompt],
padding=True,
return_tensors="pt",
).to("cuda")
# Set generation parameters
gen_kwargs = {
"max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second
}
# Filter out unsupported parameters
supported_params = [
"max_length", "min_length", "do_sample", "early_stopping", "num_beams",
"temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
"num_return_sequences", "attention_mask"
]
for param in supported_params:
if param in parameters:
gen_kwargs[param] = parameters[param]
logger.info(f"Received prompt: {prompt}")
logger.info(f"Generation parameters: {gen_kwargs}")
# Generate audio
with torch.autocast("cuda"):
outputs = self.model.generate(**input_ids, **gen_kwargs)
# Convert the output audio tensor to a list of lists (channel-wise)
audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len]
audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]]
return [
{
"generated_audio": audio_list,
"sample_rate": self.sampling_rate,
}
]
except Exception as e:
logger.error(f"Exception during generation: {e}")
return {"error": str(e)}