|
import gradio as gr |
|
import torch |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
from diffusers import DiffusionPipeline |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
import onnxruntime as ort |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
anime_model_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "model.onnx") |
|
anime_model = ort.InferenceSession(anime_model_path) |
|
photo_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) |
|
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) |
|
|
|
|
|
labels_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "selected_tags.csv") |
|
with open(labels_path, 'r') as f: |
|
anime_labels = [line.strip().split(',')[0] for line in f.readlines()[1:]] |
|
|
|
def preprocess_image(image): |
|
image = image.convert('RGB') |
|
image = image.resize((448, 448), Image.LANCZOS) |
|
image = np.array(image).astype(np.float32) |
|
image = image[:, :, ::-1] |
|
image = np.transpose(image, (2, 0, 1)) |
|
image = image / 255.0 |
|
return image[np.newaxis, ...] |
|
|
|
def transcribe_image(image, image_type, transcriber, booru_tags=None): |
|
if image_type == "Anime": |
|
input_image = preprocess_image(image) |
|
input_name = anime_model.get_inputs()[0].name |
|
output_name = anime_model.get_outputs()[0].name |
|
probs = anime_model.run([output_name], {input_name: input_image})[0] |
|
|
|
|
|
top_indices = probs[0].argsort()[-50:][::-1] |
|
tags = [anime_labels[i] for i in top_indices] |
|
else: |
|
prompt = "<MORE_DETAILED_CAPTION>" |
|
inputs = processor(images=image, text=prompt, return_tensors="pt") |
|
generated_ids = photo_model.generate( |
|
input_ids=inputs["input_ids"], |
|
pixel_values=inputs["pixel_values"], |
|
max_new_tokens=1024, |
|
do_sample=False, |
|
num_beams=3, |
|
) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
tags = generated_text |
|
|
|
return ", ".join(tags) |
|
|
|
|
|
def get_booru_image(booru, image_id): |
|
if booru == "Gelbooru": |
|
url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}" |
|
elif booru == "Danbooru": |
|
url = f"https://danbooru.donmai.us/posts/{image_id}.json" |
|
elif booru == "rule34.xxx": |
|
url = f"https://api.rule34.xxx/index.php?page=dapi&s=post&q=index&json=1&id={image_id}" |
|
else: |
|
raise ValueError("Unsupported booru") |
|
|
|
response = requests.get(url) |
|
data = response.json() |
|
|
|
|
|
|
|
image_url = data[0]['file_url'] if isinstance(data, list) else data['file_url'] |
|
tags = data[0]['tags'].split() if isinstance(data, list) else data['tags'].split() |
|
|
|
img_response = requests.get(image_url) |
|
img = Image.open(BytesIO(img_response.content)) |
|
|
|
return img, tags |
|
|
|
def update_image(image_type, booru, image_id, uploaded_image): |
|
if image_type == "Anime" and booru != "Upload": |
|
image, booru_tags = get_booru_image(booru, image_id) |
|
return image, gr.update(visible=True), booru_tags |
|
elif uploaded_image is not None: |
|
return uploaded_image, gr.update(visible=True), None |
|
else: |
|
return None, gr.update(visible=False), None |
|
|
|
def on_image_type_change(image_type): |
|
if image_type == "Anime": |
|
return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"]) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"]) |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# Image Transcription App") |
|
|
|
with gr.Tab("Step 1: Image"): |
|
image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type") |
|
|
|
with gr.Column(visible=False) as anime_options: |
|
booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus") |
|
image_id = gr.Textbox(label="Image ID") |
|
get_image_btn = gr.Button("Get image") |
|
|
|
upload_btn = gr.UploadButton("Upload Image", visible=False) |
|
|
|
image_display = gr.Image(label="Image to transcribe", visible=False) |
|
booru_tags = gr.State(None) |
|
|
|
transcribe_btn = gr.Button("Transcribe", visible=False) |
|
transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False) |
|
|
|
with gr.Tab("Step 2: Transcribe"): |
|
transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber") |
|
transcribe_image_display = gr.Image(label="Image to transcribe") |
|
transcribe_btn_final = gr.Button("Transcribe") |
|
tags_output = gr.Textbox(label="Transcribed tags") |
|
|
|
image_type.change(on_image_type_change, inputs=[image_type], |
|
outputs=[anime_options, upload_btn, transcriber]) |
|
|
|
get_image_btn.click(update_image, |
|
inputs=[image_type, booru, image_id, upload_btn], |
|
outputs=[image_display, transcribe_btn, booru_tags]) |
|
|
|
upload_btn.upload(update_image, |
|
inputs=[image_type, booru, image_id, upload_btn], |
|
outputs=[image_display, transcribe_btn, booru_tags]) |
|
|
|
def transcribe_and_update(image, image_type, transcriber, booru_tags): |
|
tags = transcribe_image(image, image_type, transcriber, booru_tags) |
|
return image, tags |
|
|
|
transcribe_btn.click(transcribe_and_update, |
|
inputs=[image_display, image_type, transcriber, booru_tags], |
|
outputs=[transcribe_image_display, tags_output]) |
|
|
|
transcribe_with_tags_btn.click(transcribe_and_update, |
|
inputs=[image_display, image_type, transcriber, booru_tags], |
|
outputs=[transcribe_image_display, tags_output]) |
|
|
|
transcribe_btn_final.click(transcribe_image, |
|
inputs=[transcribe_image_display, image_type, transcriber], |
|
outputs=[tags_output]) |
|
|
|
app.launch() |