potsawee commited on
Commit
40ca02b
·
verified ·
1 Parent(s): 346bf2e

Update modeling_typhoonaudio.py

Browse files

refactor generate() and streaming for talk areana

Files changed (1) hide show
  1. modeling_typhoonaudio.py +42 -21
modeling_typhoonaudio.py CHANGED
@@ -15,10 +15,9 @@ from transformers import (
15
  WhisperModel,
16
  PreTrainedModel,
17
  AutoTokenizer,
18
- AutoModelForCausalLM
 
19
  )
20
- import soundfile as sf
21
- import librosa
22
  from .configuration_typhoonaudio import TyphoonAudioConfig
23
  # ---------------------------------------------------- #
24
  # QFormer: https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
@@ -156,35 +155,26 @@ class TyphoonAudio(PreTrainedModel):
156
 
157
  def generate(
158
  self,
159
- wav_path,
160
  prompt,
161
  prompt_pattern,
162
- device='cuda:0',
163
- max_length=150,
164
  num_beams=4,
165
  do_sample=True,
166
- min_length=1,
167
  top_p=0.9,
168
  repetition_penalty=1.0,
169
  length_penalty=1.0,
170
  temperature=1.0,
171
  streamer=None
172
- ):
173
- # read wav
174
- wav, sr = sf.read(wav_path)
175
- if len(wav.shape) == 2:
176
- wav = wav[:, 0]
177
- if len(wav) > 30 * sr:
178
- wav = wav[: 30 * sr]
179
- if sr != 16000:
180
- wav = librosa.resample(wav, orig_sr=sr, target_sr=16000, res_type="fft")
181
-
182
  # whisper
183
- spectrogram = self.feature_extractor(wav, return_tensors="pt", sampling_rate=16000).input_features.to(device).to(self.torch_dtype) # [1, 80, 3000]
184
  speech_embeds = self.speech_encoder(spectrogram, return_dict=True).last_hidden_state
185
 
186
  # beats
187
- raw_wav = torch.from_numpy(wav).to(device).unsqueeze(0)
188
  audio_padding_mask = torch.zeros(raw_wav.shape, device=device).bool()
189
  audio_embeds, _ = self.beats.extract_features(raw_wav, padding_mask=audio_padding_mask, feature_only=True, torch_dtype=self.torch_dtype)
190
 
@@ -256,10 +246,9 @@ class TyphoonAudio(PreTrainedModel):
256
  # generate
257
  output = self.llama_model.generate(
258
  inputs_embeds=embeds,
259
- max_length=max_length,
260
  num_beams=num_beams,
261
  do_sample=do_sample,
262
- min_length=min_length,
263
  top_p=top_p,
264
  repetition_penalty=repetition_penalty,
265
  length_penalty=length_penalty,
@@ -273,6 +262,38 @@ class TyphoonAudio(PreTrainedModel):
273
  output_text = self.llama_tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)
274
  return output_text[0]
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  def init_speech_Qformer(self, num_query_token, speech_width, num_hidden_layers=2, torch_dtype="float16"):
277
  encoder_config = BertConfig()
278
  encoder_config.num_hidden_layers = num_hidden_layers
 
15
  WhisperModel,
16
  PreTrainedModel,
17
  AutoTokenizer,
18
+ AutoModelForCausalLM,
19
+ TextIteratorStreamer
20
  )
 
 
21
  from .configuration_typhoonaudio import TyphoonAudioConfig
22
  # ---------------------------------------------------- #
23
  # QFormer: https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
 
155
 
156
  def generate(
157
  self,
158
+ audio,
159
  prompt,
160
  prompt_pattern,
161
+ max_new_tokens=1024,
 
162
  num_beams=4,
163
  do_sample=True,
 
164
  top_p=0.9,
165
  repetition_penalty=1.0,
166
  length_penalty=1.0,
167
  temperature=1.0,
168
  streamer=None
169
+ ):
170
+ device = self.llama_model.device
171
+
 
 
 
 
 
 
 
172
  # whisper
173
+ spectrogram = self.feature_extractor(audio, return_tensors="pt", sampling_rate=16000).input_features.to(device).to(self.torch_dtype) # [1, 80, 3000]
174
  speech_embeds = self.speech_encoder(spectrogram, return_dict=True).last_hidden_state
175
 
176
  # beats
177
+ raw_wav = torch.from_numpy(audio).to(device).unsqueeze(0)
178
  audio_padding_mask = torch.zeros(raw_wav.shape, device=device).bool()
179
  audio_embeds, _ = self.beats.extract_features(raw_wav, padding_mask=audio_padding_mask, feature_only=True, torch_dtype=self.torch_dtype)
180
 
 
246
  # generate
247
  output = self.llama_model.generate(
248
  inputs_embeds=embeds,
249
+ max_new_tokens=max_new_tokens,
250
  num_beams=num_beams,
251
  do_sample=do_sample,
 
252
  top_p=top_p,
253
  repetition_penalty=repetition_penalty,
254
  length_penalty=length_penalty,
 
262
  output_text = self.llama_tokenizer.batch_decode(output, add_special_tokens=False, skip_special_tokens=True)
263
  return output_text[0]
264
 
265
+ def generate_stream(
266
+ self,
267
+ audio,
268
+ prompt,
269
+ prompt_pattern="<|start_header_id|>user<|end_header_id|>\n\n<Speech><SpeechHere></Speech> {}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
270
+ max_new_tokens=1024,
271
+ do_sample=True,
272
+ top_p=0.9,
273
+ repetition_penalty=1.0,
274
+ length_penalty=1.0,
275
+ temperature=1.0,
276
+ ):
277
+ streamer = TextIteratorStreamer(self.llama_tokenizer)
278
+ _ = self.generate(
279
+ audio=audio,
280
+ prompt=prompt,
281
+ prompt_pattern=prompt_pattern,
282
+ do_sample=do_sample,
283
+ max_new_tokens=max_new_tokens,
284
+ temperature=temperature,
285
+ top_p=top_p,
286
+ repetition_penalty=repetition_penalty,
287
+ length_penalty=length_penalty,
288
+ streamer=streamer,
289
+ num_beams=1,
290
+ )
291
+ response = ""
292
+ for new_tokens in streamer:
293
+ response += new_tokens.replace("<|eot_id|>", "").replace("<|end_of_text|>", "")
294
+ yield response
295
+ return response
296
+
297
  def init_speech_Qformer(self, num_query_token, speech_width, num_hidden_layers=2, torch_dtype="float16"):
298
  encoder_config = BertConfig()
299
  encoder_config.num_hidden_layers = num_hidden_layers