File size: 2,441 Bytes
f1d50b1
7326e2c
e4b9c8b
8ff0261
5dce03a
e4b9c8b
7326e2c
f1d50b1
 
 
 
 
e4b9c8b
f1d50b1
e4b9c8b
2cf3514
 
8b6d3c7
 
fedeff8
 
2cf3514
e4b9c8b
7326e2c
 
 
 
 
 
 
e4b9c8b
 
7326e2c
 
e4b9c8b
7326e2c
 
8ff0261
 
 
 
 
 
 
 
7326e2c
 
 
 
 
e4b9c8b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import streamlit as st
import requests
import numpy as np
import jax
import jax.numpy as jnp
from PIL import Image
import pandas as pd

from utils import load_model


def app(model_name):
    model, processor = load_model(f"koclip/{model_name}")

    st.title("Zero-shot Image Classification")
    st.markdown(
        """
        This demonstration explores capability of KoCLIP in the field of Zero-Shot Prediction. This demo takes a set of image and captions from, and predicts the most likely label among the different captions given.   
        
        KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence). Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder. Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.    
        """
    )

    query1 = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    query2 = st.text_input("or a URL to an image...")

    captions = st.text_input(
        "Enter candidate captions in comma-separated form.", 
        value="๊ท€์—ฌ์šด ๊ณ ์–‘์ด,๋ฉ‹์žˆ๋Š” ๊ฐ•์•„์ง€,ํŠธ๋žœ์Šคํฌ๋จธ"
    )

    if st.button("์งˆ๋ฌธ (Query)"):
        if not any([query1, query2]):
            st.error("Please upload an image or paste an image URL.")
        else:
            image_data = query1 if query1 is not None else requests.get(query2, stream=True).raw
            image = Image.open(image_data)
            st.image(image)
            captions = captions.split(",")
            inputs = processor(text=captions, images=image, return_tensors="jax", padding=True)
            inputs["pixel_values"] = jnp.transpose(
                inputs["pixel_values"], axes=[0, 2, 3, 1]
            )
            outputs = model(**inputs)
            probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
            score_dict = {captions[idx]: prob for idx, prob in enumerate(*probs)}
            df = pd.DataFrame(score_dict.values(), index=score_dict.keys())
            st.bar_chart(df)
            # for idx, prob in sorted(enumerate(*probs), key=lambda x: x[1], reverse=True):
            #     st.text(f"Score: `{prob}`, {captions[idx]}")