multimodalart HF staff commited on
Commit
527bf99
·
1 Parent(s): d64b071

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -56
app.py CHANGED
@@ -2,39 +2,11 @@ import gradio as gr
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
5
-
 
6
  MAX_COLORS = 12
7
 
8
- def get_high_freq_colors(image):
9
- im = image.getcolors(maxcolors=1024*1024)
10
- sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
11
-
12
- freqs = [c[0] for c in sorted_colors]
13
- mean_freq = sum(freqs) / len(freqs)
14
-
15
- high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq/3)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency
16
- return high_freq_colors
17
-
18
- def color_quantization(image, n_colors):
19
- # Get color histogram
20
- hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
21
- # Get most frequent colors
22
- colors = np.argwhere(hist > 0)
23
- colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]]
24
- colors = colors[:n_colors]
25
- # Replace each pixel with the closest color
26
- dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2)
27
- labels = np.argmin(dists, axis=1)
28
- return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
29
-
30
- def create_binary_matrix(img_arr, target_color):
31
- print(target_color)
32
- # Create mask of pixels with target color
33
- mask = np.all(img_arr == target_color, axis=-1)
34
-
35
- # Convert mask to binary matrix
36
- binary_matrix = mask.astype(int)
37
- return binary_matrix
38
 
39
  def process_sketch(image, binary_matrixes):
40
  high_freq_colors = get_high_freq_colors(image)
@@ -43,13 +15,12 @@ def process_sketch(image, binary_matrixes):
43
  im2arr = color_quantization(im2arr, n_colors=how_many_colors)
44
 
45
  colors_fixed = []
46
- for color in high_freq_colors[1:]:
47
- r = color[1][0]
48
- g = color[1][1]
49
- b = color[1][2]
50
- binary_matrix = create_binary_matrix(im2arr, (r,g,b))
51
- binary_matrixes.append(binary_matrix)
52
- colors_fixed.append(gr.update(value=f'<div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div>'))
53
  visibilities = []
54
  colors = []
55
  for n in range(MAX_COLORS):
@@ -62,8 +33,15 @@ def process_sketch(image, binary_matrixes):
62
 
63
  def process_generation(binary_matrixes, master_prompt, *prompts):
64
  clipped_prompts = prompts[:len(binary_matrixes)]
65
- #Now: master_prompt can be used as the main prompt, and binary_matrixes and clipped_prompts can be used as the masked inputs
66
- pass
 
 
 
 
 
 
 
67
 
68
  css = '''
69
  #color-bg{display:flex;justify-content: center;align-items: center;}
@@ -72,15 +50,11 @@ css = '''
72
  '''
73
  def update_css(aspect):
74
  if(aspect=='Square'):
75
- width = 512
76
- height = 512
77
  elif(aspect == 'Horizontal'):
78
- width = 768
79
- height = 512
80
  elif(aspect=='Vertical'):
81
- width = 512
82
- height = 768
83
- return gr.update(value=f"<style>#main-image{{width: {width}px}} .fixed-height{{height: {height}px !important}}</style>")
84
 
85
  with gr.Blocks(css=css) as demo:
86
  binary_matrixes = gr.State([])
@@ -89,11 +63,13 @@ with gr.Blocks(css=css) as demo:
89
  ''')
90
  with gr.Row():
91
  with gr.Box(elem_id="main-image"):
92
- with gr.Row():
93
- image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil")
94
- with gr.Row():
95
- aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
96
- button_run = gr.Button("I've finished my sketch",elem_id="main_button")
 
 
97
 
98
  prompts = []
99
  colors = []
@@ -111,9 +87,8 @@ with gr.Blocks(css=css) as demo:
111
  gr.Markdown('''
112
  ![Examples](https://multidiffusion.github.io/pics/tight.jpg)
113
  ''')
114
- css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
115
-
116
- aspect.change(update_css, inputs=aspect, outputs=css_height)
117
  button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
118
  final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
119
- demo.launch()
 
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
5
+ from region_control import MultiDiffusion, get_views, preprocess_mask
6
+ from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
7
  MAX_COLORS = 12
8
 
9
+ sd = MultiDiffusion("cuda", "2.1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def process_sketch(image, binary_matrixes):
12
  high_freq_colors = get_high_freq_colors(image)
 
15
  im2arr = color_quantization(im2arr, n_colors=how_many_colors)
16
 
17
  colors_fixed = []
18
+ for color in high_freq_colors:
19
+ r, g, b = color[1]
20
+ if any(c != 255 for c in (r, g, b)):
21
+ binary_matrix = create_binary_matrix(im2arr, (r,g,b))
22
+ binary_matrixes.append(binary_matrix)
23
+ colors_fixed.append(gr.update(value=f'<div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div>'))
 
24
  visibilities = []
25
  colors = []
26
  for n in range(MAX_COLORS):
 
33
 
34
  def process_generation(binary_matrixes, master_prompt, *prompts):
35
  clipped_prompts = prompts[:len(binary_matrixes)]
36
+ prompts = [master_prompt] + list(clipped_prompts)
37
+ neg_prompts = [""] * len(prompts)
38
+ fg_masks = torch.cat([preprocess_mask(mask_path, 512 // 8, 512 // 8, "cuda") for mask_path in binary_matrixes])
39
+ bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
40
+ bg_mask[bg_mask < 0] = 0
41
+ masks = torch.cat([bg_mask, fg_masks])
42
+ print(masks.size())
43
+ image = sd.generate(masks, prompts, neg_prompts, 512, 512, 50, bootstrapping=20)
44
+ return(image)
45
 
46
  css = '''
47
  #color-bg{display:flex;justify-content: center;align-items: center;}
 
50
  '''
51
  def update_css(aspect):
52
  if(aspect=='Square'):
53
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
 
54
  elif(aspect == 'Horizontal'):
55
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)]
 
56
  elif(aspect=='Vertical'):
57
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
 
 
58
 
59
  with gr.Blocks(css=css) as demo:
60
  binary_matrixes = gr.State([])
 
63
  ''')
64
  with gr.Row():
65
  with gr.Box(elem_id="main-image"):
66
+ #with gr.Row():
67
+ image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45)
68
+ #image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45)
69
+ #image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45)
70
+ #with gr.Row():
71
+ # aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
72
+ button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True)
73
 
74
  prompts = []
75
  colors = []
 
87
  gr.Markdown('''
88
  ![Examples](https://multidiffusion.github.io/pics/tight.jpg)
89
  ''')
90
+ #css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
91
+ #aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical])
 
92
  button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
93
  final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
94
+ demo.launch(debug=True)