Phoenixak99 commited on
Commit
3570981
·
verified ·
1 Parent(s): 9d2438d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +41 -68
handler.py CHANGED
@@ -1,78 +1,51 @@
1
- # app.py
2
- from fastapi import FastAPI, Request
3
- from handler import EndpointHandler
4
- import json
5
-
6
- app = FastAPI()
7
- handler = None
8
-
9
- @app.on_event("startup")
10
- async def startup_event():
11
- global handler
12
- handler = EndpointHandler()
13
-
14
- @app.post("/")
15
- async def process_request(request: Request):
16
- body = await request.json()
17
- response = handler(body)
18
- return response
19
-
20
- # handler.py
21
  from typing import Dict, Any
22
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
23
  import torch
24
 
25
  class EndpointHandler:
26
- def __init__(self, path="/repository"):
27
- """Initialize the model and processor."""
28
  self.processor = AutoProcessor.from_pretrained(path)
29
  self.model = MusicgenForConditionalGeneration.from_pretrained(
30
- path,
31
- torch_dtype=torch.float16,
32
- device_map="auto"
33
  ).to("cuda")
34
-
 
35
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
36
- """Process the input data and generate audio."""
37
- try:
38
- # Extract inputs and parameters
39
- inputs = data.pop("inputs", data)
40
- parameters = data.pop("parameters", {})
41
-
42
- # Get prompt and duration
43
- prompt = inputs.get("prompt", "")
44
- duration = inputs.get("duration", 30)
45
-
46
- # Calculate max_new_tokens based on duration
47
- samples_per_token = 1024
48
- sampling_rate = 32000
49
- max_new_tokens = int((duration * sampling_rate) / samples_per_token)
50
-
51
- # Process input text
52
- model_inputs = self.processor(
53
- text=[prompt],
54
- padding=True,
55
- return_tensors="pt"
56
- ).to("cuda")
57
-
58
- # Set default generation parameters
59
- generation_params = {
60
- "do_sample": True,
61
- "guidance_scale": 3,
62
- "max_new_tokens": max_new_tokens
 
 
 
 
 
 
 
 
63
  }
64
-
65
- # Update with any user-provided parameters
66
- generation_params.update(parameters)
67
-
68
- # Generate audio with autocast for memory efficiency
69
- with torch.cuda.amp.autocast():
70
- audio_values = self.model.generate(**model_inputs, **generation_params)
71
-
72
- # Convert to list for JSON serialization
73
- audio_data = audio_values.cpu().numpy().tolist()
74
-
75
- return [{"generated_audio": audio_data}]
76
-
77
- except Exception as e:
78
- return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, Any
2
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
  import torch
4
 
5
  class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ # Load the processor and model from the specified path
8
  self.processor = AutoProcessor.from_pretrained(path)
9
  self.model = MusicgenForConditionalGeneration.from_pretrained(
10
+ path, torch_dtype=torch.float16
 
 
11
  ).to("cuda")
12
+ self.sampling_rate = self.model.config.audio_encoder.sampling_rate
13
+
14
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
15
+ """
16
+ Args:
17
+ data (dict): The payload with the text prompt and generation parameters.
18
+ """
19
+ # Extract inputs and parameters from the payload
20
+ inputs = data.get("inputs", {})
21
+ prompt = inputs.get("prompt", "")
22
+ duration = inputs.get("duration", 10)
23
+ parameters = data.get("parameters", {})
24
+
25
+ # Preprocess the prompt
26
+ input_ids = self.processor(
27
+ text=[prompt],
28
+ padding=True,
29
+ return_tensors="pt",
30
+ ).to("cuda")
31
+
32
+ # Set generation parameters
33
+ gen_kwargs = {
34
+ "max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second
35
+ **parameters,
36
+ }
37
+
38
+ # Generate audio
39
+ with torch.autocast("cuda"):
40
+ outputs = self.model.generate(**input_ids, **gen_kwargs)
41
+
42
+ # Convert the output audio tensor to a list of lists (channel-wise)
43
+ audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len]
44
+ audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]]
45
+
46
+ return [
47
+ {
48
+ "generated_audio": audio_list,
49
+ "sample_rate": self.sampling_rate,
50
  }
51
+ ]