mishig HF staff commited on
Commit
310e819
·
1 Parent(s): 7225d85
Files changed (1) hide show
  1. app.py +17 -1
app.py CHANGED
@@ -1,11 +1,17 @@
1
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
2
  from diffusers import UniPCMultistepScheduler
3
  import gradio as gr
 
4
  import torch
5
  import base64
 
6
  from io import BytesIO
7
  from PIL import Image, ImageFilter
8
 
 
 
 
 
9
  canvas_html = '<pose-maker/>'
10
  load_js = """
11
  async () => {
@@ -31,7 +37,7 @@ async (canvas, prompt) => {
31
 
32
  # Models
33
  controlnet = ControlNetModel.from_pretrained(
34
- "lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16
35
  )
36
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
37
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
@@ -48,6 +54,15 @@ pipe.enable_xformers_memory_efficient_attention()
48
  # Generator seed,
49
  generator = torch.manual_seed(0)
50
 
 
 
 
 
 
 
 
 
 
51
 
52
  def generate_images(canvas, prompt):
53
  try:
@@ -56,6 +71,7 @@ def generate_images(canvas, prompt):
56
  input_img = Image.open(BytesIO(image_data)).convert(
57
  'RGB').resize((512, 512))
58
  input_img = input_img.filter(ImageFilter.GaussianBlur(radius=5))
 
59
  output = pipe(
60
  prompt,
61
  input_img,
 
1
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
2
  from diffusers import UniPCMultistepScheduler
3
  import gradio as gr
4
+ import numpy as np
5
  import torch
6
  import base64
7
+ import cv2
8
  from io import BytesIO
9
  from PIL import Image, ImageFilter
10
 
11
+ # Constants
12
+ low_threshold = 100
13
+ high_threshold = 200
14
+
15
  canvas_html = '<pose-maker/>'
16
  load_js = """
17
  async () => {
 
37
 
38
  # Models
39
  controlnet = ControlNetModel.from_pretrained(
40
+ "lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16
41
  )
42
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
43
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
 
54
  # Generator seed,
55
  generator = torch.manual_seed(0)
56
 
57
+ def get_canny_filter(image):
58
+ if not isinstance(image, np.ndarray):
59
+ image = np.array(image)
60
+
61
+ image = cv2.Canny(image, low_threshold, high_threshold)
62
+ image = image[:, :, None]
63
+ image = np.concatenate([image, image, image], axis=2)
64
+ canny_image = Image.fromarray(image)
65
+ return canny_image
66
 
67
  def generate_images(canvas, prompt):
68
  try:
 
71
  input_img = Image.open(BytesIO(image_data)).convert(
72
  'RGB').resize((512, 512))
73
  input_img = input_img.filter(ImageFilter.GaussianBlur(radius=5))
74
+ input_img = get_canny_filter(input_img)
75
  output = pipe(
76
  prompt,
77
  input_img,