Phoenixak99 commited on
Commit
dc480b5
·
verified ·
1 Parent(s): 30b75e1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +88 -44
handler.py CHANGED
@@ -2,62 +2,104 @@ import logging
2
  from typing import Dict, Any
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import torch
 
5
 
6
- # Set up logging
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
- # Load the processor and model from the specified path
13
- self.processor = AutoProcessor.from_pretrained(path)
 
 
 
 
 
 
 
 
 
 
14
  self.model = MusicgenForConditionalGeneration.from_pretrained(
15
- path, torch_dtype=torch.float16
 
 
16
  ).to("cuda")
 
 
 
 
 
17
  self.sampling_rate = self.model.config.audio_encoder.sampling_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def __call__(self, data: Dict[str, Any]) -> Any:
20
- """
21
- Args:
22
- data (dict): The payload with the text prompt and generation parameters.
23
- """
24
  try:
25
- # Extract inputs and parameters from the payload
26
  inputs = data.get("inputs", data)
27
  parameters = data.get("parameters", {})
28
 
29
- # Handle inputs
30
- if isinstance(inputs, str):
31
- prompt = inputs
32
- duration = 10 # Default duration
33
- elif isinstance(inputs, dict):
34
  prompt = inputs.get("text") or inputs.get("prompt")
35
  duration = inputs.get("duration", 10)
36
  else:
37
- prompt = None
38
  duration = 10
39
-
40
- # Override duration if provided in parameters
41
  if 'duration' in parameters:
42
  duration = parameters.pop('duration')
43
-
44
- # Validate the prompt
45
  if not prompt:
46
  return {"error": "No prompt provided."}
47
-
48
- # Preprocess the prompt
49
  input_ids = self.processor(
50
  text=[prompt],
51
  padding=True,
52
  return_tensors="pt",
 
 
53
  ).to("cuda")
54
-
55
- # Set generation parameters
56
  gen_kwargs = {
57
- "max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second
 
 
 
 
 
58
  }
59
-
60
- # Filter out unsupported parameters
61
  supported_params = [
62
  "max_length", "min_length", "do_sample", "early_stopping", "num_beams",
63
  "temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
@@ -66,24 +108,26 @@ class EndpointHandler:
66
  for param in supported_params:
67
  if param in parameters:
68
  gen_kwargs[param] = parameters[param]
69
-
70
- logger.info(f"Received prompt: {prompt}")
71
  logger.info(f"Generation parameters: {gen_kwargs}")
72
-
73
- # Generate audio
74
- with torch.autocast("cuda"):
75
  outputs = self.model.generate(**input_ids, **gen_kwargs)
76
-
77
- # Convert the output audio tensor to a list of lists (channel-wise)
78
- audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len]
79
- audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]]
80
-
81
- return [
82
- {
83
- "generated_audio": audio_list,
84
- "sample_rate": self.sampling_rate,
85
- }
86
- ]
 
 
87
  except Exception as e:
88
- logger.error(f"Exception during generation: {e}")
89
- return {"error": str(e)}
 
2
  from typing import Dict, Any
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import torch
5
+ import gc
6
 
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
+ # Enable CUDA optimization
13
+ torch.backends.cuda.matmul.allow_tf32 = True
14
+ torch.backends.cudnn.benchmark = True
15
+
16
+ # Load processor with optimizations
17
+ logger.info("Loading processor...")
18
+ self.processor = AutoProcessor.from_pretrained(
19
+ path,
20
+ use_fast=True # Use faster tokenizer
21
+ )
22
+
23
+ logger.info("Loading model...")
24
  self.model = MusicgenForConditionalGeneration.from_pretrained(
25
+ path,
26
+ torch_dtype=torch.float16,
27
+ low_cpu_mem_usage=True
28
  ).to("cuda")
29
+
30
+ # Set model to eval mode
31
+ self.model.eval()
32
+
33
+ # Cache sampling rate
34
  self.sampling_rate = self.model.config.audio_encoder.sampling_rate
35
+
36
+ # Clear CUDA cache
37
+ torch.cuda.empty_cache()
38
+ gc.collect()
39
+
40
+ # Quick warmup
41
+ logger.info("Warming up model...")
42
+ self._warmup()
43
+
44
+ def _warmup(self):
45
+ """Perform a minimal forward pass to warm up the model"""
46
+ try:
47
+ with torch.no_grad():
48
+ dummy_input = self.processor(
49
+ text=["test"],
50
+ padding=True,
51
+ return_tensors="pt"
52
+ ).to("cuda")
53
+
54
+ # Minimal generation
55
+ self.model.generate(
56
+ **dummy_input,
57
+ max_new_tokens=10,
58
+ do_sample=False
59
+ )
60
+ except Exception as e:
61
+ logger.warning(f"Warmup failed (non-critical): {e}")
62
 
63
  def __call__(self, data: Dict[str, Any]) -> Any:
 
 
 
 
64
  try:
65
+ # Extract inputs and parameters
66
  inputs = data.get("inputs", data)
67
  parameters = data.get("parameters", {})
68
 
69
+ # Efficient input handling
70
+ if isinstance(inputs, dict):
 
 
 
71
  prompt = inputs.get("text") or inputs.get("prompt")
72
  duration = inputs.get("duration", 10)
73
  else:
74
+ prompt = inputs if isinstance(inputs, str) else None
75
  duration = 10
76
+
 
77
  if 'duration' in parameters:
78
  duration = parameters.pop('duration')
79
+
 
80
  if not prompt:
81
  return {"error": "No prompt provided."}
82
+
83
+ # Preprocess with optimized settings
84
  input_ids = self.processor(
85
  text=[prompt],
86
  padding=True,
87
  return_tensors="pt",
88
+ truncation=True,
89
+ max_length=512 # Limit input length
90
  ).to("cuda")
91
+
92
+ # Optimized generation settings
93
  gen_kwargs = {
94
+ "max_new_tokens": int(duration * 50),
95
+ "use_cache": True, # Enable KV-cache
96
+ "do_sample": True,
97
+ "temperature": 0.8,
98
+ "top_k": 50,
99
+ "top_p": 0.95
100
  }
101
+
102
+ # Add any custom parameters
103
  supported_params = [
104
  "max_length", "min_length", "do_sample", "early_stopping", "num_beams",
105
  "temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
 
108
  for param in supported_params:
109
  if param in parameters:
110
  gen_kwargs[param] = parameters[param]
111
+
112
+ logger.info(f"Generating with prompt: {prompt}")
113
  logger.info(f"Generation parameters: {gen_kwargs}")
114
+
115
+ # Generate with optimized settings
116
+ with torch.inference_mode(), torch.autocast("cuda"):
117
  outputs = self.model.generate(**input_ids, **gen_kwargs)
118
+
119
+ # Convert output
120
+ audio_tensor = outputs[0].cpu()
121
+ audio_list = audio_tensor.numpy().tolist()
122
+
123
+ # Clear cache
124
+ torch.cuda.empty_cache()
125
+
126
+ return [{
127
+ "generated_audio": audio_list,
128
+ "sample_rate": self.sampling_rate,
129
+ }]
130
+
131
  except Exception as e:
132
+ logger.error(f"Generation failed: {e}")
133
+ return {"error": str(e)}