cwitkowitz commited on
Commit
8a8ea06
·
1 Parent(s): 0806b5e

Added line for CUDA processing.

Browse files
Files changed (1) hide show
  1. app.py +3 -0
app.py CHANGED
@@ -22,6 +22,9 @@ model_path_orig = os.path.join('models', 'tt-orig.pt')
22
  tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
23
  #tt_weights_demo = torch.load(model_path_demo, map_location='cpu')
24
 
 
 
 
25
  model_card = ModelCard(
26
  name='Timbre-Trap',
27
  description='De-timbre your audio!',
 
22
  tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
23
  #tt_weights_demo = torch.load(model_path_demo, map_location='cpu')
24
 
25
+ if torch.cuda.is_available():
26
+ model = model.cuda()
27
+
28
  model_card = ModelCard(
29
  name='Timbre-Trap',
30
  description='De-timbre your audio!',