File size: 3,298 Bytes
30b75e1 48e4064 7288895 30b75e1 48e4064 3570981 7288895 48e4064 7288895 48e4064 3570981 30b75e1 7288895 30b75e1 7288895 30b75e1 7288895 30b75e1 7288895 30b75e1 7288895 30b75e1 7288895 30b75e1 7288895 30b75e1 7288895 30b75e1 7288895 06c68e1 7288895 30b75e1 7288895 30b75e1 7288895 30b75e1 7288895 30b75e1 7288895 dc480b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import logging
from typing import Dict, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
# Load the processor and model from the specified path
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(
path, torch_dtype=torch.float16
).to("cuda")
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
def __call__(self, data: Dict[str, Any]) -> Any:
"""
Args:
data (dict): The payload with the text prompt and generation parameters.
"""
try:
# Extract inputs and parameters from the payload
inputs = data.get("inputs", data)
parameters = data.get("parameters", {})
# Handle inputs
if isinstance(inputs, str):
prompt = inputs
duration = 10 # Default duration
elif isinstance(inputs, dict):
prompt = inputs.get("text") or inputs.get("prompt")
duration = inputs.get("duration", 10)
else:
prompt = None
duration = 10
# Override duration if provided in parameters
if 'duration' in parameters:
duration = parameters.pop('duration')
# Validate the prompt
if not prompt:
return {"error": "No prompt provided."}
# Preprocess the prompt
input_ids = self.processor(
text=[prompt],
padding=True,
return_tensors="pt",
).to("cuda")
# Set generation parameters
gen_kwargs = {
"max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second
}
# Filter out unsupported parameters
supported_params = [
"max_length", "min_length", "do_sample", "early_stopping", "num_beams",
"temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
"num_return_sequences", "attention_mask"
]
for param in supported_params:
if param in parameters:
gen_kwargs[param] = parameters[param]
logger.info(f"Received prompt: {prompt}")
logger.info(f"Generation parameters: {gen_kwargs}")
# Generate audio
with torch.autocast("cuda"):
outputs = self.model.generate(**input_ids, **gen_kwargs)
# Convert the output audio tensor to a list of lists (channel-wise)
audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len]
audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]]
return [
{
"generated_audio": audio_list,
"sample_rate": self.sampling_rate,
}
]
except Exception as e:
logger.error(f"Exception during generation: {e}")
return {"error": str(e)} |