Jonny001 commited on
Commit
fb94566
1 Parent(s): 19136f3

Update roop/processors/frame/face_enhancer.py

Browse files
roop/processors/frame/face_enhancer.py CHANGED
@@ -2,6 +2,8 @@ from typing import Any, List, Callable
2
  import cv2
3
  import threading
4
  import gfpgan
 
 
5
 
6
  import roop.globals
7
  import roop.processors.frame.core
@@ -15,6 +17,8 @@ THREAD_SEMAPHORE = threading.Semaphore()
15
  THREAD_LOCK = threading.Lock()
16
  NAME = 'ROOP.FACE-ENHANCER'
17
 
 
 
18
 
19
  def get_face_enhancer() -> Any:
20
  global FACE_ENHANCER
@@ -22,60 +26,93 @@ def get_face_enhancer() -> Any:
22
  with THREAD_LOCK:
23
  if FACE_ENHANCER is None:
24
  model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
25
- # todo: set models path https://github.com/TencentARC/GFPGAN/issues/399
26
- FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=5) # type: ignore[attr-defined]
 
 
 
 
27
  return FACE_ENHANCER
28
 
29
-
30
  def pre_check() -> bool:
31
  download_directory_path = resolve_relative_path('../models')
32
- conditional_download(download_directory_path, ['https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth'])
33
- return True
34
-
 
 
 
 
35
 
36
  def pre_start() -> bool:
37
- if not is_image(roop.globals.target_path) and not is_video(roop.globals.target_path):
38
- update_status('Select an image or video for target path.', NAME)
 
 
 
 
 
 
39
  return False
40
- return True
41
-
42
 
43
  def post_process() -> None:
44
  global FACE_ENHANCER
45
 
46
  FACE_ENHANCER = None
47
-
48
 
49
  def enhance_face(temp_frame: Frame) -> Frame:
50
- with THREAD_SEMAPHORE:
51
- _, _, temp_frame = get_face_enhancer().enhance(
52
- temp_frame,
53
- paste_back=True
54
- )
55
- return temp_frame
56
-
 
 
 
57
 
58
  def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
59
- target_face = get_one_face(temp_frame)
60
- if target_face:
61
- temp_frame = enhance_face(temp_frame)
62
- return temp_frame
63
-
 
 
 
64
 
65
  def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
66
- for temp_frame_path in temp_frame_paths:
67
- temp_frame = cv2.imread(temp_frame_path)
68
- result = process_frame(None, temp_frame)
69
- cv2.imwrite(temp_frame_path, result)
70
- if update:
71
- update()
72
-
 
 
 
 
 
 
73
 
74
  def process_image(source_path: str, target_path: str, output_path: str) -> None:
75
- target_frame = cv2.imread(target_path)
76
- result = process_frame(None, target_frame)
77
- cv2.imwrite(output_path, result)
78
-
 
 
 
 
 
 
79
 
80
  def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
81
- roop.processors.frame.core.process_video(None, temp_frame_paths, process_frames)
 
 
 
 
 
2
  import cv2
3
  import threading
4
  import gfpgan
5
+ import os
6
+ import logging
7
 
8
  import roop.globals
9
  import roop.processors.frame.core
 
17
  THREAD_LOCK = threading.Lock()
18
  NAME = 'ROOP.FACE-ENHANCER'
19
 
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
 
23
  def get_face_enhancer() -> Any:
24
  global FACE_ENHANCER
 
26
  with THREAD_LOCK:
27
  if FACE_ENHANCER is None:
28
  model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
29
+ try:
30
+ FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=5) # type: ignore[attr-defined]
31
+ logging.info(f"Loaded face enhancer model from {model_path}")
32
+ except Exception as e:
33
+ logging.error(f"Failed to load face enhancer model: {e}")
34
+ FACE_ENHANCER = None
35
  return FACE_ENHANCER
36
 
 
37
  def pre_check() -> bool:
38
  download_directory_path = resolve_relative_path('../models')
39
+ try:
40
+ conditional_download(download_directory_path, ['https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth'])
41
+ logging.info("Pre-check completed successfully.")
42
+ return True
43
+ except Exception as e:
44
+ logging.error(f"Pre-check failed: {e}")
45
+ return False
46
 
47
  def pre_start() -> bool:
48
+ try:
49
+ if not is_image(roop.globals.target_path) and not is_video(roop.globals.target_path):
50
+ update_status('Select an image or video for target path.', NAME)
51
+ return False
52
+ logging.info("Pre-start checks passed.")
53
+ return True
54
+ except Exception as e:
55
+ logging.error(f"Pre-start check failed: {e}")
56
  return False
 
 
57
 
58
  def post_process() -> None:
59
  global FACE_ENHANCER
60
 
61
  FACE_ENHANCER = None
62
+ logging.info("Post-process cleanup done.")
63
 
64
  def enhance_face(temp_frame: Frame) -> Frame:
65
+ try:
66
+ with THREAD_SEMAPHORE:
67
+ _, _, temp_frame = get_face_enhancer().enhance(
68
+ temp_frame,
69
+ paste_back=True
70
+ )
71
+ return temp_frame
72
+ except Exception as e:
73
+ logging.error(f"Error enhancing face: {e}")
74
+ return temp_frame # Return the unmodified frame in case of error
75
 
76
  def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
77
+ try:
78
+ target_face = get_one_face(temp_frame)
79
+ if target_face:
80
+ temp_frame = enhance_face(temp_frame)
81
+ return temp_frame
82
+ except Exception as e:
83
+ logging.error(f"Error processing frame: {e}")
84
+ return temp_frame # Return the unmodified frame in case of error
85
 
86
  def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
87
+ try:
88
+ for temp_frame_path in temp_frame_paths:
89
+ temp_frame = cv2.imread(temp_frame_path)
90
+ if temp_frame is None:
91
+ raise ValueError(f"Failed to read frame from path: {temp_frame_path}")
92
+
93
+ result = process_frame(None, temp_frame)
94
+ cv2.imwrite(temp_frame_path, result)
95
+ if update:
96
+ update()
97
+ logging.info("Frames processed successfully.")
98
+ except Exception as e:
99
+ logging.error(f"Error processing frames: {e}")
100
 
101
  def process_image(source_path: str, target_path: str, output_path: str) -> None:
102
+ try:
103
+ target_frame = cv2.imread(target_path)
104
+ if target_frame is None:
105
+ raise ValueError("Failed to read target frame.")
106
+
107
+ result = process_frame(None, target_frame)
108
+ cv2.imwrite(output_path, result)
109
+ logging.info(f"Image processed and saved to {output_path}.")
110
+ except Exception as e:
111
+ logging.error(f"Error processing image: {e}")
112
 
113
  def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
114
+ try:
115
+ roop.processors.frame.core.process_video(None, temp_frame_paths, process_frames)
116
+ logging.info("Video processing completed.")
117
+ except Exception as e:
118
+ logging.error(f"Error processing video: {e}")