|
import gradio as gr |
|
from PIL import Image |
|
import requests |
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
general_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") |
|
anime_model = StableDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion") |
|
|
|
|
|
def check_anime_image(image): |
|
|
|
|
|
return False, [], [] |
|
|
|
def describe_image_general(image): |
|
|
|
description = general_model(image) |
|
return description |
|
|
|
def describe_image_anime(image): |
|
|
|
description = anime_model(image) |
|
return description |
|
|
|
def merge_tags(tags1, tags2): |
|
|
|
return list(set(tags1 + tags2)) |
|
|
|
|
|
def process_image(image, mode): |
|
|
|
image = image.resize((256, 256)) |
|
|
|
if mode == "Anime": |
|
is_anime, similar_images, original_tags = check_anime_image(image) |
|
if is_anime: |
|
tags = describe_image_anime(image) |
|
return tags, original_tags |
|
else: |
|
return ["Not an anime image"], [] |
|
else: |
|
tags = describe_image_general(image) |
|
return tags, [] |
|
|
|
def describe(image, mode): |
|
tags, original_tags = process_image(image, mode) |
|
return gr.update(value="\n".join(tags)), gr.update(value="\n".join(original_tags)) |
|
|
|
def merge(tags, original_tags): |
|
merged_tags = merge_tags(tags.split("\n"), original_tags.split("\n")) |
|
return "\n".join(merged_tags) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
image_input = gr.Image(type="pil", tool="editor", label="Upload/Paste Image") |
|
mode = gr.Dropdown(choices=["Anime", "General"], label="Mode") |
|
|
|
describe_button = gr.Button("Describe") |
|
merge_button = gr.Button("Merge Tags") |
|
|
|
with gr.TabGroup() as tab_group: |
|
with gr.TabItem("Described Tags"): |
|
described_tags = gr.TextArea(label="Described Tags") |
|
with gr.TabItem("Original Tags"): |
|
original_tags = gr.TextArea(label="Original Tags") |
|
|
|
merged_tags = gr.TextArea(label="Merged Tags") |
|
|
|
describe_button.click(describe, inputs=[image_input, mode], outputs=[described_tags, original_tags]) |
|
merge_button.click(merge, inputs=[described_tags, original_tags], outputs=merged_tags) |
|
|
|
demo.launch() |
|
|