File size: 3,298 Bytes
30b75e1
48e4064
 
 
 
7288895
30b75e1
 
 
48e4064
3570981
7288895
 
48e4064
7288895
48e4064
3570981
 
30b75e1
7288895
 
 
 
30b75e1
7288895
30b75e1
 
7288895
 
 
 
 
 
30b75e1
 
 
7288895
30b75e1
7288895
 
30b75e1
 
7288895
 
30b75e1
 
7288895
 
30b75e1
 
 
 
 
7288895
 
30b75e1
7288895
06c68e1
7288895
 
30b75e1
 
 
 
 
 
 
 
7288895
 
30b75e1
7288895
 
 
30b75e1
7288895
 
 
 
 
 
 
 
 
 
 
30b75e1
7288895
dc480b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import logging
from typing import Dict, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, path=""):
        # Load the processor and model from the specified path
        self.processor = AutoProcessor.from_pretrained(path)
        self.model = MusicgenForConditionalGeneration.from_pretrained(
            path, torch_dtype=torch.float16
        ).to("cuda")
        self.sampling_rate = self.model.config.audio_encoder.sampling_rate

    def __call__(self, data: Dict[str, Any]) -> Any:
        """
        Args:
            data (dict): The payload with the text prompt and generation parameters.
        """
        try:
            # Extract inputs and parameters from the payload
            inputs = data.get("inputs", data)
            parameters = data.get("parameters", {})

            # Handle inputs
            if isinstance(inputs, str):
                prompt = inputs
                duration = 10  # Default duration
            elif isinstance(inputs, dict):
                prompt = inputs.get("text") or inputs.get("prompt")
                duration = inputs.get("duration", 10)
            else:
                prompt = None
                duration = 10

            # Override duration if provided in parameters
            if 'duration' in parameters:
                duration = parameters.pop('duration')

            # Validate the prompt
            if not prompt:
                return {"error": "No prompt provided."}

            # Preprocess the prompt
            input_ids = self.processor(
                text=[prompt],
                padding=True,
                return_tensors="pt",
            ).to("cuda")

            # Set generation parameters
            gen_kwargs = {
                "max_new_tokens": int(duration * 50),  # MusicGen uses 50 tokens per second
            }

            # Filter out unsupported parameters
            supported_params = [
                "max_length", "min_length", "do_sample", "early_stopping", "num_beams",
                "temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
                "num_return_sequences", "attention_mask"
            ]
            for param in supported_params:
                if param in parameters:
                    gen_kwargs[param] = parameters[param]

            logger.info(f"Received prompt: {prompt}")
            logger.info(f"Generation parameters: {gen_kwargs}")

            # Generate audio
            with torch.autocast("cuda"):
                outputs = self.model.generate(**input_ids, **gen_kwargs)

            # Convert the output audio tensor to a list of lists (channel-wise)
            audio_tensor = outputs[0].cpu()  # Shape: [num_channels, seq_len]
            audio_list = audio_tensor.numpy().tolist()  # [[channel1_data], [channel2_data]]

            return [
                {
                    "generated_audio": audio_list,
                    "sample_rate": self.sampling_rate,
                }
            ]
        except Exception as e:
            logger.error(f"Exception during generation: {e}")
            return {"error": str(e)}