koclip / image2text.py
jaketae's picture
feature: replace comma separated input w/ counter ui
83d94a8
raw
history blame
3.11 kB
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import requests
import streamlit as st
from PIL import Image
from utils import load_model
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Zero-shot Image Classification")
st.markdown(
"""
This demonstration explores capability of KoCLIP in the field of Zero-Shot Prediction. This demo takes a set of image and captions from the user, and predicts the most likely label among the different captions given.
KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence). Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder. Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
"""
)
query1 = st.text_input(
"Enter a URL to an image...",
value="http://images.cocodataset.org/val2017/000000039769.jpg",
)
query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
col1, col2 = st.beta_columns([3, 1])
with col2:
captions_count = st.selectbox(
"Number of labels", options=range(1, 6), index=2
)
compute = st.button("Classify")
with col1:
captions = []
defaults = ["κ·€μ—¬μš΄ 고양이", "λ©‹μžˆλŠ” 강아지", "ν¬λ™ν¬λ™ν•œ ν–„μŠ€ν„°"]
for idx in range(captions_count):
value = defaults[idx] if idx < len(defaults) else ""
captions.append(st.text_input(f"Insert label {idx+1}", value=value))
if compute:
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
else:
st.markdown("""---""")
with st.spinner("Computing..."):
image_data = (
query2 if query2 is not None else requests.get(query1, stream=True).raw
)
image = Image.open(image_data)
# captions = [caption.strip() for caption in captions.split(",")]
captions = [f"이것은 {caption.strip()}이닀." for caption in captions]
inputs = processor(
text=captions, images=image, return_tensors="jax", padding=True
)
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
outputs = model(**inputs)
probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
chart_data = pd.Series(probs[0], index=captions)
col1, col2 = st.beta_columns(2)
with col1:
st.image(image)
with col2:
st.bar_chart(chart_data)