import gradio as gr import torch from transformers import AutoProcessor, AutoModelForZeroShotImageClassification from diffusers import DiffusionPipeline import requests from PIL import Image from io import BytesIO # Initialize models anime_model = DiffusionPipeline.from_pretrained("SmilingWolf/wd-v1-4-vit-tagger") photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval") processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval") def get_booru_image(booru, image_id): # This is a placeholder function. You'd need to implement the actual API calls for each booru. url = f"https://api.{booru}.org/images/{image_id}" response = requests.get(url) img = Image.open(BytesIO(response.content)) tags = ["tag1", "tag2", "tag3"] # Placeholder return img, tags def transcribe_image(image, image_type, transcriber, booru_tags=None): if image_type == "Anime": with torch.no_grad(): tags = anime_model(image) else: inputs = processor(images=image, return_tensors="pt") outputs = photo_model(**inputs) tags = outputs.logits.topk(50).indices.squeeze().tolist() tags = [processor.config.id2label[t] for t in tags] if booru_tags: tags = list(set(tags + booru_tags)) return ", ".join(tags) def update_image(image_type, booru, image_id, uploaded_image): if image_type == "Anime" and booru != "Upload": image, booru_tags = get_booru_image(booru, image_id) return image, gr.update(visible=True), booru_tags elif uploaded_image is not None: return uploaded_image, gr.update(visible=True), None else: return None, gr.update(visible=False), None def on_image_type_change(image_type): if image_type == "Anime": return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"]) else: return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"]) with gr.Blocks() as app: gr.Markdown("# Image Transcription App") with gr.Tab("Step 1: Image"): image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type") with gr.Column(visible=False) as anime_options: booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus") image_id = gr.Textbox(label="Image ID") get_image_btn = gr.Button("Get image") upload_btn = gr.UploadButton("Upload Image", visible=False) image_display = gr.Image(label="Image to transcribe", visible=False) booru_tags = gr.State(None) transcribe_btn = gr.Button("Transcribe", visible=False) transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False) with gr.Tab("Step 2: Transcribe"): transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber") transcribe_image_display = gr.Image(label="Image to transcribe") transcribe_btn_final = gr.Button("Transcribe") tags_output = gr.Textbox(label="Transcribed tags") image_type.change(on_image_type_change, inputs=[image_type], outputs=[anime_options, upload_btn, transcriber]) get_image_btn.click(update_image, inputs=[image_type, booru, image_id, upload_btn], outputs=[image_display, transcribe_btn, booru_tags]) upload_btn.upload(update_image, inputs=[image_type, booru, image_id, upload_btn], outputs=[image_display, transcribe_btn, booru_tags]) def transcribe_and_update(image, image_type, transcriber, booru_tags): tags = transcribe_image(image, image_type, transcriber, booru_tags) return image, tags transcribe_btn.click(transcribe_and_update, inputs=[image_display, image_type, transcriber, booru_tags], outputs=[transcribe_image_display, tags_output]) transcribe_with_tags_btn.click(transcribe_and_update, inputs=[image_display, image_type, transcriber, booru_tags], outputs=[transcribe_image_display, tags_output]) transcribe_btn_final.click(transcribe_image, inputs=[transcribe_image_display, image_type, transcriber], outputs=[tags_output]) app.launch()