dthomas84 commited on
Commit
348df64
1 Parent(s): 3cc2d46

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +24 -8
handler.py CHANGED
@@ -2,11 +2,18 @@ from typing import Dict, List, Any
2
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
  import torch
4
 
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  # load model and processor from path
8
  self.processor = AutoProcessor.from_pretrained(path)
9
- self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
 
 
 
 
 
 
10
 
11
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
12
  """
@@ -17,22 +24,31 @@ class EndpointHandler:
17
  # process input
18
  inputs = data.pop("inputs", data)
19
  parameters = data.pop("parameters", None)
 
 
 
 
 
 
 
20
 
21
  # preprocess
22
  inputs = self.processor(
23
  text=[inputs],
24
  padding=True,
25
- return_tensors="pt",).to("cuda")
 
 
 
 
26
 
27
  # pass inputs with all kwargs in data
28
  if parameters is not None:
29
- with torch.autocast("cuda"):
30
- outputs = self.model.generate(**inputs, **parameters, do_sample=True, guidance_scale=3)
31
  else:
32
- with torch.autocast("cuda"):
33
- outputs = self.model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=450)
34
 
35
  # postprocess the prediction
36
- prediction = outputs[0].cpu().numpy().tolist()
37
 
38
- return [{"generated_audio": prediction}]
 
2
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
  import torch
4
 
5
+
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
  # load model and processor from path
9
  self.processor = AutoProcessor.from_pretrained(path)
10
+
11
+ # Check if CUDA is available, and set the device accordingly
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # Load the model to the device
15
+ self.model = MusicgenForConditionalGeneration.from_pretrained(path)
16
+ self.model.to(self.device) # Correcting this line
17
 
18
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
19
  """
 
24
  # process input
25
  inputs = data.pop("inputs", data)
26
  parameters = data.pop("parameters", None)
27
+ duration = parameters.pop("duration", None)
28
+
29
+ if duration is not None:
30
+ # Calculate max new tokens based on duration, this is a placeholder, replace with actual logic
31
+ max_new_tokens = int(duration * 50)
32
+ else:
33
+ max_new_tokens = 256 # Default value if duration is not provided
34
 
35
  # preprocess
36
  inputs = self.processor(
37
  text=[inputs],
38
  padding=True,
39
+ return_tensors="pt",).to(self.device)
40
+
41
+ # If 'duration' is inside 'parameters', remove it
42
+ if parameters is not None and 'duration' in parameters:
43
+ parameters.pop('duration')
44
 
45
  # pass inputs with all kwargs in data
46
  if parameters is not None:
47
+ outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens, **parameters)
 
48
  else:
49
+ outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
 
50
 
51
  # postprocess the prediction
52
+ prediction = outputs[0].cpu().numpy()
53
 
54
+ return [{"generated_text": prediction}]