cwitkowitz commited on
Commit
0806b5e
·
1 Parent(s): d9bbeb4

Switched to checkbox control and added in commented out code for future demo mode.

Browse files
Files changed (1) hide show
  1. app.py +16 -21
app.py CHANGED
@@ -17,9 +17,10 @@ model = TimbreTrap(sample_rate=22050,
17
  model.eval()
18
 
19
  model_path_orig = os.path.join('models', 'tt-orig.pt')
 
20
 
21
  tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
22
- model.load_state_dict(tt_weights_orig)
23
 
24
  model_card = ModelCard(
25
  name='Timbre-Trap',
@@ -29,7 +30,7 @@ model_card = ModelCard(
29
  )
30
 
31
 
32
- def process_fn(audio_path, transcribe):
33
  # Load the audio with torchaudio
34
  audio, fs = torchaudio.load(audio_path)
35
  # Average channels to obtain mono-channel
@@ -41,6 +42,15 @@ def process_fn(audio_path, transcribe):
41
  # Determine original number of samples
42
  n_samples = audio.size(-1)
43
 
 
 
 
 
 
 
 
 
 
44
  # Obtain transcription or reconstructed spectral coefficients
45
  coefficients = model.chunked_inference(audio, transcribe)
46
 
@@ -73,25 +83,10 @@ def process_fn(audio_path, transcribe):
73
  # Build Gradio endpoint
74
  with gr.Blocks() as demo:
75
  components = [
76
- #gr.Checkbox(
77
- # value=False,
78
- # label='De-Timbre'
79
- #),
80
- gr.Slider(
81
- minimum=0,
82
- maximum=1,
83
- step=1,
84
- value=0,
85
- label='De-Timbre'
86
- ),
87
- #gr.Number(
88
- # value=0,
89
- # label='De-Timbre'
90
- #),
91
- #gr.Textbox(
92
- # value='text',
93
- # label='De-Timbre'
94
- #)
95
  ]
96
 
97
  app = build_endpoint(model_card=model_card,
 
17
  model.eval()
18
 
19
  model_path_orig = os.path.join('models', 'tt-orig.pt')
20
+ #model_path_demo = os.path.join('models', 'tt-demo.pt')
21
 
22
  tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
23
+ #tt_weights_demo = torch.load(model_path_demo, map_location='cpu')
24
 
25
  model_card = ModelCard(
26
  name='Timbre-Trap',
 
30
  )
31
 
32
 
33
+ def process_fn(audio_path, transcribe):#, demo):
34
  # Load the audio with torchaudio
35
  audio, fs = torchaudio.load(audio_path)
36
  # Average channels to obtain mono-channel
 
42
  # Determine original number of samples
43
  n_samples = audio.size(-1)
44
 
45
+ """
46
+ if demo:
47
+ # Load weights of the demo version
48
+ model.load_state_dict(tt_weights_demo)
49
+ else:
50
+ """
51
+ # Load weights of the original model
52
+ model.load_state_dict(tt_weights_orig)
53
+
54
  # Obtain transcription or reconstructed spectral coefficients
55
  coefficients = model.chunked_inference(audio, transcribe)
56
 
 
83
  # Build Gradio endpoint
84
  with gr.Blocks() as demo:
85
  components = [
86
+ gr.Checkbox(
87
+ value=False,
88
+ label='Remove Timbre'
89
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  ]
91
 
92
  app = build_endpoint(model_card=model_card,