Update audiosr/pipeline.py
Browse files- audiosr/pipeline.py +6 -1
audiosr/pipeline.py
CHANGED
@@ -80,6 +80,9 @@ def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
|
|
80 |
return {"ta_kaldi_fbank": fbank} # [1024, 128]
|
81 |
|
82 |
|
|
|
|
|
|
|
83 |
def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
|
84 |
log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file)
|
85 |
|
@@ -111,10 +114,12 @@ def round_up_duration(duration):
|
|
111 |
return int(round(duration / 2.5) + 1) * 2.5
|
112 |
|
113 |
|
|
|
114 |
def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
|
115 |
if device is None or device == "auto":
|
116 |
if torch.cuda.is_available():
|
117 |
-
device = torch.
|
|
|
118 |
elif torch.backends.mps.is_available():
|
119 |
device = torch.device("mps")
|
120 |
else:
|
|
|
80 |
return {"ta_kaldi_fbank": fbank} # [1024, 128]
|
81 |
|
82 |
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
|
87 |
log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file)
|
88 |
|
|
|
114 |
return int(round(duration / 2.5) + 1) * 2.5
|
115 |
|
116 |
|
117 |
+
@spaces.GPU
|
118 |
def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
|
119 |
if device is None or device == "auto":
|
120 |
if torch.cuda.is_available():
|
121 |
+
device = torch.Tensor([0]).cuda()
|
122 |
+
# device = torch.device("cuda:0")
|
123 |
elif torch.backends.mps.is_available():
|
124 |
device = torch.device("mps")
|
125 |
else:
|