jhj0517 commited on
Commit
7f7cda2
·
1 Parent(s): f0d2b3d

Refactor to create videos

Browse files
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -14,6 +14,7 @@ from typing import Union, List, Dict, Tuple
14
 
15
  from modules.utils.paths import *
16
  from modules.utils.image_helper import *
 
17
  from modules.live_portrait.model_downloader import *
18
  from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
19
  from modules.utils.camera import get_rotation_matrix
@@ -241,15 +242,21 @@ class LivePortraitInferencer:
241
  raise
242
 
243
  def create_video(self,
244
- retargeting_eyes: bool,
245
- retargeting_mouth: bool,
246
- tracking_src_vid: bool,
247
- animate_without_vid: bool,
248
- crop_factor: float,
 
249
  src_image_list: Optional[List[np.ndarray]] = None,
250
  driving_images: Optional[List[np.ndarray]] = None,
251
  progress: gr.Progress = gr.Progress()
252
  ):
 
 
 
 
 
253
  src_length = 1
254
 
255
  if src_image_list is not None:
@@ -322,7 +329,14 @@ class LivePortraitInferencer:
322
  return None
323
 
324
  out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list])
325
- return out_imgs
 
 
 
 
 
 
 
326
 
327
  def download_if_no_models(self,
328
  model_type: str = ModelType.HUMAN.value,
 
14
 
15
  from modules.utils.paths import *
16
  from modules.utils.image_helper import *
17
+ from modules.utils.video_helper import *
18
  from modules.live_portrait.model_downloader import *
19
  from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
20
  from modules.utils.camera import get_rotation_matrix
 
242
  raise
243
 
244
  def create_video(self,
245
+ model_type: str = ModelType.HUMAN.value,
246
+ retargeting_eyes: bool = True,
247
+ retargeting_mouth: bool = True,
248
+ tracking_src_vid: bool = True,
249
+ animate_without_vid: bool = False,
250
+ crop_factor: float = 1.5,
251
  src_image_list: Optional[List[np.ndarray]] = None,
252
  driving_images: Optional[List[np.ndarray]] = None,
253
  progress: gr.Progress = gr.Progress()
254
  ):
255
+ if self.pipeline is None or model_type != self.model_type:
256
+ self.load_models(
257
+ model_type=model_type
258
+ )
259
+
260
  src_length = 1
261
 
262
  if src_image_list is not None:
 
329
  return None
330
 
331
  out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list])
332
+ out_imgs = [tensor.permute(1, 2, 0).cpu().numpy() for tensor in out_imgs]
333
+ for img in out_imgs:
334
+ out_frame_path = get_auto_incremental_file_path(TEMP_VIDEO_OUT_FRAMES_DIR, "png")
335
+ save_image(img, out_frame_path)
336
+
337
+ video_path = create_video_from_frames(TEMP_VIDEO_OUT_FRAMES_DIR)
338
+
339
+ return video_path
340
 
341
  def download_if_no_models(self,
342
  model_type: str = ModelType.HUMAN.value,