describe-test / app.py
aolko's picture
Rename describe.py to app.py
980f584 verified
raw
history blame
2.74 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import requests
from diffusers import StableDiffusionPipeline
# Load models using diffusers
general_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
anime_model = StableDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion")
# Placeholder functions for the actual implementations
def check_anime_image(image):
# Use SauceNAO or similar service to check if the image is anime
# and fetch similar images and tags
return False, [], []
def describe_image_general(image):
# Use the general model to describe the image
description = general_model(image)
return description
def describe_image_anime(image):
# Use the anime model to describe the image
description = anime_model(image)
return description
def merge_tags(tags1, tags2):
# Merge tags, removing duplicates
return list(set(tags1 + tags2))
# Gradio app functions
def process_image(image, mode):
# Convert the image to a format suitable for the models
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
image = transform(image).unsqueeze(0)
if mode == "Anime":
is_anime, similar_images, original_tags = check_anime_image(image)
if is_anime:
tags = describe_image_anime(image)
return tags, original_tags
else:
return ["Not an anime image"], []
else:
tags = describe_image_general(image)
return tags, []
def describe(image, mode):
tags, original_tags = process_image(image, mode)
return gr.update(value="\n".join(tags)), gr.update(value="\n".join(original_tags))
def merge(tags, original_tags):
merged_tags = merge_tags(tags.split("\n"), original_tags.split("\n"))
return "\n".join(merged_tags)
# Gradio interface
with gr.Blocks() as demo:
with gr.Row():
image_input = gr.Image(type="pil", tool="editor", label="Upload/Paste Image")
mode = gr.Dropdown(choices=["Anime", "General"], label="Mode")
describe_button = gr.Button("Describe")
merge_button = gr.Button("Merge Tags")
with gr.TabGroup() as tab_group:
with gr.TabItem("Described Tags"):
described_tags = gr.TextArea(label="Described Tags")
with gr.TabItem("Original Tags"):
original_tags = gr.TextArea(label="Original Tags")
merged_tags = gr.TextArea(label="Merged Tags")
describe_button.click(describe, inputs=[image_input, mode], outputs=[described_tags, original_tags])
merge_button.click(merge, inputs=[described_tags, original_tags], outputs=merged_tags)
demo.launch()