dongyh20 commited on
Commit
2658df1
·
1 Parent(s): 1ea63a6

update space

Browse files
Files changed (1) hide show
  1. app.py +19 -18
app.py CHANGED
@@ -73,10 +73,11 @@ beats_path = hf_hub_download(
73
 
74
  model_path = "THUdyh/Ola-7b"
75
  tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None)
76
- model = model.to('cuda').eval()
 
77
  model = model.bfloat16()
78
 
79
- tts_model = CosyVoice('iic/CosyVoice-300M-SFT', load_jit=True, fp16=True)
80
  # tts_model = CosyVoice('FunAudioLLM/CosyVoice-300M-SFT', load_jit=True, fp16=True)
81
  OUTPUT_SPEECH = False
82
 
@@ -186,10 +187,10 @@ def ola_inference(multimodal, audio_path):
186
  if USE_SPEECH and audio_path:
187
  audio_path = audio_path
188
  speech, speech_length, speech_chunk, speech_wav = load_audio(audio_path)
189
- speechs.append(speech.bfloat16().to('cuda'))
190
- speech_lengths.append(speech_length.to('cuda'))
191
- speech_chunks.append(speech_chunk.to('cuda'))
192
- speech_wavs.append(speech_wav.to('cuda'))
193
  print('load audio')
194
  elif USE_SPEECH and not audio_path:
195
  # parse audio in the video
@@ -197,15 +198,15 @@ def ola_inference(multimodal, audio_path):
197
  audio.write_audiofile("./video_audio.wav")
198
  video_audio_path = './video_audio.wav'
199
  speech, speech_length, speech_chunk, speech_wav = load_audio(video_audio_path)
200
- speechs.append(speech.bfloat16().to('cuda'))
201
- speech_lengths.append(speech_length.to('cuda'))
202
- speech_chunks.append(speech_chunk.to('cuda'))
203
- speech_wavs.append(speech_wav.to('cuda'))
204
  else:
205
- speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')]
206
- speech_lengths = [torch.LongTensor([3000]).to('cuda')]
207
- speech_wavs = [torch.zeros([1, 480000]).to('cuda')]
208
- speech_chunks = [torch.LongTensor([1]).to('cuda')]
209
 
210
  conv_mode = "qwen_1_5"
211
  if text:
@@ -224,11 +225,11 @@ def ola_inference(multimodal, audio_path):
224
  conv.append_message(conv.roles[1], None)
225
  prompt = conv.get_prompt()
226
  if USE_SPEECH and audio_path:
227
- input_ids = tokenizer_speech_question_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
228
  elif USE_SPEECH:
229
- input_ids = tokenizer_speech_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
230
  else:
231
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
232
 
233
  if modality == "video":
234
  video_processed = []
@@ -272,7 +273,7 @@ def ola_inference(multimodal, audio_path):
272
 
273
  pad_token_ids = 151643
274
 
275
- attention_masks = input_ids.ne(pad_token_ids).long().to('cuda')
276
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
277
  keywords = [stop_str]
278
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
 
73
 
74
  model_path = "THUdyh/Ola-7b"
75
  tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None)
76
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
77
+ model = model.to(device).eval()
78
  model = model.bfloat16()
79
 
80
+ # tts_model = CosyVoice('iic/CosyVoice-300M-SFT', load_jit=True, fp16=True)
81
  # tts_model = CosyVoice('FunAudioLLM/CosyVoice-300M-SFT', load_jit=True, fp16=True)
82
  OUTPUT_SPEECH = False
83
 
 
187
  if USE_SPEECH and audio_path:
188
  audio_path = audio_path
189
  speech, speech_length, speech_chunk, speech_wav = load_audio(audio_path)
190
+ speechs.append(speech.bfloat16().to(device))
191
+ speech_lengths.append(speech_length.to(device))
192
+ speech_chunks.append(speech_chunk.to(device))
193
+ speech_wavs.append(speech_wav.to(device))
194
  print('load audio')
195
  elif USE_SPEECH and not audio_path:
196
  # parse audio in the video
 
198
  audio.write_audiofile("./video_audio.wav")
199
  video_audio_path = './video_audio.wav'
200
  speech, speech_length, speech_chunk, speech_wav = load_audio(video_audio_path)
201
+ speechs.append(speech.bfloat16().to(device))
202
+ speech_lengths.append(speech_length.to(device))
203
+ speech_chunks.append(speech_chunk.to(device))
204
+ speech_wavs.append(speech_wav.to(device))
205
  else:
206
+ speechs = [torch.zeros(1, 3000, 128).bfloat16().to(device)]
207
+ speech_lengths = [torch.LongTensor([3000]).to(device)]
208
+ speech_wavs = [torch.zeros([1, 480000]).to(device)]
209
+ speech_chunks = [torch.LongTensor([1]).to(device)]
210
 
211
  conv_mode = "qwen_1_5"
212
  if text:
 
225
  conv.append_message(conv.roles[1], None)
226
  prompt = conv.get_prompt()
227
  if USE_SPEECH and audio_path:
228
+ input_ids = tokenizer_speech_question_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
229
  elif USE_SPEECH:
230
+ input_ids = tokenizer_speech_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
231
  else:
232
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
233
 
234
  if modality == "video":
235
  video_processed = []
 
273
 
274
  pad_token_ids = 151643
275
 
276
+ attention_masks = input_ids.ne(pad_token_ids).long().to(device)
277
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
278
  keywords = [stop_str]
279
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)