gaunernst commited on
Commit
bcc0935
·
1 Parent(s): cac3ec7

fix preprocessing. add examples

Browse files
LS_female_1462-170138-0008.flac ADDED
Binary file (122 kB). View file
 
LS_male_3170-137482-0005.flac ADDED
Binary file (155 kB). View file
 
app.py CHANGED
@@ -17,6 +17,8 @@ LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/au
17
  AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())
18
 
19
  SAMPLING_RATE = 16_000
 
 
20
 
21
 
22
  def resample(x: np.ndarray, sr: int):
@@ -26,25 +28,34 @@ def resample(x: np.ndarray, sr: int):
26
 
27
 
28
  def preprocess(x: torch.Tensor):
 
29
  melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
30
  if melspec.shape[0] < 1024:
31
  melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
32
  else:
33
  melspec = melspec[:1024]
 
34
  return melspec.view(1, 1, 1024, 128)
35
 
36
 
37
- def predict(audio):
38
  sr, x = audio
39
- x = resample(x, sr)
 
 
 
40
  x = torch.from_numpy(x)
41
 
42
  with torch.inference_mode():
43
  logits = MODEL(preprocess(x)).squeeze(0)
44
 
45
- topk_probs, topk_classes = logits.softmax(dim=-1).topk(5)
46
  return [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
47
 
48
 
49
- iface = gr.Interface(fn=predict, inputs="audio", outputs="dataframe")
50
- iface.launch()
 
 
 
 
 
17
  AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())
18
 
19
  SAMPLING_RATE = 16_000
20
+ MEAN = -4.2677393
21
+ STD = 4.5689974
22
 
23
 
24
  def resample(x: np.ndarray, sr: int):
 
28
 
29
 
30
  def preprocess(x: torch.Tensor):
31
+ x = x - x.mean()
32
  melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
33
  if melspec.shape[0] < 1024:
34
  melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
35
  else:
36
  melspec = melspec[:1024]
37
+ melspec = (melspec - MEAN) / (STD * 2)
38
  return melspec.view(1, 1, 1024, 128)
39
 
40
 
41
+ def predict(audio, start):
42
  sr, x = audio
43
+ if x.shape[0] < start * sr:
44
+ raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)")
45
+
46
+ x = resample(x[int(start * sr) :], sr)
47
  x = torch.from_numpy(x)
48
 
49
  with torch.inference_mode():
50
  logits = MODEL(preprocess(x)).squeeze(0)
51
 
52
+ topk_probs, topk_classes = logits.sigmoid().topk(10)
53
  return [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
54
 
55
 
56
+ gr.Interface(
57
+ fn=predict,
58
+ inputs=["audio", "number"],
59
+ outputs="dataframe",
60
+ examples=[["LS_female_1462-170138-0008.flac", 0], ["LS_male_3170-137482-0005.flac", 0]],
61
+ ).launch()