Spaces:
Running
on
T4
Running
on
T4
try to figure out how ZeroGPU works
Browse files- InferenceInterfaces/ToucanTTSInterface.py +8 -1
- app.py +1 -14
InferenceInterfaces/ToucanTTSInterface.py
CHANGED
@@ -8,7 +8,7 @@ import pyloudnorm
|
|
8 |
import sounddevice
|
9 |
import soundfile
|
10 |
import torch
|
11 |
-
|
12 |
with warnings.catch_warnings():
|
13 |
warnings.simplefilter("ignore")
|
14 |
from speechbrain.pretrained import EncoderClassifier
|
@@ -127,6 +127,7 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
127 |
|
128 |
self.lang_id = get_language_id(lang_id).to(self.device)
|
129 |
|
|
|
130 |
def forward(self,
|
131 |
text,
|
132 |
view=False,
|
@@ -152,6 +153,10 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
152 |
1.0 means no scaling happens, higher values increase variance of the energy curve,
|
153 |
lower values decrease variance of the energy curve.
|
154 |
"""
|
|
|
|
|
|
|
|
|
155 |
with torch.inference_mode():
|
156 |
phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
|
157 |
mel, durations, pitch, energy = self.phone2mel(phones,
|
@@ -223,6 +228,8 @@ class ToucanTTSInterface(torch.nn.Module):
|
|
223 |
if return_plot_as_filepath:
|
224 |
plt.savefig("tmp.png")
|
225 |
return wave, sr, "tmp.png"
|
|
|
|
|
226 |
return wave, sr
|
227 |
|
228 |
def read_to_file(self,
|
|
|
8 |
import sounddevice
|
9 |
import soundfile
|
10 |
import torch
|
11 |
+
import spaces
|
12 |
with warnings.catch_warnings():
|
13 |
warnings.simplefilter("ignore")
|
14 |
from speechbrain.pretrained import EncoderClassifier
|
|
|
127 |
|
128 |
self.lang_id = get_language_id(lang_id).to(self.device)
|
129 |
|
130 |
+
@spaces.GPU
|
131 |
def forward(self,
|
132 |
text,
|
133 |
view=False,
|
|
|
153 |
1.0 means no scaling happens, higher values increase variance of the energy curve,
|
154 |
lower values decrease variance of the energy curve.
|
155 |
"""
|
156 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
157 |
+
self.device = device
|
158 |
+
self.to(device)
|
159 |
+
|
160 |
with torch.inference_mode():
|
161 |
phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
|
162 |
mel, durations, pitch, energy = self.phone2mel(phones,
|
|
|
228 |
if return_plot_as_filepath:
|
229 |
plt.savefig("tmp.png")
|
230 |
return wave, sr, "tmp.png"
|
231 |
+
self.to("cpu")
|
232 |
+
self.device="cpu"
|
233 |
return wave, sr
|
234 |
|
235 |
def read_to_file(self,
|
app.py
CHANGED
@@ -35,7 +35,6 @@ class ControllableInterface(torch.nn.Module):
|
|
35 |
self.model.device = "cpu"
|
36 |
self.wgan.to("cpu")
|
37 |
self.wgan.device = "cpu"
|
38 |
-
self._modules = []
|
39 |
|
40 |
def read(self,
|
41 |
prompt,
|
@@ -123,7 +122,6 @@ class ControllableInterface(torch.nn.Module):
|
|
123 |
|
124 |
|
125 |
|
126 |
-
@spaces.GPU
|
127 |
def read(prompt,
|
128 |
language,
|
129 |
voice_seed,
|
@@ -133,13 +131,7 @@ def read(prompt,
|
|
133 |
emb1,
|
134 |
emb2
|
135 |
):
|
136 |
-
|
137 |
-
controllable_ui.to("cuda")
|
138 |
-
controllable_ui.device = "cuda"
|
139 |
-
controllable_ui.model.device = "cuda"
|
140 |
-
controllable_ui.wgan.device = "cuda"
|
141 |
-
try:
|
142 |
-
sr, wav, fig = controllable_ui.read(prompt,
|
143 |
language.split(" ")[-1].split("(")[1].split(")")[0],
|
144 |
language.split(" ")[-1].split("(")[1].split(")")[0],
|
145 |
voice_seed,
|
@@ -154,11 +146,6 @@ def read(prompt,
|
|
154 |
0.,
|
155 |
0.,
|
156 |
-24.)
|
157 |
-
finally:
|
158 |
-
controllable_ui.to("cpu")
|
159 |
-
controllable_ui.device = "cpu"
|
160 |
-
controllable_ui.model.device = "cpu"
|
161 |
-
controllable_ui.wgan.device = "cpu"
|
162 |
return (sr, float2pcm(wav)), fig
|
163 |
|
164 |
|
|
|
35 |
self.model.device = "cpu"
|
36 |
self.wgan.to("cpu")
|
37 |
self.wgan.device = "cpu"
|
|
|
38 |
|
39 |
def read(self,
|
40 |
prompt,
|
|
|
122 |
|
123 |
|
124 |
|
|
|
125 |
def read(prompt,
|
126 |
language,
|
127 |
voice_seed,
|
|
|
131 |
emb1,
|
132 |
emb2
|
133 |
):
|
134 |
+
sr, wav, fig = controllable_ui.read(prompt,
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
language.split(" ")[-1].split("(")[1].split(")")[0],
|
136 |
language.split(" ")[-1].split("(")[1].split(")")[0],
|
137 |
voice_seed,
|
|
|
146 |
0.,
|
147 |
0.,
|
148 |
-24.)
|
|
|
|
|
|
|
|
|
|
|
149 |
return (sr, float2pcm(wav)), fig
|
150 |
|
151 |
|