Text-to-Audio
Inference Endpoints
hungchiayu commited on
Commit
83283ea
·
verified ·
1 Parent(s): 609bc9a

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +44 -0
handler.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from tangoflux import TangoFluxInference
3
+ import torchaudio
4
+ from huggingface_inference_toolkit.logging import logger
5
+
6
+ class EndpointHandler():
7
+ def __init__(self, path=""):
8
+ # Preload all the elements you are going to need at inference.
9
+ # pseudo:
10
+ # self.model= load_model(path)
11
+ self.model = TangoFluxInference(name='declare-lab/TangoFlux',device='cuda')
12
+
13
+
14
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
+ """
16
+ data args:
17
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
18
+ kwargs
19
+ Return:
20
+ A :obj:`list` | `dict`: will be serialized and returned
21
+ """
22
+
23
+ logger.info(f"Received incoming request with {data=}")
24
+
25
+ if "inputs" in data and isinstance(data["inputs"], str):
26
+ prompt = data.pop("inputs")
27
+ elif "prompt" in data and isinstance(data["prompt"], str):
28
+ prompt = data.pop("prompt")
29
+ else:
30
+ raise ValueError(
31
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
32
+ " prompt to use for the audio generation, and it needs to be a non-empty string."
33
+ )
34
+
35
+ parameters = data.pop("parameters", {})
36
+
37
+ num_inference_steps = parameters.get("num_inference_steps", 50)
38
+ duration = parameters.get("duration", 10)
39
+ guidance_scale = parameters.get("guidance_scale", 3.5)
40
+
41
+
42
+ return self.model.generate(prompt,steps=num_inference_steps,
43
+ duration=duration,
44
+ guidance_scale=guidance_scale)