Flux9665 commited on
Commit
ee42912
β€’
1 Parent(s): d763494

try to figure out how ZeroGPU works

Browse files
InferenceInterfaces/ToucanTTSInterface.py CHANGED
@@ -8,7 +8,7 @@ import pyloudnorm
8
  import sounddevice
9
  import soundfile
10
  import torch
11
-
12
  with warnings.catch_warnings():
13
  warnings.simplefilter("ignore")
14
  from speechbrain.pretrained import EncoderClassifier
@@ -127,6 +127,7 @@ class ToucanTTSInterface(torch.nn.Module):
127
 
128
  self.lang_id = get_language_id(lang_id).to(self.device)
129
 
 
130
  def forward(self,
131
  text,
132
  view=False,
@@ -152,6 +153,10 @@ class ToucanTTSInterface(torch.nn.Module):
152
  1.0 means no scaling happens, higher values increase variance of the energy curve,
153
  lower values decrease variance of the energy curve.
154
  """
 
 
 
 
155
  with torch.inference_mode():
156
  phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
157
  mel, durations, pitch, energy = self.phone2mel(phones,
@@ -223,6 +228,8 @@ class ToucanTTSInterface(torch.nn.Module):
223
  if return_plot_as_filepath:
224
  plt.savefig("tmp.png")
225
  return wave, sr, "tmp.png"
 
 
226
  return wave, sr
227
 
228
  def read_to_file(self,
 
8
  import sounddevice
9
  import soundfile
10
  import torch
11
+ import spaces
12
  with warnings.catch_warnings():
13
  warnings.simplefilter("ignore")
14
  from speechbrain.pretrained import EncoderClassifier
 
127
 
128
  self.lang_id = get_language_id(lang_id).to(self.device)
129
 
130
+ @spaces.GPU
131
  def forward(self,
132
  text,
133
  view=False,
 
153
  1.0 means no scaling happens, higher values increase variance of the energy curve,
154
  lower values decrease variance of the energy curve.
155
  """
156
+ device = "cuda" if torch.cuda.is_available() else "cpu"
157
+ self.device = device
158
+ self.to(device)
159
+
160
  with torch.inference_mode():
161
  phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
162
  mel, durations, pitch, energy = self.phone2mel(phones,
 
228
  if return_plot_as_filepath:
229
  plt.savefig("tmp.png")
230
  return wave, sr, "tmp.png"
231
+ self.to("cpu")
232
+ self.device="cpu"
233
  return wave, sr
234
 
235
  def read_to_file(self,
app.py CHANGED
@@ -35,7 +35,6 @@ class ControllableInterface(torch.nn.Module):
35
  self.model.device = "cpu"
36
  self.wgan.to("cpu")
37
  self.wgan.device = "cpu"
38
- self._modules = []
39
 
40
  def read(self,
41
  prompt,
@@ -123,7 +122,6 @@ class ControllableInterface(torch.nn.Module):
123
 
124
 
125
 
126
- @spaces.GPU
127
  def read(prompt,
128
  language,
129
  voice_seed,
@@ -133,13 +131,7 @@ def read(prompt,
133
  emb1,
134
  emb2
135
  ):
136
- if torch.cuda.is_available():
137
- controllable_ui.to("cuda")
138
- controllable_ui.device = "cuda"
139
- controllable_ui.model.device = "cuda"
140
- controllable_ui.wgan.device = "cuda"
141
- try:
142
- sr, wav, fig = controllable_ui.read(prompt,
143
  language.split(" ")[-1].split("(")[1].split(")")[0],
144
  language.split(" ")[-1].split("(")[1].split(")")[0],
145
  voice_seed,
@@ -154,11 +146,6 @@ def read(prompt,
154
  0.,
155
  0.,
156
  -24.)
157
- finally:
158
- controllable_ui.to("cpu")
159
- controllable_ui.device = "cpu"
160
- controllable_ui.model.device = "cpu"
161
- controllable_ui.wgan.device = "cpu"
162
  return (sr, float2pcm(wav)), fig
163
 
164
 
 
35
  self.model.device = "cpu"
36
  self.wgan.to("cpu")
37
  self.wgan.device = "cpu"
 
38
 
39
  def read(self,
40
  prompt,
 
122
 
123
 
124
 
 
125
  def read(prompt,
126
  language,
127
  voice_seed,
 
131
  emb1,
132
  emb2
133
  ):
134
+ sr, wav, fig = controllable_ui.read(prompt,
 
 
 
 
 
 
135
  language.split(" ")[-1].split("(")[1].split(")")[0],
136
  language.split(" ")[-1].split("(")[1].split(")")[0],
137
  voice_seed,
 
146
  0.,
147
  0.,
148
  -24.)
 
 
 
 
 
149
  return (sr, float2pcm(wav)), fig
150
 
151