File size: 3,020 Bytes
ab79e7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import contextlib
import functools
import json
import logging
import os
import time
import urllib.request

import gradio as gr
import open_clip  # works on open-clip-torch>=2.23.0, timm>=0.9.8
import PIL.Image
import torch
import torch.nn.functional as F


INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json'
IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg'


@contextlib.contextmanager
def timed(name):
  t0 = time.monotonic()
  try:
    yield
  finally:
    logging.info('Timed %s: %.1f secs', name, time.monotonic() - t0)


@functools.cache
def load_model(name='hf-hub:timm/ViT-SO400M-14-SigLIP-384'):
  with timed('loading model, preprocess, tokenizer'):
    t0 = time.time()
    model, preprocess = open_clip.create_model_from_pretrained(name)
    tokenizer = open_clip.get_tokenizer(name)
    logging.info('loaded  in %.1fs', time.time() - t0)
    return model, preprocess, tokenizer


def generate_answers(image_path, prompts):

  model, preprocess, tokenizer = load_model()

  with torch.no_grad(), torch.cuda.amp.autocast():
    logging.info('Opening image "%s"', image_path)
    with timed(f'opening image "{image_path}"'):
      image = PIL.Image.open(image_path)
    with timed('image features'):
      image = preprocess(image).unsqueeze(0)
      image_features = model.encode_image(image)

    with timed('text features'):
      prompts = prompts.split(', ')
      text = tokenizer(prompts, context_length=model.context_length)
      text_features = model.encode_text(text)
      image_features = F.normalize(image_features, dim=-1)
      text_features = F.normalize(text_features, dim=-1)

    exp, bias = model.logit_scale.exp(), model.logit_bias
    text_probs = torch.sigmoid(image_features @ text_features.T * exp + bias)
    return list(zip(prompts, [round(p.item(), 3) for p in text_probs[0]]))


def create_app():
  info = json.load(urllib.request.urlopen(INFO_URL))

  with gr.Blocks() as demo:

    gr.Markdown('Minimal gradio clone of [lit-tuning-demo](https://google-research.github.io/vision_transformer/lit/)')
    gr.Markdown('Using `open_clip` implementation of SigLIP model `timm/ViT-SO400M-14-SigLIP-384`')

    with gr.Row():
      image = gr.Image(label='input_image', type='filepath')
      with gr.Column():
        prompts = gr.Textbox(label='prompts')
        answer = gr.Textbox(label='answer')
        run = gr.Button('Run')

    gr.Examples(
        examples=[
            [IMG_URL_FMT.format(ex['id']), ex['prompts']]
            for ex in info
        ],
        inputs=[image, prompts],
        outputs=[answer],
    )

    run.click(fn=generate_answers, inputs=[image, prompts], outputs=[answer])

  return demo


if __name__ == "__main__":

  logging.basicConfig(level=logging.INFO,
                      format='%(asctime)s - %(levelname)s - %(message)s')

  for k, v in os.environ.items():
    logging.info('environ["%s"] = %r', k, v)

  _  = load_model()

  create_app().queue().launch()