Phoenixak99
commited on
Update handler.py
Browse files- handler.py +88 -44
handler.py
CHANGED
@@ -2,62 +2,104 @@ import logging
|
|
2 |
from typing import Dict, Any
|
3 |
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
4 |
import torch
|
|
|
5 |
|
6 |
-
# Set up logging
|
7 |
logging.basicConfig(level=logging.INFO)
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, path=""):
|
12 |
-
#
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
self.model = MusicgenForConditionalGeneration.from_pretrained(
|
15 |
-
path,
|
|
|
|
|
16 |
).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
17 |
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
20 |
-
"""
|
21 |
-
Args:
|
22 |
-
data (dict): The payload with the text prompt and generation parameters.
|
23 |
-
"""
|
24 |
try:
|
25 |
-
# Extract inputs and parameters
|
26 |
inputs = data.get("inputs", data)
|
27 |
parameters = data.get("parameters", {})
|
28 |
|
29 |
-
#
|
30 |
-
if isinstance(inputs,
|
31 |
-
prompt = inputs
|
32 |
-
duration = 10 # Default duration
|
33 |
-
elif isinstance(inputs, dict):
|
34 |
prompt = inputs.get("text") or inputs.get("prompt")
|
35 |
duration = inputs.get("duration", 10)
|
36 |
else:
|
37 |
-
prompt = None
|
38 |
duration = 10
|
39 |
-
|
40 |
-
# Override duration if provided in parameters
|
41 |
if 'duration' in parameters:
|
42 |
duration = parameters.pop('duration')
|
43 |
-
|
44 |
-
# Validate the prompt
|
45 |
if not prompt:
|
46 |
return {"error": "No prompt provided."}
|
47 |
-
|
48 |
-
# Preprocess
|
49 |
input_ids = self.processor(
|
50 |
text=[prompt],
|
51 |
padding=True,
|
52 |
return_tensors="pt",
|
|
|
|
|
53 |
).to("cuda")
|
54 |
-
|
55 |
-
#
|
56 |
gen_kwargs = {
|
57 |
-
"max_new_tokens": int(duration * 50),
|
|
|
|
|
|
|
|
|
|
|
58 |
}
|
59 |
-
|
60 |
-
#
|
61 |
supported_params = [
|
62 |
"max_length", "min_length", "do_sample", "early_stopping", "num_beams",
|
63 |
"temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
|
@@ -66,24 +108,26 @@ class EndpointHandler:
|
|
66 |
for param in supported_params:
|
67 |
if param in parameters:
|
68 |
gen_kwargs[param] = parameters[param]
|
69 |
-
|
70 |
-
logger.info(f"
|
71 |
logger.info(f"Generation parameters: {gen_kwargs}")
|
72 |
-
|
73 |
-
# Generate
|
74 |
-
with torch.autocast("cuda"):
|
75 |
outputs = self.model.generate(**input_ids, **gen_kwargs)
|
76 |
-
|
77 |
-
# Convert
|
78 |
-
audio_tensor = outputs[0].cpu()
|
79 |
-
audio_list = audio_tensor.numpy().tolist()
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
87 |
except Exception as e:
|
88 |
-
logger.error(f"
|
89 |
-
return {"error": str(e)}
|
|
|
2 |
from typing import Dict, Any
|
3 |
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
4 |
import torch
|
5 |
+
import gc
|
6 |
|
|
|
7 |
logging.basicConfig(level=logging.INFO)
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, path=""):
|
12 |
+
# Enable CUDA optimization
|
13 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
14 |
+
torch.backends.cudnn.benchmark = True
|
15 |
+
|
16 |
+
# Load processor with optimizations
|
17 |
+
logger.info("Loading processor...")
|
18 |
+
self.processor = AutoProcessor.from_pretrained(
|
19 |
+
path,
|
20 |
+
use_fast=True # Use faster tokenizer
|
21 |
+
)
|
22 |
+
|
23 |
+
logger.info("Loading model...")
|
24 |
self.model = MusicgenForConditionalGeneration.from_pretrained(
|
25 |
+
path,
|
26 |
+
torch_dtype=torch.float16,
|
27 |
+
low_cpu_mem_usage=True
|
28 |
).to("cuda")
|
29 |
+
|
30 |
+
# Set model to eval mode
|
31 |
+
self.model.eval()
|
32 |
+
|
33 |
+
# Cache sampling rate
|
34 |
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
|
35 |
+
|
36 |
+
# Clear CUDA cache
|
37 |
+
torch.cuda.empty_cache()
|
38 |
+
gc.collect()
|
39 |
+
|
40 |
+
# Quick warmup
|
41 |
+
logger.info("Warming up model...")
|
42 |
+
self._warmup()
|
43 |
+
|
44 |
+
def _warmup(self):
|
45 |
+
"""Perform a minimal forward pass to warm up the model"""
|
46 |
+
try:
|
47 |
+
with torch.no_grad():
|
48 |
+
dummy_input = self.processor(
|
49 |
+
text=["test"],
|
50 |
+
padding=True,
|
51 |
+
return_tensors="pt"
|
52 |
+
).to("cuda")
|
53 |
+
|
54 |
+
# Minimal generation
|
55 |
+
self.model.generate(
|
56 |
+
**dummy_input,
|
57 |
+
max_new_tokens=10,
|
58 |
+
do_sample=False
|
59 |
+
)
|
60 |
+
except Exception as e:
|
61 |
+
logger.warning(f"Warmup failed (non-critical): {e}")
|
62 |
|
63 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
|
|
|
|
|
|
|
|
64 |
try:
|
65 |
+
# Extract inputs and parameters
|
66 |
inputs = data.get("inputs", data)
|
67 |
parameters = data.get("parameters", {})
|
68 |
|
69 |
+
# Efficient input handling
|
70 |
+
if isinstance(inputs, dict):
|
|
|
|
|
|
|
71 |
prompt = inputs.get("text") or inputs.get("prompt")
|
72 |
duration = inputs.get("duration", 10)
|
73 |
else:
|
74 |
+
prompt = inputs if isinstance(inputs, str) else None
|
75 |
duration = 10
|
76 |
+
|
|
|
77 |
if 'duration' in parameters:
|
78 |
duration = parameters.pop('duration')
|
79 |
+
|
|
|
80 |
if not prompt:
|
81 |
return {"error": "No prompt provided."}
|
82 |
+
|
83 |
+
# Preprocess with optimized settings
|
84 |
input_ids = self.processor(
|
85 |
text=[prompt],
|
86 |
padding=True,
|
87 |
return_tensors="pt",
|
88 |
+
truncation=True,
|
89 |
+
max_length=512 # Limit input length
|
90 |
).to("cuda")
|
91 |
+
|
92 |
+
# Optimized generation settings
|
93 |
gen_kwargs = {
|
94 |
+
"max_new_tokens": int(duration * 50),
|
95 |
+
"use_cache": True, # Enable KV-cache
|
96 |
+
"do_sample": True,
|
97 |
+
"temperature": 0.8,
|
98 |
+
"top_k": 50,
|
99 |
+
"top_p": 0.95
|
100 |
}
|
101 |
+
|
102 |
+
# Add any custom parameters
|
103 |
supported_params = [
|
104 |
"max_length", "min_length", "do_sample", "early_stopping", "num_beams",
|
105 |
"temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
|
|
|
108 |
for param in supported_params:
|
109 |
if param in parameters:
|
110 |
gen_kwargs[param] = parameters[param]
|
111 |
+
|
112 |
+
logger.info(f"Generating with prompt: {prompt}")
|
113 |
logger.info(f"Generation parameters: {gen_kwargs}")
|
114 |
+
|
115 |
+
# Generate with optimized settings
|
116 |
+
with torch.inference_mode(), torch.autocast("cuda"):
|
117 |
outputs = self.model.generate(**input_ids, **gen_kwargs)
|
118 |
+
|
119 |
+
# Convert output
|
120 |
+
audio_tensor = outputs[0].cpu()
|
121 |
+
audio_list = audio_tensor.numpy().tolist()
|
122 |
+
|
123 |
+
# Clear cache
|
124 |
+
torch.cuda.empty_cache()
|
125 |
+
|
126 |
+
return [{
|
127 |
+
"generated_audio": audio_list,
|
128 |
+
"sample_rate": self.sampling_rate,
|
129 |
+
}]
|
130 |
+
|
131 |
except Exception as e:
|
132 |
+
logger.error(f"Generation failed: {e}")
|
133 |
+
return {"error": str(e)}
|