jhj0517
commited on
Commit
·
7f7cda2
1
Parent(s):
f0d2b3d
Refactor to create videos
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
@@ -14,6 +14,7 @@ from typing import Union, List, Dict, Tuple
|
|
14 |
|
15 |
from modules.utils.paths import *
|
16 |
from modules.utils.image_helper import *
|
|
|
17 |
from modules.live_portrait.model_downloader import *
|
18 |
from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
|
19 |
from modules.utils.camera import get_rotation_matrix
|
@@ -241,15 +242,21 @@ class LivePortraitInferencer:
|
|
241 |
raise
|
242 |
|
243 |
def create_video(self,
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
|
|
249 |
src_image_list: Optional[List[np.ndarray]] = None,
|
250 |
driving_images: Optional[List[np.ndarray]] = None,
|
251 |
progress: gr.Progress = gr.Progress()
|
252 |
):
|
|
|
|
|
|
|
|
|
|
|
253 |
src_length = 1
|
254 |
|
255 |
if src_image_list is not None:
|
@@ -322,7 +329,14 @@ class LivePortraitInferencer:
|
|
322 |
return None
|
323 |
|
324 |
out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list])
|
325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
def download_if_no_models(self,
|
328 |
model_type: str = ModelType.HUMAN.value,
|
|
|
14 |
|
15 |
from modules.utils.paths import *
|
16 |
from modules.utils.image_helper import *
|
17 |
+
from modules.utils.video_helper import *
|
18 |
from modules.live_portrait.model_downloader import *
|
19 |
from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
|
20 |
from modules.utils.camera import get_rotation_matrix
|
|
|
242 |
raise
|
243 |
|
244 |
def create_video(self,
|
245 |
+
model_type: str = ModelType.HUMAN.value,
|
246 |
+
retargeting_eyes: bool = True,
|
247 |
+
retargeting_mouth: bool = True,
|
248 |
+
tracking_src_vid: bool = True,
|
249 |
+
animate_without_vid: bool = False,
|
250 |
+
crop_factor: float = 1.5,
|
251 |
src_image_list: Optional[List[np.ndarray]] = None,
|
252 |
driving_images: Optional[List[np.ndarray]] = None,
|
253 |
progress: gr.Progress = gr.Progress()
|
254 |
):
|
255 |
+
if self.pipeline is None or model_type != self.model_type:
|
256 |
+
self.load_models(
|
257 |
+
model_type=model_type
|
258 |
+
)
|
259 |
+
|
260 |
src_length = 1
|
261 |
|
262 |
if src_image_list is not None:
|
|
|
329 |
return None
|
330 |
|
331 |
out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list])
|
332 |
+
out_imgs = [tensor.permute(1, 2, 0).cpu().numpy() for tensor in out_imgs]
|
333 |
+
for img in out_imgs:
|
334 |
+
out_frame_path = get_auto_incremental_file_path(TEMP_VIDEO_OUT_FRAMES_DIR, "png")
|
335 |
+
save_image(img, out_frame_path)
|
336 |
+
|
337 |
+
video_path = create_video_from_frames(TEMP_VIDEO_OUT_FRAMES_DIR)
|
338 |
+
|
339 |
+
return video_path
|
340 |
|
341 |
def download_if_no_models(self,
|
342 |
model_type: str = ModelType.HUMAN.value,
|