cwitkowitz commited on
Commit
b7e4485
·
1 Parent(s): 8a8ea06

Forgot to add audio to current device.

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -22,8 +22,8 @@ model_path_orig = os.path.join('models', 'tt-orig.pt')
22
  tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
23
  #tt_weights_demo = torch.load(model_path_demo, map_location='cpu')
24
 
25
- if torch.cuda.is_available():
26
- model = model.cuda()
27
 
28
  model_card = ModelCard(
29
  name='Timbre-Trap',
@@ -32,7 +32,6 @@ model_card = ModelCard(
32
  tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering']
33
  )
34
 
35
-
36
  def process_fn(audio_path, transcribe):#, demo):
37
  # Load the audio with torchaudio
38
  audio, fs = torchaudio.load(audio_path)
@@ -54,6 +53,9 @@ def process_fn(audio_path, transcribe):#, demo):
54
  # Load weights of the original model
55
  model.load_state_dict(tt_weights_orig)
56
 
 
 
 
57
  # Obtain transcription or reconstructed spectral coefficients
58
  coefficients = model.chunked_inference(audio, transcribe)
59
 
@@ -70,6 +72,9 @@ def process_fn(audio_path, transcribe):#, demo):
70
  # Resample audio back to the original sampling rate
71
  audio = torchaudio.functional.resample(audio, 22050, fs)
72
 
 
 
 
73
  # Create a temporary directory for output
74
  os.makedirs('_outputs', exist_ok=True)
75
  # Create a path for saving the audio
 
22
  tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
23
  #tt_weights_demo = torch.load(model_path_demo, map_location='cpu')
24
 
25
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ model = model.to(device)
27
 
28
  model_card = ModelCard(
29
  name='Timbre-Trap',
 
32
  tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering']
33
  )
34
 
 
35
  def process_fn(audio_path, transcribe):#, demo):
36
  # Load the audio with torchaudio
37
  audio, fs = torchaudio.load(audio_path)
 
53
  # Load weights of the original model
54
  model.load_state_dict(tt_weights_orig)
55
 
56
+ # Add audio to current device
57
+ audio = audio.to(device)
58
+
59
  # Obtain transcription or reconstructed spectral coefficients
60
  coefficients = model.chunked_inference(audio, transcribe)
61
 
 
72
  # Resample audio back to the original sampling rate
73
  audio = torchaudio.functional.resample(audio, 22050, fs)
74
 
75
+ # Bring audio back to CPU
76
+ audio = audio.cpu()
77
+
78
  # Create a temporary directory for output
79
  os.makedirs('_outputs', exist_ok=True)
80
  # Create a path for saving the audio