guardiancc commited on
Commit
bfcd340
·
verified ·
1 Parent(s): 0fb75b6

Update engine.py

Browse files
Files changed (1) hide show
  1. engine.py +20 -1
engine.py CHANGED
@@ -18,6 +18,8 @@ import tqdm
18
  from tqdm import tqdm as loader
19
 
20
  import cv2
 
 
21
 
22
  from liveportrait.config.argument_config import ArgumentConfig
23
  from liveportrait.utils.camera import get_rotation_matrix
@@ -31,6 +33,21 @@ logger = logging.getLogger(__name__)
31
  # Global constants
32
  DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
33
  MODELS_DIR = os.path.join(DATA_ROOT, "models")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def base64_data_uri_to_PIL_Image(base64_string: str) -> Image.Image:
36
  """
@@ -113,7 +130,9 @@ class Engine:
113
 
114
  _, frame = await self.transform_frame(processed_data, params)
115
  bgr_frame = cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB)
116
- video_writer.write(cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB))
 
 
117
 
118
  video_writer.release()
119
  cap.release()
 
18
  from tqdm import tqdm as loader
19
 
20
  import cv2
21
+ import torch
22
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
23
 
24
  from liveportrait.config.argument_config import ArgumentConfig
25
  from liveportrait.utils.camera import get_rotation_matrix
 
33
  # Global constants
34
  DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
35
  MODELS_DIR = os.path.join(DATA_ROOT, "models")
36
+ os.system("pip freeze")
37
+
38
+ if not os.path.exists('RestoreFormer.pth'):
39
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
40
+
41
+ if not os.path.exists('realesr-general-x4v3.pth'):
42
+ os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
43
+
44
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
45
+ model_path = 'realesr-general-x4v3.pth'
46
+ half = True if torch.cuda.is_available() else False
47
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
48
+
49
+ enhancer = GFPGANer(
50
+ model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
51
 
52
  def base64_data_uri_to_PIL_Image(base64_string: str) -> Image.Image:
53
  """
 
130
 
131
  _, frame = await self.transform_frame(processed_data, params)
132
  bgr_frame = cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB)
133
+ new_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
134
+ _, _, output = enhancer.enhance(new_frame, has_aligned=False, only_center_face=False, paste_back=True)
135
+ video_writer.write(output)
136
 
137
  video_writer.release()
138
  cap.release()