pcuenq HF staff radames commited on
Commit
4e77c00
1 Parent(s): 98a9e6e

enable live conditioning (#2)

Browse files

- enable live conditioning (219f54ab9f753d37099cb613ad68703220b44db9)


Co-authored-by: Radamés Ajna <[email protected]>

Files changed (1) hide show
  1. app.py +50 -20
app.py CHANGED
@@ -3,7 +3,9 @@ import torch
3
  import dlib
4
  import numpy as np
5
  import PIL
6
-
 
 
7
  # Only used to convert to gray, could do it differently and remove this big dependency
8
  import cv2
9
 
@@ -35,6 +37,26 @@ pipe = pipe.to("cuda")
35
  # Generator seed,
36
  generator = torch.manual_seed(0)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def get_bounding_box(image):
40
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
@@ -119,13 +141,18 @@ def get_conditioning(image):
119
  return spiga_seg
120
 
121
 
122
- def generate_images(image, prompt, image_video=None):
123
- if image is None and image_video is None:
124
  raise gr.Error("Please provide an image")
125
- if image_video is not None:
126
- image = image_video
 
 
 
 
 
 
127
  try:
128
- conditioning = get_conditioning(image)
129
  output = pipe(
130
  prompt,
131
  conditioning,
@@ -139,11 +166,10 @@ def generate_images(image, prompt, image_video=None):
139
 
140
 
141
  def toggle(choice):
142
- if choice == "webcam":
143
  return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
144
- else:
145
- return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
146
-
147
 
148
  with gr.Blocks() as blocks:
149
  gr.Markdown("""
@@ -151,15 +177,17 @@ with gr.Blocks() as blocks:
151
  [Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet)
152
  """)
153
  with gr.Row():
 
154
  with gr.Column():
155
- image_or_file_opt = gr.Radio(["file", "webcam"], value="file",
156
  label="How would you like to upload your image?")
157
- image_in_video = gr.Image(
158
- source="webcam", type="pil", visible=False)
159
- image_in_img = gr.Image(
160
- source="upload", visible=True, type="pil")
161
- image_or_file_opt.change(fn=toggle, inputs=[image_or_file_opt],
162
- outputs=[image_in_video, image_in_img], queue=False)
 
163
  prompt = gr.Textbox(
164
  label="Enter your prompt",
165
  max_lines=1,
@@ -169,8 +197,10 @@ with gr.Blocks() as blocks:
169
  with gr.Column():
170
  gallery = gr.Gallery().style(grid=[2], height="auto")
171
  run_button.click(fn=generate_images,
172
- inputs=[image_in_img, prompt, image_in_video],
173
- outputs=[gallery])
 
 
174
  gr.Examples(fn=generate_images,
175
  examples=[
176
  ["./examples/pedro-512.jpg",
@@ -178,7 +208,7 @@ with gr.Blocks() as blocks:
178
  ["./examples/image1.jpg",
179
  "Highly detailed photograph of a scary clown"],
180
  ["./examples/image0.jpg",
181
- "Highly detailed photograph of Barack Obama"],
182
  ],
183
  inputs=[image_in_img, prompt],
184
  outputs=[gallery],
 
3
  import dlib
4
  import numpy as np
5
  import PIL
6
+ import base64
7
+ from io import BytesIO
8
+ from PIL import Image
9
  # Only used to convert to gray, could do it differently and remove this big dependency
10
  import cv2
11
 
 
37
  # Generator seed,
38
  generator = torch.manual_seed(0)
39
 
40
+ canvas_html = "<face-canvas id='canvas-root' style='display:flex;max-width: 500px;margin: 0 auto;'></face-canvas>"
41
+ load_js = """
42
+ async () => {
43
+ const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/face-canvas.js"
44
+ fetch(url)
45
+ .then(res => res.text())
46
+ .then(text => {
47
+ const script = document.createElement('script');
48
+ script.type = "module"
49
+ script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
50
+ document.head.appendChild(script);
51
+ });
52
+ }
53
+ """
54
+ get_js_image = """
55
+ async (image_in_img, prompt, image_file_live_opt, live_conditioning) => {
56
+ const canvasEl = document.getElementById("canvas-root");
57
+ return [image_in_img, prompt, image_file_live_opt, canvasEl._data]
58
+ }
59
+ """
60
 
61
  def get_bounding_box(image):
62
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
 
141
  return spiga_seg
142
 
143
 
144
+ def generate_images(image_in_img, prompt, image_file_live_opt='file', live_conditioning=None):
145
+ if image_in_img is None and 'image' not in live_conditioning:
146
  raise gr.Error("Please provide an image")
147
+
148
+ if image_file_live_opt == 'file':
149
+ conditioning = get_conditioning(image_in_img)
150
+ elif image_file_live_opt == 'webcam':
151
+ base64_img = live_conditioning['image']
152
+ image_data = base64.b64decode(base64_img.split(',')[1])
153
+ conditioning = Image.open(BytesIO(image_data)).convert('RGB').resize((512,512))
154
+
155
  try:
 
156
  output = pipe(
157
  prompt,
158
  conditioning,
 
166
 
167
 
168
  def toggle(choice):
169
+ if choice == "file":
170
  return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
171
+ elif choice == "webcam":
172
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html)
 
173
 
174
  with gr.Blocks() as blocks:
175
  gr.Markdown("""
 
177
  [Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet)
178
  """)
179
  with gr.Row():
180
+ live_conditioning = gr.JSON(value={}, visible=False)
181
  with gr.Column():
182
+ image_file_live_opt = gr.Radio(["file", "webcam"], value="file",
183
  label="How would you like to upload your image?")
184
+ image_in_img = gr.Image(source="upload", visible=True, type="pil")
185
+ canvas = gr.HTML(None, elem_id="canvas_html", visible=False)
186
+
187
+ image_file_live_opt.change(fn=toggle,
188
+ inputs=[image_file_live_opt],
189
+ outputs=[image_in_img, canvas],
190
+ queue=False)
191
  prompt = gr.Textbox(
192
  label="Enter your prompt",
193
  max_lines=1,
 
197
  with gr.Column():
198
  gallery = gr.Gallery().style(grid=[2], height="auto")
199
  run_button.click(fn=generate_images,
200
+ inputs=[image_in_img, prompt, image_file_live_opt, live_conditioning],
201
+ outputs=[gallery],
202
+ _js=get_js_image)
203
+ blocks.load(None, None, None, _js=load_js)
204
  gr.Examples(fn=generate_images,
205
  examples=[
206
  ["./examples/pedro-512.jpg",
 
208
  ["./examples/image1.jpg",
209
  "Highly detailed photograph of a scary clown"],
210
  ["./examples/image0.jpg",
211
+ "Highly detailed photograph of Madonna"],
212
  ],
213
  inputs=[image_in_img, prompt],
214
  outputs=[gallery],