|
import gradio as gr |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
|
|
MAX_COLORS = 12 |
|
|
|
def get_high_freq_colors(image): |
|
im = image.getcolors(maxcolors=1024*1024) |
|
sorted_colors = sorted(im, key=lambda x: x[0], reverse=True) |
|
|
|
freqs = [c[0] for c in sorted_colors] |
|
mean_freq = sum(freqs) / len(freqs) |
|
|
|
high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq/3)] |
|
return high_freq_colors |
|
|
|
def color_quantization(image, n_colors): |
|
|
|
hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256))) |
|
|
|
colors = np.argwhere(hist > 0) |
|
colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]] |
|
colors = colors[:n_colors] |
|
|
|
dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2) |
|
labels = np.argmin(dists, axis=1) |
|
return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8) |
|
|
|
def create_binary_matrix(img_arr, target_color): |
|
print(target_color) |
|
|
|
mask = np.all(img_arr == target_color, axis=-1) |
|
|
|
|
|
binary_matrix = mask.astype(int) |
|
return binary_matrix |
|
|
|
def process_sketch(image, binary_matrixes): |
|
high_freq_colors = get_high_freq_colors(image) |
|
how_many_colors = len(high_freq_colors) |
|
im2arr = np.array(image) |
|
im2arr = color_quantization(im2arr, n_colors=how_many_colors) |
|
|
|
colors_fixed = [] |
|
for color in high_freq_colors[1:]: |
|
r = color[1][0] |
|
g = color[1][1] |
|
b = color[1][2] |
|
binary_matrix = create_binary_matrix(im2arr, (r,g,b)) |
|
binary_matrixes.append(binary_matrix) |
|
colors_fixed.append(gr.update(value=f'<div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div>')) |
|
visibilities = [] |
|
colors = [] |
|
for n in range(MAX_COLORS): |
|
visibilities.append(gr.update(visible=False)) |
|
colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>')) |
|
for n in range(how_many_colors-1): |
|
visibilities[n] = gr.update(visible=True) |
|
colors[n] = colors_fixed[n] |
|
return [gr.update(open=False), gr.update(visible=True), binary_matrixes, *visibilities, *colors] |
|
|
|
def process_generation(binary_matrixes, master_prompt, *prompts): |
|
clipped_prompts = prompts[:len(binary_matrixes)] |
|
print(clipped_prompts) |
|
|
|
pass |
|
|
|
css = ''' |
|
#color-bg{display:flex;justify-content: center;align-items: center;} |
|
.color-bg-item{width: 100%; height: 32px} |
|
#main_button{width:100%} |
|
''' |
|
def update_css(aspect): |
|
if(aspect=='Square'): |
|
width = 512 |
|
height = 512 |
|
elif(aspect == 'Horizontal'): |
|
width = 768 |
|
height = 512 |
|
elif(aspect=='Vertical'): |
|
width = 512 |
|
height = 768 |
|
return gr.update(value=f"<style>#main-image{{width: {width}px}} .fixed-height{{height: {height}px !important}}</style>") |
|
|
|
with gr.Blocks(css=css) as demo: |
|
binary_matrixes = gr.State([]) |
|
with gr.Box(elem_id="main-image"): |
|
with gr.Accordion(open=True, label="Your color sketch") as sketch_accordion: |
|
with gr.Row(): |
|
image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil") |
|
with gr.Row(): |
|
aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio") |
|
button_run = gr.Button("I've finished my sketch",elem_id="main_button") |
|
|
|
prompts = [] |
|
colors = [] |
|
color_row = [None] * MAX_COLORS |
|
with gr.Column(visible=False) as post_sketch: |
|
general_prompt = gr.Textbox(label="General Prompt") |
|
for n in range(MAX_COLORS): |
|
with gr.Row(visible=False) as color_row[n]: |
|
with gr.Box(elem_id="color-bg"): |
|
colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>')) |
|
prompts.append(gr.Textbox(label="Prompt for this color")) |
|
final_run_btn = gr.Button("Generate!") |
|
out_image = gr.Image() |
|
|
|
css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>") |
|
|
|
aspect.change(update_css, inputs=aspect, outputs=css_height) |
|
button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[sketch_accordion, post_sketch, binary_matrixes, *color_row, *colors]) |
|
final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image) |
|
demo.launch() |