MiXaiLL76 commited on
Commit
1d05ae5
1 Parent(s): 73271d4
Files changed (1) hide show
  1. tools/extract_embedding.py +8 -23
tools/extract_embedding.py CHANGED
@@ -21,11 +21,10 @@ import torch
21
  import torchaudio
22
  import torchaudio.compliance.kaldi as kaldi
23
  from tqdm import tqdm
 
24
 
25
 
26
- def extract_embedding(input_list):
27
- utt, wav_file, ort_session = input_list
28
-
29
  audio, sample_rate = torchaudio.load(wav_file)
30
  if sample_rate != 16000:
31
  audio = torchaudio.transforms.Resample(
@@ -33,19 +32,7 @@ def extract_embedding(input_list):
33
  )(audio)
34
  feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
35
  feat = feat - feat.mean(dim=0, keepdim=True)
36
- embedding = (
37
- ort_session.run(
38
- None,
39
- {
40
- ort_session.get_inputs()[0]
41
- .name: feat.unsqueeze(dim=0)
42
- .cpu()
43
- .numpy()
44
- },
45
- )[0]
46
- .flatten()
47
- .tolist()
48
- )
49
  return (utt, embedding)
50
 
51
 
@@ -72,16 +59,14 @@ def main(args):
72
  args.onnx_path, sess_options=option, providers=providers
73
  )
74
 
75
- inputs = [
76
- (utt, utt2wav[utt], ort_session)
77
- for utt in tqdm(utt2wav.keys(), desc="Load data")
78
- ]
79
  with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
80
  results = list(
81
  tqdm(
82
- executor.map(extract_embedding, inputs),
83
- total=len(inputs),
84
- desc="Process data: ",
85
  )
86
  )
87
 
 
21
  import torchaudio
22
  import torchaudio.compliance.kaldi as kaldi
23
  from tqdm import tqdm
24
+ from itertools import repeat
25
 
26
 
27
+ def extract_embedding(utt: str, wav_file: str, ort_session: onnxruntime.InferenceSession):
 
 
28
  audio, sample_rate = torchaudio.load(wav_file)
29
  if sample_rate != 16000:
30
  audio = torchaudio.transforms.Resample(
 
32
  )(audio)
33
  feat = kaldi.fbank(audio, num_mel_bins=80, dither=0, sample_frequency=16000)
34
  feat = feat - feat.mean(dim=0, keepdim=True)
35
+ embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
36
  return (utt, embedding)
37
 
38
 
 
59
  args.onnx_path, sess_options=option, providers=providers
60
  )
61
 
62
+ all_utt = utt2wav.keys()
63
+
 
 
64
  with ThreadPoolExecutor(max_workers=args.num_thread) as executor:
65
  results = list(
66
  tqdm(
67
+ executor.map(extract_embedding, all_utt, [utt2wav[utt] for utt in all_utt], repeat(ort_session)),
68
+ total=len(utt2wav),
69
+ desc="Process data: "
70
  )
71
  )
72