Spaces:
Running
Running
File size: 3,802 Bytes
563a829 a3ee979 563a829 50c0cb3 a3ee979 563a829 a3ee979 54f2384 2eca6de a3ee979 8c43e37 a3ee979 83d58c6 a3ee979 563a829 a3ee979 563a829 54a3362 a3ee979 563a829 a3ee979 563a829 a3ee979 563a829 a3ee979 563a829 50c0cb3 f6c2567 a3ee979 563a829 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
from glob import glob
from typing import Optional
import gradio as gr
import torch
from torchvision.transforms.functional import to_pil_image
from transformers import AutoModel, CLIPProcessor
PAPER_TITLE = "Vocabulary-free Image Classification"
PAPER_URL = "https://arxiv.org/abs/2306.00917"
MARKDOWN_DESCRIPTION = """
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
<h1>Vocabulary-free Image Classification</h1>
</div>
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
<a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem;">
<img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
</a>
<a href="https://huggingface.co/spaces/altndrr/vic" style="margin-right: 0.5rem;">
<img src="https://img.shields.io/badge/demo-hf.altndrr%2Fvic-yellow.svg"/>
</a>
<a href="https://arxiv.org/abs/2306.00917" style="margin-right: 0.5rem;">
<img src="https://img.shields.io/badge/paper-arXiv.2306.00917-B31B1B.svg"/>
</a>
<a href="https://alessandroconti.me/papers/2306.00917" style="margin-right: 0.5rem;">
<img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
</a>
</div>
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DEVICE)
PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
def prepare_image(image: gr.Image):
if image is None:
return None, None
PROCESSOR.image_processor.do_normalize = False
image_tensor = PROCESSOR(images=[image], return_tensors="pt", padding=True)
PROCESSOR.image_processor.do_normalize = True
image_tensor = image_tensor.pixel_values[0]
curr_image = to_pil_image(image_tensor)
return curr_image, image.copy()
def image_inference(image: gr.Image, alpha: Optional[float] = None):
if image is None:
return None
images = PROCESSOR(images=[image], return_tensors="pt", padding=True)
with torch.no_grad():
outputs = MODEL(images, alpha=alpha)
vocabulary = outputs["vocabularies"][0]
scores = outputs["scores"][0].tolist()
confidences = dict(zip(vocabulary, scores))
return confidences
with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo:
gr.Markdown(MARKDOWN_DESCRIPTION)
with gr.Row():
with gr.Column():
curr_image = gr.Image(label="input", type="pil", height=300)
orig_image = gr.Image(
label="orig. image", type="pil", visible=False, interactive=False
)
alpha_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="alpha")
with gr.Row():
clear_button = gr.ClearButton([curr_image, orig_image])
run_button = gr.Button(value="Submit", variant="primary")
with gr.Column():
output_label = gr.Label(label="output", num_top_classes=5)
examples = gr.Examples(
examples=glob(os.path.join(os.path.dirname(__file__), "examples", "*.jpg")),
inputs=[orig_image],
outputs=[output_label],
fn=image_inference,
cache_examples=True,
)
gr.Markdown(f"Check out the <a href={PAPER_URL}>original paper</a> for more information.")
curr_image.upload(prepare_image, [curr_image], [curr_image, orig_image])
curr_image.clear(lambda: None, [], [orig_image])
orig_image.change(prepare_image, [orig_image], [curr_image, orig_image])
run_button.click(image_inference, [curr_image, alpha_slider], [output_label])
if __name__ == "__main__":
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860)
|