RaysDipesh commited on
Commit
dfef5ed
·
1 Parent(s): 48a2587

Create processing_whisper.py

Browse files
Files changed (1) hide show
  1. processing_whisper.py +145 -0
processing_whisper.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ from transformers import WhisperProcessor
5
+
6
+
7
+ class WhisperPrePostProcessor(WhisperProcessor):
8
+ def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
9
+ inputs_len = inputs.shape[0]
10
+ step = chunk_len - stride_left - stride_right
11
+
12
+ all_chunk_start_idx = np.arange(0, inputs_len, step)
13
+ num_samples = len(all_chunk_start_idx)
14
+
15
+ num_batches = math.ceil(num_samples / batch_size)
16
+ batch_idx = np.array_split(np.arange(num_samples), num_batches)
17
+
18
+ for i, idx in enumerate(batch_idx):
19
+ chunk_start_idx = all_chunk_start_idx[idx]
20
+
21
+ chunk_end_idx = chunk_start_idx + chunk_len
22
+
23
+ chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
24
+ processed = self.feature_extractor(
25
+ chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
26
+ )
27
+
28
+ _stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
29
+ is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
30
+ _stride_right = np.where(is_last, 0, stride_right)
31
+
32
+ chunk_lens = [chunk.shape[0] for chunk in chunks]
33
+ strides = [
34
+ (int(chunk_l), int(_stride_l), int(_stride_r))
35
+ for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
36
+ ]
37
+
38
+ yield {"stride": strides, **processed}
39
+
40
+ def preprocess_batch(self, inputs, chunk_length_s=0, stride_length_s=None, batch_size=None):
41
+ stride = None
42
+ if isinstance(inputs, dict):
43
+ stride = inputs.pop("stride", None)
44
+ # Accepting `"array"` which is the key defined in `datasets` for
45
+ # better integration
46
+ if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
47
+ raise ValueError(
48
+ "When passing a dictionary to FlaxWhisperPipline, the dict needs to contain a "
49
+ '"raw" or "array" key containing the numpy array representing the audio, and a "sampling_rate" key '
50
+ "containing the sampling rate associated with the audio array."
51
+ )
52
+
53
+ _inputs = inputs.pop("raw", None)
54
+ if _inputs is None:
55
+ # Remove path which will not be used from `datasets`.
56
+ inputs.pop("path", None)
57
+ _inputs = inputs.pop("array", None)
58
+ in_sampling_rate = inputs.pop("sampling_rate")
59
+ inputs = _inputs
60
+
61
+ if in_sampling_rate != self.feature_extractor.sampling_rate:
62
+ try:
63
+ import librosa
64
+ except ImportError as err:
65
+ raise ImportError(
66
+ "To support resampling audio files, please install 'librosa' and 'soundfile'."
67
+ ) from err
68
+
69
+ inputs = librosa.resample(
70
+ inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
71
+ )
72
+ ratio = self.feature_extractor.sampling_rate / in_sampling_rate
73
+ else:
74
+ ratio = 1
75
+
76
+ if not isinstance(inputs, np.ndarray):
77
+ raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`.")
78
+ if len(inputs.shape) != 1:
79
+ raise ValueError(
80
+ f"We expect a single channel audio input for the Flax Whisper API, got {len(inputs.shape)} channels."
81
+ )
82
+
83
+ if stride is not None:
84
+ if stride[0] + stride[1] > inputs.shape[0]:
85
+ raise ValueError("Stride is too large for input.")
86
+
87
+ # Stride needs to get the chunk length here, it's going to get
88
+ # swallowed by the `feature_extractor` later, and then batching
89
+ # can add extra data in the inputs, so we need to keep track
90
+ # of the original length in the stride so we can cut properly.
91
+ stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
92
+
93
+ if chunk_length_s:
94
+ if stride_length_s is None:
95
+ stride_length_s = chunk_length_s / 6
96
+
97
+ if isinstance(stride_length_s, (int, float)):
98
+ stride_length_s = [stride_length_s, stride_length_s]
99
+
100
+ chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
101
+ stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
102
+ stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
103
+
104
+ if chunk_len < stride_left + stride_right:
105
+ raise ValueError("Chunk length must be superior to stride length.")
106
+
107
+ for item in self.chunk_iter_with_batch(
108
+ inputs,
109
+ chunk_len,
110
+ stride_left,
111
+ stride_right,
112
+ batch_size,
113
+ ):
114
+ yield item
115
+ else:
116
+ processed = self.feature_extractor(
117
+ inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
118
+ )
119
+ if stride is not None:
120
+ processed["stride"] = stride
121
+ yield processed
122
+
123
+ def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
124
+ # unpack the outputs from list(dict(list)) to list(dict)
125
+ model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]
126
+
127
+ time_precision = self.feature_extractor.chunk_length / 1500 # max source positions = 1500
128
+ # Send the chunking back to seconds, it's easier to handle in whisper
129
+ sampling_rate = self.feature_extractor.sampling_rate
130
+ for output in model_outputs:
131
+ if "stride" in output:
132
+ chunk_len, stride_left, stride_right = output["stride"]
133
+ # Go back in seconds
134
+ chunk_len /= sampling_rate
135
+ stride_left /= sampling_rate
136
+ stride_right /= sampling_rate
137
+ output["stride"] = chunk_len, stride_left, stride_right
138
+
139
+ text, optional = self.tokenizer._decode_asr(
140
+ model_outputs,
141
+ return_timestamps=return_timestamps,
142
+ return_language=return_language,
143
+ time_precision=time_precision,
144
+ )
145
+ return {"text": text, **optional}