Ii commited on
Commit
dcca02f
·
verified ·
1 Parent(s): 68d4f08

Update refacer.py

Browse files
Files changed (1) hide show
  1. refacer.py +35 -35
refacer.py CHANGED
@@ -24,20 +24,20 @@ import re
24
  import subprocess
25
 
26
  class RefacerMode(Enum):
27
- CPU, CUDA, COREML, TENSORRT = range(1, 5)
28
 
29
  class Refacer:
30
- def __init__(self,force_cpu=False,colab_performance=False):
31
  self.first_face = False
32
  self.force_cpu = force_cpu
33
  self.colab_performance = colab_performance
34
- self.__check_encoders()
35
  self.__check_providers()
36
  self.total_mem = psutil.virtual_memory().total
37
  self.__init_apps()
38
 
39
  def __check_providers(self):
40
- if self.force_cpu :
41
  self.providers = ['CPUExecutionProvider']
42
  else:
43
  self.providers = rt.get_available_providers()
@@ -48,18 +48,18 @@ class Refacer:
48
 
49
  if len(self.providers) == 1 and 'CPUExecutionProvider' in self.providers:
50
  self.mode = RefacerMode.CPU
51
- self.use_num_cpus = mp.cpu_count()-1
52
- self.sess_options.intra_op_num_threads = int(self.use_num_cpus/3)
53
  print(f"CPU mode with providers {self.providers}")
54
  elif self.colab_performance:
55
  self.mode = RefacerMode.TENSORRT
56
- self.use_num_cpus = mp.cpu_count()-1
57
- self.sess_options.intra_op_num_threads = int(self.use_num_cpus/3)
58
  print(f"TENSORRT mode with providers {self.providers}")
59
  elif 'CoreMLExecutionProvider' in self.providers:
60
  self.mode = RefacerMode.COREML
61
- self.use_num_cpus = mp.cpu_count()-1
62
- self.sess_options.intra_op_num_threads = int(self.use_num_cpus/3)
63
  print(f"CoreML mode with providers {self.providers}")
64
  elif 'CUDAExecutionProvider' in self.providers:
65
  self.mode = RefacerMode.CUDA
@@ -74,25 +74,25 @@ class Refacer:
74
 
75
  model_path = os.path.join(assets_dir, 'det_10g.onnx')
76
  sess_face = rt.InferenceSession(model_path, self.sess_options, providers=self.providers)
77
- self.face_detector = SCRFD(model_path,sess_face)
78
- self.face_detector.prepare(0,input_size=(640, 640))
79
 
80
- model_path = os.path.join(assets_dir , 'w600k_r50.onnx')
81
  sess_rec = rt.InferenceSession(model_path, self.sess_options, providers=self.providers)
82
- self.rec_app = ArcFaceONNX(model_path,sess_rec)
83
  self.rec_app.prepare(0)
84
 
85
  model_path = 'inswapper_128.onnx'
86
  sess_swap = rt.InferenceSession(model_path, self.sess_options, providers=self.providers)
87
- self.face_swapper = INSwapper(model_path,sess_swap)
88
 
89
  def prepare_faces(self, faces):
90
- self.replacement_faces=[]
91
  for face in faces:
92
  if "origin" in face:
93
  face_threshold = face['threshold']
94
- bboxes1, kpss1 = self.face_detector.autodetect(face['origin'], max_num=1)
95
- if len(kpss1)<1:
96
  raise Exception('No face detected on "Face to replace" image')
97
  feat_original = self.rec_app.get(face['origin'], kpss1[0])
98
  else:
@@ -100,13 +100,13 @@ class Refacer:
100
  self.first_face = True
101
  feat_original = None
102
  print('No origin image: First face change')
103
- _faces = self.__get_faces(face['destination'],max_num=1)
104
- if len(_faces)<1:
105
  raise Exception('No face detected on "Destination face" image')
106
- self.replacement_faces.append((feat_original,_faces[0],face_threshold))
107
 
108
- def __get_faces(self,frame,max_num=0):
109
- bboxes, kpss = self.face_detector.detect(frame,max_num=max_num,metric='default')
110
 
111
  if bboxes.shape[0] == 0:
112
  return []
@@ -122,24 +122,24 @@ class Refacer:
122
  ret.append(face)
123
  return ret
124
 
125
- def process_first_face(self,frame):
126
- faces = self.__get_faces(frame,max_num=1)
127
  if len(faces) != 0:
128
  frame = self.face_swapper.get(frame, faces[0], self.replacement_faces[0][1], paste_back=True)
129
  return frame
130
 
131
- def process_faces(self,frame):
132
- faces = self.__get_faces(frame,max_num=0)
133
  for rep_face in self.replacement_faces:
134
  for i in range(len(faces) - 1, -1, -1):
135
  sim = self.rec_app.compute_sim(rep_face[0], faces[i].embedding)
136
- if sim>=rep_face[2]:
137
  frame = self.face_swapper.get(frame, faces[i], rep_face[1], paste_back=True)
138
  del faces[i]
139
  break
140
  return frame
141
 
142
- def __check_video_has_audio(self,video_path):
143
  self.video_has_audio = False
144
  probe = ffmpeg.probe(video_path)
145
  audio_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'audio'), None)
@@ -148,11 +148,11 @@ class Refacer:
148
 
149
  def reface_group(self, faces, frames):
150
  results = []
151
- with ThreadPoolExecutor(max_workers = self.use_num_cpus) as executor:
152
  if self.first_face:
153
- results = list(tqdm(executor.map(self.process_first_face, frames), total=len(frames),desc="Processing frames"))
154
  else:
155
- results = list(tqdm(executor.map(self.process_faces, frames), total=len(frames),desc="Processing frames"))
156
  return results
157
 
158
  def reface(self, video_path, faces):
@@ -162,13 +162,13 @@ class Refacer:
162
  cap = cv2.VideoCapture(video_path)
163
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
164
  print(f"Total frames: {total_frames}")
165
-
166
  fps = cap.get(cv2.CAP_PROP_FPS)
167
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
168
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
169
 
170
  frames = []
171
- with tqdm(total=total_frames,desc="Extracting frames") as pbar:
172
  while cap.isOpened():
173
  flag, frame = cap.read()
174
  if flag and len(frame) > 0:
@@ -181,7 +181,7 @@ class Refacer:
181
  pbar.close()
182
 
183
  refaced_frames = self.reface_group(faces, frames)
184
-
185
  video_buffer = io.BytesIO()
186
  out = cv2.VideoWriter('temp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
187
 
 
24
  import subprocess
25
 
26
  class RefacerMode(Enum):
27
+ CPU, CUDA, COREML, TENSORRT = range(1, 5)
28
 
29
  class Refacer:
30
+ def __init__(self, force_cpu=False, colab_performance=False):
31
  self.first_face = False
32
  self.force_cpu = force_cpu
33
  self.colab_performance = colab_performance
34
+ self.__check_encoders() # Correct method name
35
  self.__check_providers()
36
  self.total_mem = psutil.virtual_memory().total
37
  self.__init_apps()
38
 
39
  def __check_providers(self):
40
+ if self.force_cpu:
41
  self.providers = ['CPUExecutionProvider']
42
  else:
43
  self.providers = rt.get_available_providers()
 
48
 
49
  if len(self.providers) == 1 and 'CPUExecutionProvider' in self.providers:
50
  self.mode = RefacerMode.CPU
51
+ self.use_num_cpus = mp.cpu_count() - 1
52
+ self.sess_options.intra_op_num_threads = int(self.use_num_cpus / 3)
53
  print(f"CPU mode with providers {self.providers}")
54
  elif self.colab_performance:
55
  self.mode = RefacerMode.TENSORRT
56
+ self.use_num_cpus = mp.cpu_count() - 1
57
+ self.sess_options.intra_op_num_threads = int(self.use_num_cpus / 3)
58
  print(f"TENSORRT mode with providers {self.providers}")
59
  elif 'CoreMLExecutionProvider' in self.providers:
60
  self.mode = RefacerMode.COREML
61
+ self.use_num_cpus = mp.cpu_count() - 1
62
+ self.sess_options.intra_op_num_threads = int(self.use_num_cpus / 3)
63
  print(f"CoreML mode with providers {self.providers}")
64
  elif 'CUDAExecutionProvider' in self.providers:
65
  self.mode = RefacerMode.CUDA
 
74
 
75
  model_path = os.path.join(assets_dir, 'det_10g.onnx')
76
  sess_face = rt.InferenceSession(model_path, self.sess_options, providers=self.providers)
77
+ self.face_detector = SCRFD(model_path, sess_face)
78
+ self.face_detector.prepare(0, input_size=(640, 640))
79
 
80
+ model_path = os.path.join(assets_dir, 'w600k_r50.onnx')
81
  sess_rec = rt.InferenceSession(model_path, self.sess_options, providers=self.providers)
82
+ self.rec_app = ArcFaceONNX(model_path, sess_rec)
83
  self.rec_app.prepare(0)
84
 
85
  model_path = 'inswapper_128.onnx'
86
  sess_swap = rt.InferenceSession(model_path, self.sess_options, providers=self.providers)
87
+ self.face_swapper = INSwapper(model_path, sess_swap)
88
 
89
  def prepare_faces(self, faces):
90
+ self.replacement_faces = []
91
  for face in faces:
92
  if "origin" in face:
93
  face_threshold = face['threshold']
94
+ bboxes1, kpss1 = self.face_detector.autodetect(face['origin'], max_num=1)
95
+ if len(kpss1) < 1:
96
  raise Exception('No face detected on "Face to replace" image')
97
  feat_original = self.rec_app.get(face['origin'], kpss1[0])
98
  else:
 
100
  self.first_face = True
101
  feat_original = None
102
  print('No origin image: First face change')
103
+ _faces = self.__get_faces(face['destination'], max_num=1)
104
+ if len(_faces) < 1:
105
  raise Exception('No face detected on "Destination face" image')
106
+ self.replacement_faces.append((feat_original, _faces[0], face_threshold))
107
 
108
+ def __get_faces(self, frame, max_num=0):
109
+ bboxes, kpss = self.face_detector.detect(frame, max_num=max_num, metric='default')
110
 
111
  if bboxes.shape[0] == 0:
112
  return []
 
122
  ret.append(face)
123
  return ret
124
 
125
+ def process_first_face(self, frame):
126
+ faces = self.__get_faces(frame, max_num=1)
127
  if len(faces) != 0:
128
  frame = self.face_swapper.get(frame, faces[0], self.replacement_faces[0][1], paste_back=True)
129
  return frame
130
 
131
+ def process_faces(self, frame):
132
+ faces = self.__get_faces(frame, max_num=0)
133
  for rep_face in self.replacement_faces:
134
  for i in range(len(faces) - 1, -1, -1):
135
  sim = self.rec_app.compute_sim(rep_face[0], faces[i].embedding)
136
+ if sim >= rep_face[2]:
137
  frame = self.face_swapper.get(frame, faces[i], rep_face[1], paste_back=True)
138
  del faces[i]
139
  break
140
  return frame
141
 
142
+ def __check_video_has_audio(self, video_path):
143
  self.video_has_audio = False
144
  probe = ffmpeg.probe(video_path)
145
  audio_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'audio'), None)
 
148
 
149
  def reface_group(self, faces, frames):
150
  results = []
151
+ with ThreadPoolExecutor(max_workers=self.use_num_cpus) as executor:
152
  if self.first_face:
153
+ results = list(tqdm(executor.map(self.process_first_face, frames), total=len(frames), desc="Processing frames"))
154
  else:
155
+ results = list(tqdm(executor.map(self.process_faces, frames), total=len(frames), desc="Processing frames"))
156
  return results
157
 
158
  def reface(self, video_path, faces):
 
162
  cap = cv2.VideoCapture(video_path)
163
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
164
  print(f"Total frames: {total_frames}")
165
+
166
  fps = cap.get(cv2.CAP_PROP_FPS)
167
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
168
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
169
 
170
  frames = []
171
+ with tqdm(total=total_frames, desc="Extracting frames") as pbar:
172
  while cap.isOpened():
173
  flag, frame = cap.read()
174
  if flag and len(frame) > 0:
 
181
  pbar.close()
182
 
183
  refaced_frames = self.reface_group(faces, frames)
184
+
185
  video_buffer = io.BytesIO()
186
  out = cv2.VideoWriter('temp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
187