kgout commited on
Commit
6f64722
·
verified ·
1 Parent(s): fa90792

Update audiosr/pipeline.py

Browse files
Files changed (1) hide show
  1. audiosr/pipeline.py +6 -1
audiosr/pipeline.py CHANGED
@@ -80,6 +80,9 @@ def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec):
80
  return {"ta_kaldi_fbank": fbank} # [1024, 128]
81
 
82
 
 
 
 
83
  def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
84
  log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file)
85
 
@@ -111,10 +114,12 @@ def round_up_duration(duration):
111
  return int(round(duration / 2.5) + 1) * 2.5
112
 
113
 
 
114
  def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
115
  if device is None or device == "auto":
116
  if torch.cuda.is_available():
117
- device = torch.device("cuda:0")
 
118
  elif torch.backends.mps.is_available():
119
  device = torch.device("mps")
120
  else:
 
80
  return {"ta_kaldi_fbank": fbank} # [1024, 128]
81
 
82
 
83
+
84
+
85
+
86
  def make_batch_for_super_resolution(input_file, waveform=None, fbank=None):
87
  log_mel_spec, stft, waveform, duration, target_frame = read_audio_file(input_file)
88
 
 
114
  return int(round(duration / 2.5) + 1) * 2.5
115
 
116
 
117
+ @spaces.GPU
118
  def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
119
  if device is None or device == "auto":
120
  if torch.cuda.is_available():
121
+ device = torch.Tensor([0]).cuda()
122
+ # device = torch.device("cuda:0")
123
  elif torch.backends.mps.is_available():
124
  device = torch.device("mps")
125
  else: