Spaces:
Paused
Paused
Update engine.py
Browse files
engine.py
CHANGED
@@ -18,6 +18,8 @@ import tqdm
|
|
18 |
from tqdm import tqdm as loader
|
19 |
|
20 |
import cv2
|
|
|
|
|
21 |
|
22 |
from liveportrait.config.argument_config import ArgumentConfig
|
23 |
from liveportrait.utils.camera import get_rotation_matrix
|
@@ -31,6 +33,21 @@ logger = logging.getLogger(__name__)
|
|
31 |
# Global constants
|
32 |
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
|
33 |
MODELS_DIR = os.path.join(DATA_ROOT, "models")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def base64_data_uri_to_PIL_Image(base64_string: str) -> Image.Image:
|
36 |
"""
|
@@ -113,7 +130,9 @@ class Engine:
|
|
113 |
|
114 |
_, frame = await self.transform_frame(processed_data, params)
|
115 |
bgr_frame = cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB)
|
116 |
-
|
|
|
|
|
117 |
|
118 |
video_writer.release()
|
119 |
cap.release()
|
|
|
18 |
from tqdm import tqdm as loader
|
19 |
|
20 |
import cv2
|
21 |
+
import torch
|
22 |
+
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
23 |
|
24 |
from liveportrait.config.argument_config import ArgumentConfig
|
25 |
from liveportrait.utils.camera import get_rotation_matrix
|
|
|
33 |
# Global constants
|
34 |
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
|
35 |
MODELS_DIR = os.path.join(DATA_ROOT, "models")
|
36 |
+
os.system("pip freeze")
|
37 |
+
|
38 |
+
if not os.path.exists('RestoreFormer.pth'):
|
39 |
+
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
|
40 |
+
|
41 |
+
if not os.path.exists('realesr-general-x4v3.pth'):
|
42 |
+
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
|
43 |
+
|
44 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
45 |
+
model_path = 'realesr-general-x4v3.pth'
|
46 |
+
half = True if torch.cuda.is_available() else False
|
47 |
+
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
48 |
+
|
49 |
+
enhancer = GFPGANer(
|
50 |
+
model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
|
51 |
|
52 |
def base64_data_uri_to_PIL_Image(base64_string: str) -> Image.Image:
|
53 |
"""
|
|
|
130 |
|
131 |
_, frame = await self.transform_frame(processed_data, params)
|
132 |
bgr_frame = cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2RGB)
|
133 |
+
new_frame = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
134 |
+
_, _, output = enhancer.enhance(new_frame, has_aligned=False, only_center_face=False, paste_back=True)
|
135 |
+
video_writer.write(output)
|
136 |
|
137 |
video_writer.release()
|
138 |
cap.release()
|