|
import contextlib |
|
import functools |
|
import json |
|
import logging |
|
import os |
|
import time |
|
import urllib.request |
|
|
|
import gradio as gr |
|
import open_clip |
|
import PIL.Image |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json' |
|
IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg' |
|
|
|
|
|
@contextlib.contextmanager |
|
def timed(name): |
|
t0 = time.monotonic() |
|
try: |
|
yield |
|
finally: |
|
logging.info('Timed %s: %.1f secs', name, time.monotonic() - t0) |
|
|
|
|
|
@functools.cache |
|
def load_model(name='hf-hub:timm/ViT-SO400M-14-SigLIP-384'): |
|
with timed('loading model, preprocess, tokenizer'): |
|
t0 = time.time() |
|
model, preprocess = open_clip.create_model_from_pretrained(name) |
|
tokenizer = open_clip.get_tokenizer(name) |
|
logging.info('loaded in %.1fs', time.time() - t0) |
|
return model, preprocess, tokenizer |
|
|
|
|
|
def generate_answers(image_path, prompts): |
|
|
|
model, preprocess, tokenizer = load_model() |
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
logging.info('Opening image "%s"', image_path) |
|
with timed(f'opening image "{image_path}"'): |
|
image = PIL.Image.open(image_path) |
|
with timed('image features'): |
|
image = preprocess(image).unsqueeze(0) |
|
image_features = model.encode_image(image) |
|
|
|
with timed('text features'): |
|
prompts = prompts.split(', ') |
|
text = tokenizer(prompts, context_length=model.context_length) |
|
text_features = model.encode_text(text) |
|
image_features = F.normalize(image_features, dim=-1) |
|
text_features = F.normalize(text_features, dim=-1) |
|
|
|
exp, bias = model.logit_scale.exp(), model.logit_bias |
|
text_probs = torch.sigmoid(image_features @ text_features.T * exp + bias) |
|
return list(zip(prompts, [round(p.item(), 3) for p in text_probs[0]])) |
|
|
|
|
|
def create_app(): |
|
info = json.load(urllib.request.urlopen(INFO_URL)) |
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown('Minimal gradio clone of [lit-tuning-demo](https://google-research.github.io/vision_transformer/lit/)') |
|
gr.Markdown('Using `open_clip` implementation of SigLIP model `timm/ViT-SO400M-14-SigLIP-384`') |
|
|
|
with gr.Row(): |
|
image = gr.Image(label='input_image', type='filepath') |
|
with gr.Column(): |
|
prompts = gr.Textbox(label='prompts') |
|
answer = gr.Textbox(label='answer') |
|
run = gr.Button('Run') |
|
|
|
gr.Examples( |
|
examples=[ |
|
[IMG_URL_FMT.format(ex['id']), ex['prompts']] |
|
for ex in info |
|
], |
|
inputs=[image, prompts], |
|
outputs=[answer], |
|
) |
|
|
|
run.click(fn=generate_answers, inputs=[image, prompts], outputs=[answer]) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
for k, v in os.environ.items(): |
|
logging.info('environ["%s"] = %r', k, v) |
|
|
|
_ = load_model() |
|
|
|
create_app().queue().launch() |
|
|