added create_params func
Browse files- handler.py +27 -0
handler.py
CHANGED
@@ -5,6 +5,33 @@ import torch
|
|
5 |
import io
|
6 |
import base64
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
class EndpointHandler:
|
9 |
def __init__(self, path="pbotsaris/musicgen-small"):
|
10 |
self.processor = AutoProcessor.from_pretrained(path)
|
|
|
5 |
import io
|
6 |
import base64
|
7 |
|
8 |
+
def create_params(params, fr):
|
9 |
+
|
10 |
+
# default
|
11 |
+
out = { "do_sample": True,
|
12 |
+
"guidance_scale": 3,
|
13 |
+
"max_new_tokens": 256
|
14 |
+
}
|
15 |
+
|
16 |
+
has_tokens = False
|
17 |
+
|
18 |
+
if params is None:
|
19 |
+
return out
|
20 |
+
|
21 |
+
if 'duration' in params:
|
22 |
+
out['max_new_tokens'] = params['duration'] * fr
|
23 |
+
has_tokens = True
|
24 |
+
|
25 |
+
for k, p in params.items():
|
26 |
+
if k in out:
|
27 |
+
if has_tokens and k == 'max_new_tokens':
|
28 |
+
continue
|
29 |
+
|
30 |
+
out[k] = p
|
31 |
+
|
32 |
+
return out
|
33 |
+
|
34 |
+
|
35 |
class EndpointHandler:
|
36 |
def __init__(self, path="pbotsaris/musicgen-small"):
|
37 |
self.processor = AutoProcessor.from_pretrained(path)
|