pcuenq HF staff radames commited on
Commit
556fb16
·
1 Parent(s): 46aba97

optional - Use retinaface for face detection (#7)

Browse files

- optional - Use retinaface for face detection (c57f7aedf4dedc338af570caaecc4cee74d7f2bc)
- remove print (478af14d5268d7cec5efefcc368bee4c9e99b9f7)


Co-authored-by: Radamés Ajna <[email protected]>

Files changed (2) hide show
  1. app.py +29 -19
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,27 +1,30 @@
1
  import gradio as gr
2
  import torch
3
- import dlib
4
  import numpy as np
5
  import PIL
6
  import base64
7
  from io import BytesIO
8
  from PIL import Image
9
- # Only used to convert to gray, could do it differently and remove this big dependency
10
- import cv2
11
 
12
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
13
  from diffusers import UniPCMultistepScheduler
14
 
15
  from spiga.inference.config import ModelConfig
16
  from spiga.inference.framework import SPIGAFramework
 
17
 
18
  import matplotlib.pyplot as plt
19
  from matplotlib.path import Path
20
  import matplotlib.patches as patches
21
 
22
  # Bounding boxes
23
- face_detector = dlib.get_frontal_face_detector()
24
-
 
 
 
25
  # Landmark extraction
26
  spiga_extractor = SPIGAFramework(ModelConfig("300wpublic"))
27
 
@@ -59,14 +62,19 @@ async (image_in_img, prompt, image_file_live_opt, live_conditioning) => {
59
  }
60
  """
61
 
 
62
  def get_bounding_box(image):
63
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
64
- faces = face_detector(gray)
65
- if len(faces) == 0:
66
- raise Exception("No face detected in image")
67
- face = faces[0]
68
- bbox = [face.left(), face.top(), face.width(), face.height()]
69
- return bbox
 
 
 
 
70
 
71
 
72
  def get_landmarks(image, bbox):
@@ -145,7 +153,6 @@ def get_conditioning(image):
145
  def generate_images(image_in_img, prompt, image_file_live_opt='file', live_conditioning=None):
146
  if image_in_img is None and 'image' not in live_conditioning:
147
  raise gr.Error("Please provide an image")
148
-
149
  try:
150
  if image_file_live_opt == 'file':
151
  conditioning = get_conditioning(image_in_img)
@@ -166,29 +173,31 @@ def generate_images(image_in_img, prompt, image_file_live_opt='file', live_condi
166
  except Exception as e:
167
  raise gr.Error(str(e))
168
 
 
169
  def toggle(choice):
170
  if choice == "file":
171
  return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
172
  elif choice == "webcam":
173
  return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html)
174
 
 
175
  with gr.Blocks() as blocks:
176
  gr.Markdown("""
177
  ## Generate Uncanny Faces with ControlNet Stable Diffusion
178
  [Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet)
179
  """)
180
  with gr.Row():
181
- live_conditioning = gr.JSON(value={}, visible=False)
182
  with gr.Column():
183
  image_file_live_opt = gr.Radio(["file", "webcam"], value="file",
184
- label="How would you like to upload your image?")
185
  image_in_img = gr.Image(source="upload", visible=True, type="pil")
186
  canvas = gr.HTML(None, elem_id="canvas_html", visible=False)
187
 
188
  image_file_live_opt.change(fn=toggle,
189
- inputs=[image_file_live_opt],
190
- outputs=[image_in_img, canvas],
191
- queue=False)
192
  prompt = gr.Textbox(
193
  label="Enter your prompt",
194
  max_lines=1,
@@ -198,7 +207,8 @@ with gr.Blocks() as blocks:
198
  with gr.Column():
199
  gallery = gr.Gallery().style(grid=[2], height="auto")
200
  run_button.click(fn=generate_images,
201
- inputs=[image_in_img, prompt, image_file_live_opt, live_conditioning],
 
202
  outputs=[gallery],
203
  _js=get_js_image)
204
  blocks.load(None, None, None, _js=load_js)
 
1
  import gradio as gr
2
  import torch
 
3
  import numpy as np
4
  import PIL
5
  import base64
6
  from io import BytesIO
7
  from PIL import Image
8
+ # import for face detection
9
+ import retinaface
10
 
11
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
12
  from diffusers import UniPCMultistepScheduler
13
 
14
  from spiga.inference.config import ModelConfig
15
  from spiga.inference.framework import SPIGAFramework
16
+ import spiga.demo.analyze.track.retinasort.config as cfg
17
 
18
  import matplotlib.pyplot as plt
19
  from matplotlib.path import Path
20
  import matplotlib.patches as patches
21
 
22
  # Bounding boxes
23
+ config = cfg.cfg_retinasort
24
+ face_detector = retinaface.RetinaFaceDetector(model=config['retina']['model_name'],
25
+ device='cuda' if torch.cuda.is_available() else 'cpu',
26
+ extra_features=config['retina']['extra_features'],
27
+ cfg_postreat=config['retina']['postreat'])
28
  # Landmark extraction
29
  spiga_extractor = SPIGAFramework(ModelConfig("300wpublic"))
30
 
 
62
  }
63
  """
64
 
65
+
66
  def get_bounding_box(image):
67
+ pil_image = Image.fromarray(image)
68
+ face_detector.set_input_shape(pil_image.size[1], pil_image.size[0])
69
+ features = face_detector.inference(pil_image)
70
+
71
+ if (features is None) and (len(features['bbox']) <= 0):
72
+ raise Exception("No face detected")
73
+ # get the first face detected
74
+ bbox = features['bbox'][0]
75
+ x1, y1, x2, y2 = bbox[:4]
76
+ bbox_wh = [x1, y1, x2-x1, y2-y1]
77
+ return bbox_wh
78
 
79
 
80
  def get_landmarks(image, bbox):
 
153
  def generate_images(image_in_img, prompt, image_file_live_opt='file', live_conditioning=None):
154
  if image_in_img is None and 'image' not in live_conditioning:
155
  raise gr.Error("Please provide an image")
 
156
  try:
157
  if image_file_live_opt == 'file':
158
  conditioning = get_conditioning(image_in_img)
 
173
  except Exception as e:
174
  raise gr.Error(str(e))
175
 
176
+
177
  def toggle(choice):
178
  if choice == "file":
179
  return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
180
  elif choice == "webcam":
181
  return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html)
182
 
183
+
184
  with gr.Blocks() as blocks:
185
  gr.Markdown("""
186
  ## Generate Uncanny Faces with ControlNet Stable Diffusion
187
  [Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet)
188
  """)
189
  with gr.Row():
190
+ live_conditioning = gr.JSON(value={}, visible=False)
191
  with gr.Column():
192
  image_file_live_opt = gr.Radio(["file", "webcam"], value="file",
193
+ label="How would you like to upload your image?")
194
  image_in_img = gr.Image(source="upload", visible=True, type="pil")
195
  canvas = gr.HTML(None, elem_id="canvas_html", visible=False)
196
 
197
  image_file_live_opt.change(fn=toggle,
198
+ inputs=[image_file_live_opt],
199
+ outputs=[image_in_img, canvas],
200
+ queue=False)
201
  prompt = gr.Textbox(
202
  label="Enter your prompt",
203
  max_lines=1,
 
207
  with gr.Column():
208
  gallery = gr.Gallery().style(grid=[2], height="auto")
209
  run_button.click(fn=generate_images,
210
+ inputs=[image_in_img, prompt,
211
+ image_file_live_opt, live_conditioning],
212
  outputs=[gallery],
213
  _js=get_js_image)
214
  blocks.load(None, None, None, _js=load_js)
requirements.txt CHANGED
@@ -7,3 +7,4 @@ dlib
7
  opencv-python
8
  matplotlib
9
  Pillow
 
 
7
  opencv-python
8
  matplotlib
9
  Pillow
10
+ retinaface-py>=0.0.2