Spaces:
Sleeping
Sleeping
import huggingface_hub | |
import gradio as gr | |
from stable_diffusion_reference_only.pipelines.pipeline_stable_diffusion_reference_only import ( | |
StableDiffusionReferenceOnlyPipeline, | |
) | |
from anime_segmentation import get_model as get_anime_segmentation_model | |
from anime_segmentation import character_segment as anime_character_segment | |
from diffusers.schedulers import UniPCMultistepScheduler | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import os | |
import torch | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
automatic_coloring_pipeline = StableDiffusionReferenceOnlyPipeline.from_pretrained( | |
"AisingioroHao0/stable-diffusion-reference-only-automatic-coloring-0.1.2" | |
).to(device) | |
automatic_coloring_pipeline.scheduler = UniPCMultistepScheduler.from_config( | |
automatic_coloring_pipeline.scheduler.config | |
) | |
segment_model = get_anime_segmentation_model( | |
model_path=huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.ckpt") | |
).to(device) | |
def character_segment(img): | |
if img is None: | |
return None | |
img = anime_character_segment(segment_model, img) | |
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) | |
return img | |
def color_inversion(img): | |
if img is None: | |
return None | |
return 255 - img | |
def get_line_art(img): | |
if img is None: | |
return None | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
img = cv2.adaptiveThreshold( | |
img, | |
255, | |
cv2.ADAPTIVE_THRESH_MEAN_C, | |
cv2.THRESH_BINARY, | |
blockSize=5, | |
C=7, | |
) | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
return img | |
def inference(prompt, blueprint, num_inference_steps): | |
if prompt is None or blueprint is None: | |
return None | |
return np.array( | |
automatic_coloring_pipeline( | |
prompt=Image.fromarray(prompt), | |
blueprint=Image.fromarray(blueprint), | |
num_inference_steps=num_inference_steps, | |
).images[0] | |
) | |
def automatic_coloring(prompt, blueprint, num_inference_steps): | |
if prompt is None or blueprint is None: | |
return None | |
blueprint = color_inversion(blueprint) | |
return inference(prompt, blueprint, num_inference_steps) | |
def style_transfer(prompt, blueprint, num_inference_steps): | |
if prompt is None or blueprint is None: | |
return None | |
prompt = character_segment(prompt) | |
blueprint = character_segment(blueprint) | |
blueprint = get_line_art(blueprint) | |
blueprint = color_inversion(blueprint) | |
return inference(prompt, blueprint, num_inference_steps) | |
def resize(img, new_height, new_width): | |
img = Image.fromarray(img).resize((int(new_width), int(new_height)), Image.BILINEAR) | |
return np.array(img) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Stable Diffusion Reference Only Automatic Coloring 0.1.2\n\n | |
demo for [https://github.com/aihao2000/stable-diffusion-reference-only](https://github.com/aihao2000/stable-diffusion-reference-only) | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input_compoent = gr.Image(label="prompt") | |
with gr.Row(): | |
prompt_new_height = gr.Number(512, label="height", minimum=1) | |
prompt_new_width = gr.Number(512, label="width", minimum=1) | |
prompt_resize_button = gr.Button("prompt resize") | |
prompt_resize_button.click( | |
resize, | |
inputs=[ | |
prompt_input_compoent, | |
prompt_new_height, | |
prompt_new_width, | |
], | |
outputs=prompt_input_compoent, | |
) | |
prompt_character_segment_button = gr.Button( | |
"character segment", | |
) | |
prompt_character_segment_button.click( | |
character_segment, | |
inputs=prompt_input_compoent, | |
outputs=prompt_input_compoent, | |
) | |
with gr.Column(): | |
blueprint_input_compoent = gr.Image(label="blueprint") | |
with gr.Row(): | |
blueprint_new_height = gr.Number(512, label="height", minimum=1) | |
blueprint_new_width = gr.Number(512, label="width", minimum=1) | |
blueprint_resize_button = gr.Button("blueprint resize") | |
blueprint_resize_button.click( | |
resize, | |
inputs=[ | |
blueprint_input_compoent, | |
blueprint_new_height, | |
blueprint_new_width, | |
], | |
outputs=blueprint_input_compoent, | |
) | |
blueprint_character_segment_button = gr.Button("character segment") | |
blueprint_character_segment_button.click( | |
character_segment, | |
inputs=blueprint_input_compoent, | |
outputs=blueprint_input_compoent, | |
) | |
get_line_art_button = gr.Button( | |
"get line art", | |
) | |
get_line_art_button.click( | |
get_line_art, | |
inputs=blueprint_input_compoent, | |
outputs=blueprint_input_compoent, | |
) | |
color_inversion_button = gr.Button( | |
"color inversion", | |
) | |
color_inversion_button.click( | |
color_inversion, | |
inputs=blueprint_input_compoent, | |
outputs=blueprint_input_compoent, | |
) | |
with gr.Column(): | |
result_output_component = gr.Image(label="result") | |
num_inference_steps_input_component = gr.Number( | |
20, label="num inference steps", minimum=1, maximum=1000, step=1 | |
) | |
inference_button = gr.Button("inference") | |
inference_button.click( | |
inference, | |
inputs=[ | |
prompt_input_compoent, | |
blueprint_input_compoent, | |
num_inference_steps_input_component, | |
], | |
outputs=result_output_component, | |
) | |
automatic_coloring_button = gr.Button("automatic coloring") | |
automatic_coloring_button.click( | |
automatic_coloring, | |
inputs=[ | |
prompt_input_compoent, | |
blueprint_input_compoent, | |
num_inference_steps_input_component, | |
], | |
outputs=result_output_component, | |
) | |
style_transfer_button = gr.Button("style transfer") | |
style_transfer_button.click( | |
style_transfer, | |
inputs=[ | |
prompt_input_compoent, | |
blueprint_input_compoent, | |
num_inference_steps_input_component, | |
], | |
outputs=result_output_component, | |
) | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
[ | |
os.path.join( | |
os.path.dirname(__file__), "README.assets", "3x9_prompt.png" | |
), | |
os.path.join( | |
os.path.dirname(__file__), "README.assets", "3x9_blueprint.png" | |
), | |
], | |
], | |
inputs=[prompt_input_compoent, blueprint_input_compoent], | |
outputs=result_output_component, | |
fn=lambda x, y: None, | |
cache_examples=True, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=5).launch() | |