describe-test / app.py
aolko's picture
Update app.py
1376e14 verified
raw
history blame
4.56 kB
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from diffusers import DiffusionPipeline
import requests
from PIL import Image
from io import BytesIO
# Initialize models
anime_model = DiffusionPipeline.from_pretrained("SmilingWolf/wd-v1-4-vit-tagger")
photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval")
processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval")
def get_booru_image(booru, image_id):
# This is a placeholder function. You'd need to implement the actual API calls for each booru.
url = f"https://api.{booru}.org/images/{image_id}"
response = requests.get(url)
img = Image.open(BytesIO(response.content))
tags = ["tag1", "tag2", "tag3"] # Placeholder
return img, tags
def transcribe_image(image, image_type, transcriber, booru_tags=None):
if image_type == "Anime":
with torch.no_grad():
tags = anime_model(image)
else:
inputs = processor(images=image, return_tensors="pt")
outputs = photo_model(**inputs)
tags = outputs.logits.topk(50).indices.squeeze().tolist()
tags = [processor.config.id2label[t] for t in tags]
if booru_tags:
tags = list(set(tags + booru_tags))
return ", ".join(tags)
def update_image(image_type, booru, image_id, uploaded_image):
if image_type == "Anime" and booru != "Upload":
image, booru_tags = get_booru_image(booru, image_id)
return image, gr.update(visible=True), booru_tags
elif uploaded_image is not None:
return uploaded_image, gr.update(visible=True), None
else:
return None, gr.update(visible=False), None
def on_image_type_change(image_type):
if image_type == "Anime":
return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"])
else:
return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"])
with gr.Blocks() as app:
gr.Markdown("# Image Transcription App")
with gr.Tab("Step 1: Image"):
image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type")
with gr.Column(visible=False) as anime_options:
booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus")
image_id = gr.Textbox(label="Image ID")
get_image_btn = gr.Button("Get image")
upload_btn = gr.UploadButton("Upload Image", visible=False)
image_display = gr.Image(label="Image to transcribe", visible=False)
booru_tags = gr.State(None)
transcribe_btn = gr.Button("Transcribe", visible=False)
transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False)
with gr.Tab("Step 2: Transcribe"):
transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber")
transcribe_image_display = gr.Image(label="Image to transcribe")
transcribe_btn_final = gr.Button("Transcribe")
tags_output = gr.Textbox(label="Transcribed tags")
image_type.change(on_image_type_change, inputs=[image_type],
outputs=[anime_options, upload_btn, transcriber])
get_image_btn.click(update_image,
inputs=[image_type, booru, image_id, upload_btn],
outputs=[image_display, transcribe_btn, booru_tags])
upload_btn.upload(update_image,
inputs=[image_type, booru, image_id, upload_btn],
outputs=[image_display, transcribe_btn, booru_tags])
def transcribe_and_update(image, image_type, transcriber, booru_tags):
tags = transcribe_image(image, image_type, transcriber, booru_tags)
return image, tags
transcribe_btn.click(transcribe_and_update,
inputs=[image_display, image_type, transcriber, booru_tags],
outputs=[transcribe_image_display, tags_output])
transcribe_with_tags_btn.click(transcribe_and_update,
inputs=[image_display, image_type, transcriber, booru_tags],
outputs=[transcribe_image_display, tags_output])
transcribe_btn_final.click(transcribe_image,
inputs=[transcribe_image_display, image_type, transcriber],
outputs=[tags_output])
app.launch()