Spaces:
Running
on
A10G
Running
on
A10G
Commit
·
ce3552b
1
Parent(s):
e3d135e
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
MAX_COLORS = 12
|
7 |
+
|
8 |
+
def get_high_freq_colors(image):
|
9 |
+
im = image.getcolors(maxcolors=1024*1024)
|
10 |
+
sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
|
11 |
+
|
12 |
+
freqs = [c[0] for c in sorted_colors]
|
13 |
+
mean_freq = sum(freqs) / len(freqs)
|
14 |
+
|
15 |
+
high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq/3)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency
|
16 |
+
return high_freq_colors
|
17 |
+
|
18 |
+
def color_quantization(image, n_colors):
|
19 |
+
# Get color histogram
|
20 |
+
hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
|
21 |
+
# Get most frequent colors
|
22 |
+
colors = np.argwhere(hist > 0)
|
23 |
+
colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]]
|
24 |
+
colors = colors[:n_colors]
|
25 |
+
# Replace each pixel with the closest color
|
26 |
+
dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2)
|
27 |
+
labels = np.argmin(dists, axis=1)
|
28 |
+
return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
|
29 |
+
|
30 |
+
def create_binary_matrix(img_arr, target_color):
|
31 |
+
print(target_color)
|
32 |
+
# Create mask of pixels with target color
|
33 |
+
mask = np.all(img_arr == target_color, axis=-1)
|
34 |
+
|
35 |
+
# Convert mask to binary matrix
|
36 |
+
binary_matrix = mask.astype(int)
|
37 |
+
return binary_matrix
|
38 |
+
|
39 |
+
def process_sketch(image, binary_matrixes):
|
40 |
+
high_freq_colors = get_high_freq_colors(image)
|
41 |
+
how_many_colors = len(high_freq_colors)
|
42 |
+
im2arr = np.array(image) # im2arr.shape: height x width x channel
|
43 |
+
im2arr = color_quantization(im2arr, n_colors=how_many_colors)
|
44 |
+
|
45 |
+
colors_fixed = []
|
46 |
+
for color in high_freq_colors[1:]:
|
47 |
+
r = color[1][0]
|
48 |
+
g = color[1][1]
|
49 |
+
b = color[1][2]
|
50 |
+
binary_matrix = create_binary_matrix(im2arr, (r,g,b))
|
51 |
+
binary_matrixes.append(binary_matrix)
|
52 |
+
colors_fixed.append(gr.update(value=f'<div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div>'))
|
53 |
+
visibilities = []
|
54 |
+
colors = []
|
55 |
+
for n in range(MAX_COLORS):
|
56 |
+
visibilities.append(gr.update(visible=False))
|
57 |
+
colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>'))
|
58 |
+
for n in range(how_many_colors-1):
|
59 |
+
visibilities[n] = gr.update(visible=True)
|
60 |
+
colors[n] = colors_fixed[n]
|
61 |
+
return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
|
62 |
+
|
63 |
+
def process_generation(binary_matrixes, master_prompt, *prompts):
|
64 |
+
clipped_prompts = prompts[:len(binary_matrixes)]
|
65 |
+
print(clipped_prompts)
|
66 |
+
#Now: master_prompt can be used as the main prompt, and binary_matrixes and clipped_prompts can be used as the masked inputs
|
67 |
+
pass
|
68 |
+
|
69 |
+
css = '''
|
70 |
+
#color-bg{display:flex;justify-content: center;align-items: center;}
|
71 |
+
.color-bg-item{width: 100%; height: 32px}
|
72 |
+
#main_button{width:100%}
|
73 |
+
'''
|
74 |
+
def update_css(aspect):
|
75 |
+
if(aspect=='Square'):
|
76 |
+
width = 512
|
77 |
+
height = 512
|
78 |
+
elif(aspect == 'Horizontal'):
|
79 |
+
width = 768
|
80 |
+
height = 512
|
81 |
+
elif(aspect=='Vertical'):
|
82 |
+
width = 512
|
83 |
+
height = 768
|
84 |
+
return gr.update(value=f"<style>#main-image{{width: {width}px}} .fixed-height{{height: {height}px !important}}</style>")
|
85 |
+
|
86 |
+
with gr.Blocks(css=css) as demo:
|
87 |
+
binary_matrixes = gr.State([])
|
88 |
+
with gr.Box(elem_id="main-image"):
|
89 |
+
with gr.Row():
|
90 |
+
image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil")
|
91 |
+
with gr.Row():
|
92 |
+
aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
|
93 |
+
button_run = gr.Button("I've finished my sketch",elem_id="main_button")
|
94 |
+
|
95 |
+
prompts = []
|
96 |
+
colors = []
|
97 |
+
color_row = [None] * MAX_COLORS
|
98 |
+
with gr.Column(visible=False) as post_sketch:
|
99 |
+
general_prompt = gr.Textbox(label="General Prompt")
|
100 |
+
for n in range(MAX_COLORS):
|
101 |
+
with gr.Row(visible=False) as color_row[n]:
|
102 |
+
with gr.Box(elem_id="color-bg"):
|
103 |
+
colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
|
104 |
+
prompts.append(gr.Textbox(label="Prompt for this color"))
|
105 |
+
final_run_btn = gr.Button("Generate!")
|
106 |
+
gallery = gr.Gallery()
|
107 |
+
|
108 |
+
css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
|
109 |
+
|
110 |
+
aspect.change(update_css, inputs=aspect, outputs=css_height)
|
111 |
+
button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
|
112 |
+
final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=gallery)
|
113 |
+
demo.launch(share=True, debug=True)
|