Update modeling_typhoonaudio.py
Browse filesrefactor generate() and streaming for talk areana
- modeling_typhoonaudio.py +42 -21
modeling_typhoonaudio.py
CHANGED
@@ -15,10 +15,9 @@ from transformers import (
|
|
15 |
WhisperModel,
|
16 |
PreTrainedModel,
|
17 |
AutoTokenizer,
|
18 |
-
AutoModelForCausalLM
|
|
|
19 |
)
|
20 |
-
import soundfile as sf
|
21 |
-
import librosa
|
22 |
from .configuration_typhoonaudio import TyphoonAudioConfig
|
23 |
# ---------------------------------------------------- #
|
24 |
# QFormer: https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
@@ -156,35 +155,26 @@ class TyphoonAudio(PreTrainedModel):
|
|
156 |
|
157 |
def generate(
|
158 |
self,
|
159 |
-
|
160 |
prompt,
|
161 |
prompt_pattern,
|
162 |
-
|
163 |
-
max_length=150,
|
164 |
num_beams=4,
|
165 |
do_sample=True,
|
166 |
-
min_length=1,
|
167 |
top_p=0.9,
|
168 |
repetition_penalty=1.0,
|
169 |
length_penalty=1.0,
|
170 |
temperature=1.0,
|
171 |
streamer=None
|
172 |
-
):
|
173 |
-
|
174 |
-
|
175 |
-
if len(wav.shape) == 2:
|
176 |
-
wav = wav[:, 0]
|
177 |
-
if len(wav) > 30 * sr:
|
178 |
-
wav = wav[: 30 * sr]
|
179 |
-
if sr != 16000:
|
180 |
-
wav = librosa.resample(wav, orig_sr=sr, target_sr=16000, res_type="fft")
|
181 |
-
|
182 |
# whisper
|
183 |
-
spectrogram = self.feature_extractor(
|
184 |
speech_embeds = self.speech_encoder(spectrogram, return_dict=True).last_hidden_state
|
185 |
|
186 |
# beats
|
187 |
-
raw_wav = torch.from_numpy(
|
188 |
audio_padding_mask = torch.zeros(raw_wav.shape, device=device).bool()
|
189 |
audio_embeds, _ = self.beats.extract_features(raw_wav, padding_mask=audio_padding_mask, feature_only=True, torch_dtype=self.torch_dtype)
|
190 |
|
@@ -256,10 +246,9 @@ class TyphoonAudio(PreTrainedModel):
|
|
256 |
# generate
|
257 |
output = self.llama_model.generate(
|
258 |
inputs_embeds=embeds,
|
259 |
-
|
260 |
num_beams=num_beams,
|
261 |
do_sample=do_sample,
|
262 |
-
min_length=min_length,
|
263 |
top_p=top_p,
|
264 |
repetition_penalty=repetition_penalty,
|
265 |
length_penalty=length_penalty,
|
@@ -273,6 +262,38 @@ class TyphoonAudio(PreTrainedModel):
|
|
273 |
output_text = self.llama_tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)
|
274 |
return output_text[0]
|
275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
def init_speech_Qformer(self, num_query_token, speech_width, num_hidden_layers=2, torch_dtype="float16"):
|
277 |
encoder_config = BertConfig()
|
278 |
encoder_config.num_hidden_layers = num_hidden_layers
|
|
|
15 |
WhisperModel,
|
16 |
PreTrainedModel,
|
17 |
AutoTokenizer,
|
18 |
+
AutoModelForCausalLM,
|
19 |
+
TextIteratorStreamer
|
20 |
)
|
|
|
|
|
21 |
from .configuration_typhoonaudio import TyphoonAudioConfig
|
22 |
# ---------------------------------------------------- #
|
23 |
# QFormer: https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
|
|
155 |
|
156 |
def generate(
|
157 |
self,
|
158 |
+
audio,
|
159 |
prompt,
|
160 |
prompt_pattern,
|
161 |
+
max_new_tokens=1024,
|
|
|
162 |
num_beams=4,
|
163 |
do_sample=True,
|
|
|
164 |
top_p=0.9,
|
165 |
repetition_penalty=1.0,
|
166 |
length_penalty=1.0,
|
167 |
temperature=1.0,
|
168 |
streamer=None
|
169 |
+
):
|
170 |
+
device = self.llama_model.device
|
171 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
# whisper
|
173 |
+
spectrogram = self.feature_extractor(audio, return_tensors="pt", sampling_rate=16000).input_features.to(device).to(self.torch_dtype) # [1, 80, 3000]
|
174 |
speech_embeds = self.speech_encoder(spectrogram, return_dict=True).last_hidden_state
|
175 |
|
176 |
# beats
|
177 |
+
raw_wav = torch.from_numpy(audio).to(device).unsqueeze(0)
|
178 |
audio_padding_mask = torch.zeros(raw_wav.shape, device=device).bool()
|
179 |
audio_embeds, _ = self.beats.extract_features(raw_wav, padding_mask=audio_padding_mask, feature_only=True, torch_dtype=self.torch_dtype)
|
180 |
|
|
|
246 |
# generate
|
247 |
output = self.llama_model.generate(
|
248 |
inputs_embeds=embeds,
|
249 |
+
max_new_tokens=max_new_tokens,
|
250 |
num_beams=num_beams,
|
251 |
do_sample=do_sample,
|
|
|
252 |
top_p=top_p,
|
253 |
repetition_penalty=repetition_penalty,
|
254 |
length_penalty=length_penalty,
|
|
|
262 |
output_text = self.llama_tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)
|
263 |
return output_text[0]
|
264 |
|
265 |
+
def generate_stream(
|
266 |
+
self,
|
267 |
+
audio,
|
268 |
+
prompt,
|
269 |
+
prompt_pattern="<|start_header_id|>user<|end_header_id|>\n\n<Speech><SpeechHere></Speech> {}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
270 |
+
max_new_tokens=1024,
|
271 |
+
do_sample=True,
|
272 |
+
top_p=0.9,
|
273 |
+
repetition_penalty=1.0,
|
274 |
+
length_penalty=1.0,
|
275 |
+
temperature=1.0,
|
276 |
+
):
|
277 |
+
streamer = TextIteratorStreamer(self.llama_tokenizer)
|
278 |
+
_ = self.generate(
|
279 |
+
audio=audio,
|
280 |
+
prompt=prompt,
|
281 |
+
prompt_pattern=prompt_pattern,
|
282 |
+
do_sample=do_sample,
|
283 |
+
max_new_tokens=max_new_tokens,
|
284 |
+
temperature=temperature,
|
285 |
+
top_p=top_p,
|
286 |
+
repetition_penalty=repetition_penalty,
|
287 |
+
length_penalty=length_penalty,
|
288 |
+
streamer=streamer,
|
289 |
+
num_beams=1,
|
290 |
+
)
|
291 |
+
response = ""
|
292 |
+
for new_tokens in streamer:
|
293 |
+
response += new_tokens.replace("<|eot_id|>", "").replace("<|end_of_text|>", "")
|
294 |
+
yield response
|
295 |
+
return response
|
296 |
+
|
297 |
def init_speech_Qformer(self, num_query_token, speech_width, num_hidden_layers=2, torch_dtype="float16"):
|
298 |
encoder_config = BertConfig()
|
299 |
encoder_config.num_hidden_layers = num_hidden_layers
|