Update audiosr/pipeline.py
Browse files- audiosr/pipeline.py +4 -2
audiosr/pipeline.py
CHANGED
@@ -111,7 +111,7 @@ def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
|
|
111 |
def round_up_duration(duration):
|
112 |
return int(round(duration / 2.5) + 1) * 2.5
|
113 |
|
114 |
-
|
115 |
def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
|
116 |
if device is None or device == "auto":
|
117 |
if torch.cuda.is_available():
|
@@ -150,6 +150,7 @@ def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
|
|
150 |
|
151 |
return latent_diffusion
|
152 |
|
|
|
153 |
|
154 |
def super_resolution(
|
155 |
latent_diffusion,
|
@@ -165,12 +166,13 @@ def super_resolution(
|
|
165 |
|
166 |
batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform)
|
167 |
|
|
|
168 |
with torch.no_grad():
|
169 |
waveform = latent_diffusion.generate_batch(
|
170 |
batch,
|
171 |
unconditional_guidance_scale=guidance_scale,
|
172 |
ddim_steps=ddim_steps,
|
173 |
-
duration=duration,
|
174 |
)
|
175 |
|
176 |
return waveform
|
|
|
|
111 |
def round_up_duration(duration):
|
112 |
return int(round(duration / 2.5) + 1) * 2.5
|
113 |
|
114 |
+
|
115 |
def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
|
116 |
if device is None or device == "auto":
|
117 |
if torch.cuda.is_available():
|
|
|
150 |
|
151 |
return latent_diffusion
|
152 |
|
153 |
+
@spaces.GPU
|
154 |
|
155 |
def super_resolution(
|
156 |
latent_diffusion,
|
|
|
166 |
|
167 |
batch, duration = make_batch_for_super_resolution(input_file, waveform=waveform)
|
168 |
|
169 |
+
with torch.no_grad():
|
170 |
with torch.no_grad():
|
171 |
waveform = latent_diffusion.generate_batch(
|
172 |
batch,
|
173 |
unconditional_guidance_scale=guidance_scale,
|
174 |
ddim_steps=ddim_steps,
|
|
|
175 |
)
|
176 |
|
177 |
return waveform
|
178 |
+
|