Phoenixak99
commited on
Update handler.py
Browse files- handler.py +44 -88
handler.py
CHANGED
@@ -2,104 +2,62 @@ import logging
|
|
2 |
from typing import Dict, Any
|
3 |
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
4 |
import torch
|
5 |
-
import gc
|
6 |
|
|
|
7 |
logging.basicConfig(level=logging.INFO)
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, path=""):
|
12 |
-
#
|
13 |
-
|
14 |
-
torch.backends.cudnn.benchmark = True
|
15 |
-
|
16 |
-
# Load processor with optimizations
|
17 |
-
logger.info("Loading processor...")
|
18 |
-
self.processor = AutoProcessor.from_pretrained(
|
19 |
-
path,
|
20 |
-
use_fast=True # Use faster tokenizer
|
21 |
-
)
|
22 |
-
|
23 |
-
logger.info("Loading model...")
|
24 |
self.model = MusicgenForConditionalGeneration.from_pretrained(
|
25 |
-
path,
|
26 |
-
torch_dtype=torch.float16,
|
27 |
-
low_cpu_mem_usage=True
|
28 |
).to("cuda")
|
29 |
-
|
30 |
-
# Set model to eval mode
|
31 |
-
self.model.eval()
|
32 |
-
|
33 |
-
# Cache sampling rate
|
34 |
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
|
35 |
-
|
36 |
-
# Clear CUDA cache
|
37 |
-
torch.cuda.empty_cache()
|
38 |
-
gc.collect()
|
39 |
-
|
40 |
-
# Quick warmup
|
41 |
-
logger.info("Warming up model...")
|
42 |
-
self._warmup()
|
43 |
-
|
44 |
-
def _warmup(self):
|
45 |
-
"""Perform a minimal forward pass to warm up the model"""
|
46 |
-
try:
|
47 |
-
with torch.no_grad():
|
48 |
-
dummy_input = self.processor(
|
49 |
-
text=["test"],
|
50 |
-
padding=True,
|
51 |
-
return_tensors="pt"
|
52 |
-
).to("cuda")
|
53 |
-
|
54 |
-
# Minimal generation
|
55 |
-
self.model.generate(
|
56 |
-
**dummy_input,
|
57 |
-
max_new_tokens=10,
|
58 |
-
do_sample=False
|
59 |
-
)
|
60 |
-
except Exception as e:
|
61 |
-
logger.warning(f"Warmup failed (non-critical): {e}")
|
62 |
|
63 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
|
|
|
|
|
|
|
|
64 |
try:
|
65 |
-
# Extract inputs and parameters
|
66 |
inputs = data.get("inputs", data)
|
67 |
parameters = data.get("parameters", {})
|
68 |
-
|
69 |
-
#
|
70 |
-
if isinstance(inputs,
|
|
|
|
|
|
|
71 |
prompt = inputs.get("text") or inputs.get("prompt")
|
72 |
duration = inputs.get("duration", 10)
|
73 |
else:
|
74 |
-
prompt =
|
75 |
duration = 10
|
76 |
-
|
|
|
77 |
if 'duration' in parameters:
|
78 |
duration = parameters.pop('duration')
|
79 |
-
|
|
|
80 |
if not prompt:
|
81 |
return {"error": "No prompt provided."}
|
82 |
-
|
83 |
-
# Preprocess
|
84 |
input_ids = self.processor(
|
85 |
text=[prompt],
|
86 |
padding=True,
|
87 |
return_tensors="pt",
|
88 |
-
truncation=True,
|
89 |
-
max_length=512 # Limit input length
|
90 |
).to("cuda")
|
91 |
-
|
92 |
-
#
|
93 |
gen_kwargs = {
|
94 |
-
"max_new_tokens": int(duration * 50),
|
95 |
-
"use_cache": True, # Enable KV-cache
|
96 |
-
"do_sample": True,
|
97 |
-
"temperature": 0.8,
|
98 |
-
"top_k": 50,
|
99 |
-
"top_p": 0.95
|
100 |
}
|
101 |
-
|
102 |
-
#
|
103 |
supported_params = [
|
104 |
"max_length", "min_length", "do_sample", "early_stopping", "num_beams",
|
105 |
"temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
|
@@ -108,26 +66,24 @@ class EndpointHandler:
|
|
108 |
for param in supported_params:
|
109 |
if param in parameters:
|
110 |
gen_kwargs[param] = parameters[param]
|
111 |
-
|
112 |
-
logger.info(f"
|
113 |
logger.info(f"Generation parameters: {gen_kwargs}")
|
114 |
-
|
115 |
-
# Generate
|
116 |
-
with torch.
|
117 |
outputs = self.model.generate(**input_ids, **gen_kwargs)
|
118 |
-
|
119 |
-
# Convert output
|
120 |
-
audio_tensor = outputs[0].cpu()
|
121 |
-
audio_list = audio_tensor.numpy().tolist()
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
}]
|
130 |
-
|
131 |
except Exception as e:
|
132 |
-
logger.error(f"
|
133 |
return {"error": str(e)}
|
|
|
2 |
from typing import Dict, Any
|
3 |
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
4 |
import torch
|
|
|
5 |
|
6 |
+
# Set up logging
|
7 |
logging.basicConfig(level=logging.INFO)
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, path=""):
|
12 |
+
# Load the processor and model from the specified path
|
13 |
+
self.processor = AutoProcessor.from_pretrained(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
self.model = MusicgenForConditionalGeneration.from_pretrained(
|
15 |
+
path, torch_dtype=torch.float16
|
|
|
|
|
16 |
).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
17 |
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
data (dict): The payload with the text prompt and generation parameters.
|
23 |
+
"""
|
24 |
try:
|
25 |
+
# Extract inputs and parameters from the payload
|
26 |
inputs = data.get("inputs", data)
|
27 |
parameters = data.get("parameters", {})
|
28 |
+
|
29 |
+
# Handle inputs
|
30 |
+
if isinstance(inputs, str):
|
31 |
+
prompt = inputs
|
32 |
+
duration = 10 # Default duration
|
33 |
+
elif isinstance(inputs, dict):
|
34 |
prompt = inputs.get("text") or inputs.get("prompt")
|
35 |
duration = inputs.get("duration", 10)
|
36 |
else:
|
37 |
+
prompt = None
|
38 |
duration = 10
|
39 |
+
|
40 |
+
# Override duration if provided in parameters
|
41 |
if 'duration' in parameters:
|
42 |
duration = parameters.pop('duration')
|
43 |
+
|
44 |
+
# Validate the prompt
|
45 |
if not prompt:
|
46 |
return {"error": "No prompt provided."}
|
47 |
+
|
48 |
+
# Preprocess the prompt
|
49 |
input_ids = self.processor(
|
50 |
text=[prompt],
|
51 |
padding=True,
|
52 |
return_tensors="pt",
|
|
|
|
|
53 |
).to("cuda")
|
54 |
+
|
55 |
+
# Set generation parameters
|
56 |
gen_kwargs = {
|
57 |
+
"max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second
|
|
|
|
|
|
|
|
|
|
|
58 |
}
|
59 |
+
|
60 |
+
# Filter out unsupported parameters
|
61 |
supported_params = [
|
62 |
"max_length", "min_length", "do_sample", "early_stopping", "num_beams",
|
63 |
"temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
|
|
|
66 |
for param in supported_params:
|
67 |
if param in parameters:
|
68 |
gen_kwargs[param] = parameters[param]
|
69 |
+
|
70 |
+
logger.info(f"Received prompt: {prompt}")
|
71 |
logger.info(f"Generation parameters: {gen_kwargs}")
|
72 |
+
|
73 |
+
# Generate audio
|
74 |
+
with torch.autocast("cuda"):
|
75 |
outputs = self.model.generate(**input_ids, **gen_kwargs)
|
76 |
+
|
77 |
+
# Convert the output audio tensor to a list of lists (channel-wise)
|
78 |
+
audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len]
|
79 |
+
audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]]
|
80 |
+
|
81 |
+
return [
|
82 |
+
{
|
83 |
+
"generated_audio": audio_list,
|
84 |
+
"sample_rate": self.sampling_rate,
|
85 |
+
}
|
86 |
+
]
|
|
|
|
|
87 |
except Exception as e:
|
88 |
+
logger.error(f"Exception during generation: {e}")
|
89 |
return {"error": str(e)}
|