Nathan Pruyne commited on
Commit
1dabc5e
·
1 Parent(s): b63dc56
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -23,8 +23,11 @@ model_path_orig = os.path.join('models', 'tt-orig.pt')
23
  tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
24
  #tt_weights_demo = torch.load(model_path_demo, map_location='cpu')
25
 
26
- if torch.cuda.is_available():
27
- model = model.cuda()
 
 
 
28
 
29
  model_card = ModelCard(
30
  name='Timbre-Trap',
@@ -55,6 +58,8 @@ def process_fn(audio_path, transcribe):#, demo):
55
  # Load weights of the original model
56
  model.load_state_dict(tt_weights_orig)
57
 
 
 
58
  # Obtain transcription or reconstructed spectral coefficients
59
  coefficients = model.chunked_inference(audio, transcribe)
60
 
 
23
  tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
24
  #tt_weights_demo = torch.load(model_path_demo, map_location='cpu')
25
 
26
+ # if torch.cuda.is_available():
27
+ # model = model.cuda()
28
+
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ model = model.to(device)
31
 
32
  model_card = ModelCard(
33
  name='Timbre-Trap',
 
58
  # Load weights of the original model
59
  model.load_state_dict(tt_weights_orig)
60
 
61
+ audio = audio.to(device)
62
+
63
  # Obtain transcription or reconstructed spectral coefficients
64
  coefficients = model.chunked_inference(audio, transcribe)
65