|
import gradio as gr |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
import torch |
|
from region_control import MultiDiffusion, get_views, preprocess_mask |
|
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix |
|
MAX_COLORS = 12 |
|
|
|
sd = MultiDiffusion("cuda", "2.0") |
|
|
|
def process_sketch(image, binary_matrixes): |
|
high_freq_colors, image = get_high_freq_colors(image) |
|
how_many_colors = len(high_freq_colors) |
|
im2arr = np.array(image) |
|
im2arr = color_quantization(im2arr, high_freq_colors) |
|
|
|
colors_fixed = [] |
|
for color in high_freq_colors: |
|
r, g, b = color[1] |
|
if any(c != 255 for c in (r, g, b)): |
|
binary_matrix = create_binary_matrix(im2arr, (r,g,b)) |
|
binary_matrixes.append(binary_matrix) |
|
colors_fixed.append(gr.update(value=f'<div style="display:flex;align-items: center;justify-content: center"><img width="20%" style="margin-right: 1em" src="file/{binary_matrix}" /><div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div></div>')) |
|
visibilities = [] |
|
colors = [] |
|
for n in range(MAX_COLORS): |
|
visibilities.append(gr.update(visible=False)) |
|
colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>')) |
|
for n in range(how_many_colors-1): |
|
visibilities[n] = gr.update(visible=True) |
|
colors[n] = colors_fixed[n] |
|
return [gr.update(visible=True), binary_matrixes, *visibilities, *colors] |
|
|
|
def process_generation(binary_matrixes, master_prompt, *prompts): |
|
clipped_prompts = prompts[:len(binary_matrixes)] |
|
prompts = [master_prompt] + list(clipped_prompts) |
|
neg_prompts = [""] * len(prompts) |
|
fg_masks = torch.cat([preprocess_mask(mask_path, 512 // 8, 512 // 8, "cuda") for mask_path in binary_matrixes]) |
|
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True) |
|
bg_mask[bg_mask < 0] = 0 |
|
masks = torch.cat([bg_mask, fg_masks]) |
|
print(masks.size()) |
|
image = sd.generate(masks, prompts, neg_prompts, 512, 512, 50, bootstrapping=20) |
|
return(image) |
|
|
|
css = ''' |
|
#color-bg{display:flex;justify-content: center;align-items: center;} |
|
.color-bg-item{width: 100%; height: 32px} |
|
#main_button{width:100%} |
|
''' |
|
def update_css(aspect): |
|
if(aspect=='Square'): |
|
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] |
|
elif(aspect == 'Horizontal'): |
|
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)] |
|
elif(aspect=='Vertical'): |
|
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] |
|
|
|
with gr.Blocks(css=css) as demo: |
|
binary_matrixes = gr.State([]) |
|
gr.Markdown('''## Control your Stable Diffusion generation with Sketches |
|
This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io). |
|
''') |
|
with gr.Row(): |
|
with gr.Box(elem_id="main-image"): |
|
|
|
image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45) |
|
|
|
|
|
|
|
|
|
button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True) |
|
|
|
prompts = [] |
|
colors = [] |
|
color_row = [None] * MAX_COLORS |
|
with gr.Column(visible=False) as post_sketch: |
|
general_prompt = gr.Textbox(label="General Prompt") |
|
for n in range(MAX_COLORS): |
|
with gr.Row(visible=False) as color_row[n]: |
|
with gr.Box(elem_id="color-bg"): |
|
colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>')) |
|
prompts.append(gr.Textbox(label="Prompt for this mask")) |
|
final_run_btn = gr.Button("Generate!") |
|
|
|
out_image = gr.Image(label="Result") |
|
gr.Markdown(''' |
|
![Examples](https://multidiffusion.github.io/pics/tight.jpg) |
|
''') |
|
|
|
|
|
button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors]) |
|
final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image) |
|
demo.launch(debug=True) |