jhj0517 commited on
Commit
46fa2af
·
1 Parent(s): 7785d3b

Apply model type enum

Browse files
app.py CHANGED
@@ -22,6 +22,8 @@ class App:
22
  @staticmethod
23
  def create_parameters():
24
  return [
 
 
25
  gr.Slider(label=_("Rotate Pitch"), minimum=-20, maximum=20, step=0.5, value=0),
26
  gr.Slider(label=_("Rotate Yaw"), minimum=-20, maximum=20, step=0.5, value=0),
27
  gr.Slider(label=_("Rotate Roll"), minimum=-20, maximum=20, step=0.5, value=0),
 
22
  @staticmethod
23
  def create_parameters():
24
  return [
25
+ gr.Dropdown(label=_("Model Type"),
26
+ choices=[item.value for item in ModelType], value=ModelType.HUMAN.value),
27
  gr.Slider(label=_("Rotate Pitch"), minimum=-20, maximum=20, step=0.5, value=0),
28
  gr.Slider(label=_("Rotate Yaw"), minimum=-20, maximum=20, step=0.5, value=0),
29
  gr.Slider(label=_("Rotate Roll"), minimum=-20, maximum=20, step=0.5, value=0),
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
  import cv2
3
  import time
4
  import copy
@@ -6,7 +7,10 @@ import dill
6
  from ultralytics import YOLO
7
  import safetensors.torch
8
  import gradio as gr
 
9
  from ultralytics.utils import LOGGER as ultralytics_logger
 
 
10
 
11
  from modules.utils.paths import *
12
  from modules.utils.image_helper import *
@@ -14,6 +18,7 @@ from modules.live_portrait.model_downloader import *
14
  from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
15
  from modules.utils.camera import get_rotation_matrix
16
  from modules.utils.helper import load_yaml
 
17
  from modules.config.inference_config import InferenceConfig
18
  from modules.live_portrait.spade_generator import SPADEDecoder
19
  from modules.live_portrait.warping_network import WarpingNetwork
@@ -27,6 +32,7 @@ class LivePortraitInferencer:
27
  model_dir: str = MODELS_DIR,
28
  output_dir: str = OUTPUTS_DIR):
29
  self.model_dir = model_dir
 
30
  self.output_dir = output_dir
31
  self.model_config = load_yaml(MODEL_CONFIG)["model_params"]
32
 
@@ -38,6 +44,7 @@ class LivePortraitInferencer:
38
  self.pipeline = None
39
  self.detect_model = None
40
  self.device = self.get_device()
 
41
 
42
  self.mask_img = None
43
  self.temp_img_idx = 0
@@ -52,8 +59,22 @@ class LivePortraitInferencer:
52
  self.d_info = None
53
 
54
  def load_models(self,
 
55
  progress=gr.Progress()):
56
- self.download_if_no_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  total_models_num = 5
59
  progress(0/total_models_num, desc="Loading Appearance Feature Extractor model...")
@@ -61,7 +82,7 @@ class LivePortraitInferencer:
61
  self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
62
  self.appearance_feature_extractor = self.load_safe_tensor(
63
  self.appearance_feature_extractor,
64
- os.path.join(self.model_dir, "appearance_feature_extractor.safetensors")
65
  )
66
 
67
  progress(1/total_models_num, desc="Loading Motion Extractor model...")
@@ -69,7 +90,7 @@ class LivePortraitInferencer:
69
  self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
70
  self.motion_extractor = self.load_safe_tensor(
71
  self.motion_extractor,
72
- os.path.join(self.model_dir, "motion_extractor.safetensors")
73
  )
74
 
75
  progress(2/total_models_num, desc="Loading Warping Module model...")
@@ -77,7 +98,7 @@ class LivePortraitInferencer:
77
  self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
78
  self.warping_module = self.load_safe_tensor(
79
  self.warping_module,
80
- os.path.join(self.model_dir, "warping_module.safetensors")
81
  )
82
 
83
  progress(3/total_models_num, desc="Loading Spade generator model...")
@@ -85,7 +106,7 @@ class LivePortraitInferencer:
85
  self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
86
  self.spade_generator = self.load_safe_tensor(
87
  self.spade_generator,
88
- os.path.join(self.model_dir, "spade_generator.safetensors")
89
  )
90
 
91
  progress(4/total_models_num, desc="Loading Stitcher model...")
@@ -93,7 +114,7 @@ class LivePortraitInferencer:
93
  self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching')).to(self.device)
94
  self.stitching_retargeting_module = self.load_safe_tensor(
95
  self.stitching_retargeting_module,
96
- os.path.join(self.model_dir, "stitching_retargeting_module.safetensors"),
97
  True
98
  )
99
  self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
@@ -111,6 +132,7 @@ class LivePortraitInferencer:
111
  self.detect_model = YOLO(MODEL_PATHS["face_yolov8n"]).to(self.device)
112
 
113
  def edit_expression(self,
 
114
  rotate_pitch=0,
115
  rotate_yaw=0,
116
  rotate_roll=0,
@@ -131,8 +153,15 @@ class LivePortraitInferencer:
131
  sample_image=None,
132
  motion_link=None,
133
  add_exp=None):
134
- if self.pipeline is None:
135
- self.load_models()
 
 
 
 
 
 
 
136
 
137
  try:
138
  rotate_yaw = -rotate_yaw
@@ -330,14 +359,27 @@ class LivePortraitInferencer:
330
  return out_imgs
331
 
332
  def download_if_no_models(self,
333
- progress=gr.Progress()):
 
334
  progress(0, desc="Downloading models...")
335
- for model_name, model_url in MODELS_URL.items():
 
 
 
 
 
 
 
 
 
 
336
  if model_url.endswith(".pt"):
337
  model_name += ".pt"
 
 
338
  else:
339
  model_name += ".safetensors"
340
- model_path = os.path.join(self.model_dir, model_name)
341
  if not os.path.exists(model_path):
342
  download_model(model_path, model_url)
343
 
@@ -779,3 +821,4 @@ class Command:
779
  self.es = es
780
  self.change = change
781
  self.keep = keep
 
 
1
  import logging
2
+ import os
3
  import cv2
4
  import time
5
  import copy
 
7
  from ultralytics import YOLO
8
  import safetensors.torch
9
  import gradio as gr
10
+ from gradio_i18n import Translate, gettext as _
11
  from ultralytics.utils import LOGGER as ultralytics_logger
12
+ from enum import Enum
13
+ from typing import Union
14
 
15
  from modules.utils.paths import *
16
  from modules.utils.image_helper import *
 
18
  from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
19
  from modules.utils.camera import get_rotation_matrix
20
  from modules.utils.helper import load_yaml
21
+ from modules.utils.constants import *
22
  from modules.config.inference_config import InferenceConfig
23
  from modules.live_portrait.spade_generator import SPADEDecoder
24
  from modules.live_portrait.warping_network import WarpingNetwork
 
32
  model_dir: str = MODELS_DIR,
33
  output_dir: str = OUTPUTS_DIR):
34
  self.model_dir = model_dir
35
+ os.makedirs(os.path.join(self.model_dir, "animal"), exist_ok=True)
36
  self.output_dir = output_dir
37
  self.model_config = load_yaml(MODEL_CONFIG)["model_params"]
38
 
 
44
  self.pipeline = None
45
  self.detect_model = None
46
  self.device = self.get_device()
47
+ self.model_type = ModelType.HUMAN.value
48
 
49
  self.mask_img = None
50
  self.temp_img_idx = 0
 
59
  self.d_info = None
60
 
61
  def load_models(self,
62
+ model_type: str = ModelType.HUMAN.value,
63
  progress=gr.Progress()):
64
+ if isinstance(model_type, ModelType):
65
+ model_type = model_type.value
66
+ if model_type not in [mode.value for mode in ModelType]:
67
+ model_type = ModelType.HUMAN.value
68
+
69
+ self.model_type = model_type
70
+ if model_type == ModelType.ANIMAL.value:
71
+ model_dir = os.path.join(self.model_dir, "animal")
72
+ else:
73
+ model_dir = self.model_dir
74
+
75
+ self.download_if_no_models(
76
+ model_type=model_type
77
+ )
78
 
79
  total_models_num = 5
80
  progress(0/total_models_num, desc="Loading Appearance Feature Extractor model...")
 
82
  self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
83
  self.appearance_feature_extractor = self.load_safe_tensor(
84
  self.appearance_feature_extractor,
85
+ os.path.join(model_dir, "appearance_feature_extractor.safetensors")
86
  )
87
 
88
  progress(1/total_models_num, desc="Loading Motion Extractor model...")
 
90
  self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
91
  self.motion_extractor = self.load_safe_tensor(
92
  self.motion_extractor,
93
+ os.path.join(model_dir, "motion_extractor.safetensors")
94
  )
95
 
96
  progress(2/total_models_num, desc="Loading Warping Module model...")
 
98
  self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
99
  self.warping_module = self.load_safe_tensor(
100
  self.warping_module,
101
+ os.path.join(model_dir, "warping_module.safetensors")
102
  )
103
 
104
  progress(3/total_models_num, desc="Loading Spade generator model...")
 
106
  self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
107
  self.spade_generator = self.load_safe_tensor(
108
  self.spade_generator,
109
+ os.path.join(model_dir, "spade_generator.safetensors")
110
  )
111
 
112
  progress(4/total_models_num, desc="Loading Stitcher model...")
 
114
  self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching')).to(self.device)
115
  self.stitching_retargeting_module = self.load_safe_tensor(
116
  self.stitching_retargeting_module,
117
+ os.path.join(model_dir, "stitching_retargeting_module.safetensors"),
118
  True
119
  )
120
  self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
 
132
  self.detect_model = YOLO(MODEL_PATHS["face_yolov8n"]).to(self.device)
133
 
134
  def edit_expression(self,
135
+ model_type: str = ModelType.HUMAN.value,
136
  rotate_pitch=0,
137
  rotate_yaw=0,
138
  rotate_roll=0,
 
153
  sample_image=None,
154
  motion_link=None,
155
  add_exp=None):
156
+ if isinstance(model_type, ModelType):
157
+ model_type = model_type.value
158
+ if model_type not in [mode.value for mode in ModelType]:
159
+ model_type = ModelType.HUMAN
160
+
161
+ if self.pipeline is None or model_type != self.model_type:
162
+ self.load_models(
163
+ model_type=model_type
164
+ )
165
 
166
  try:
167
  rotate_yaw = -rotate_yaw
 
359
  return out_imgs
360
 
361
  def download_if_no_models(self,
362
+ model_type: str = ModelType.HUMAN.value,
363
+ progress=gr.Progress(), ):
364
  progress(0, desc="Downloading models...")
365
+
366
+ if isinstance(model_type, ModelType):
367
+ model_type = model_type.value
368
+ if model_type == ModelType.ANIMAL.value:
369
+ models_urls_dic = MODELS_ANIMAL_URL
370
+ model_dir = os.path.join(self.model_dir, "animal")
371
+ else:
372
+ models_urls_dic = MODELS_URL
373
+ model_dir = self.model_dir
374
+
375
+ for model_name, model_url in models_urls_dic.items():
376
  if model_url.endswith(".pt"):
377
  model_name += ".pt"
378
+ # Exception for face_yolov8n.pt
379
+ model_dir = self.model_dir
380
  else:
381
  model_name += ".safetensors"
382
+ model_path = os.path.join(model_dir, model_name)
383
  if not os.path.exists(model_path):
384
  download_model(model_path, model_url)
385
 
 
821
  self.es = es
822
  self.change = change
823
  self.keep = keep
824
+