|
import gradio as gr |
|
import torch |
|
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification |
|
from diffusers import DiffusionPipeline |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
|
|
|
|
anime_model = DiffusionPipeline.from_pretrained("SmilingWolf/wd-v1-4-vit-tagger") |
|
photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval") |
|
processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval") |
|
|
|
def get_booru_image(booru, image_id): |
|
|
|
url = f"https://api.{booru}.org/images/{image_id}" |
|
response = requests.get(url) |
|
img = Image.open(BytesIO(response.content)) |
|
tags = ["tag1", "tag2", "tag3"] |
|
return img, tags |
|
|
|
def transcribe_image(image, image_type, transcriber, booru_tags=None): |
|
if image_type == "Anime": |
|
with torch.no_grad(): |
|
tags = anime_model(image) |
|
else: |
|
inputs = processor(images=image, return_tensors="pt") |
|
outputs = photo_model(**inputs) |
|
tags = outputs.logits.topk(50).indices.squeeze().tolist() |
|
tags = [processor.config.id2label[t] for t in tags] |
|
|
|
if booru_tags: |
|
tags = list(set(tags + booru_tags)) |
|
|
|
return ", ".join(tags) |
|
|
|
def update_image(image_type, booru, image_id, uploaded_image): |
|
if image_type == "Anime" and booru != "Upload": |
|
image, booru_tags = get_booru_image(booru, image_id) |
|
return image, gr.update(visible=True), booru_tags |
|
elif uploaded_image is not None: |
|
return uploaded_image, gr.update(visible=True), None |
|
else: |
|
return None, gr.update(visible=False), None |
|
|
|
def on_image_type_change(image_type): |
|
if image_type == "Anime": |
|
return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"]) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"]) |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Image Transcription App") |
|
|
|
with gr.Tab("Step 1: Image"): |
|
image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type") |
|
|
|
with gr.Column(visible=False) as anime_options: |
|
booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus") |
|
image_id = gr.Textbox(label="Image ID") |
|
get_image_btn = gr.Button("Get image") |
|
|
|
upload_btn = gr.UploadButton("Upload Image", visible=False) |
|
|
|
image_display = gr.Image(label="Image to transcribe", visible=False) |
|
booru_tags = gr.State(None) |
|
|
|
transcribe_btn = gr.Button("Transcribe", visible=False) |
|
transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False) |
|
|
|
with gr.Tab("Step 2: Transcribe"): |
|
transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber") |
|
transcribe_image_display = gr.Image(label="Image to transcribe") |
|
transcribe_btn_final = gr.Button("Transcribe") |
|
tags_output = gr.Textbox(label="Transcribed tags") |
|
|
|
image_type.change(on_image_type_change, inputs=[image_type], |
|
outputs=[anime_options, upload_btn, transcriber]) |
|
|
|
get_image_btn.click(update_image, |
|
inputs=[image_type, booru, image_id, upload_btn], |
|
outputs=[image_display, transcribe_btn, booru_tags]) |
|
|
|
upload_btn.upload(update_image, |
|
inputs=[image_type, booru, image_id, upload_btn], |
|
outputs=[image_display, transcribe_btn, booru_tags]) |
|
|
|
def transcribe_and_update(image, image_type, transcriber, booru_tags): |
|
tags = transcribe_image(image, image_type, transcriber, booru_tags) |
|
return image, tags |
|
|
|
transcribe_btn.click(transcribe_and_update, |
|
inputs=[image_display, image_type, transcriber, booru_tags], |
|
outputs=[transcribe_image_display, tags_output]) |
|
|
|
transcribe_with_tags_btn.click(transcribe_and_update, |
|
inputs=[image_display, image_type, transcriber, booru_tags], |
|
outputs=[transcribe_image_display, tags_output]) |
|
|
|
transcribe_btn_final.click(transcribe_image, |
|
inputs=[transcribe_image_display, image_type, transcriber], |
|
outputs=[tags_output]) |
|
|
|
app.launch() |