Spaces:
Build error
Build error
File size: 3,112 Bytes
8ff0261 5dce03a bf20f73 7326e2c bf20f73 f1d50b1 e4b9c8b f1d50b1 e4b9c8b 2cf3514 83d94a8 8b6d3c7 fedeff8 2cf3514 e4b9c8b df702af 14261e1 df702af 7326e2c 83d94a8 e4b9c8b 83d94a8 7326e2c e4b9c8b 83d94a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import requests
import streamlit as st
from PIL import Image
from utils import load_model
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Zero-shot Image Classification")
st.markdown(
"""
This demonstration explores capability of KoCLIP in the field of Zero-Shot Prediction. This demo takes a set of image and captions from the user, and predicts the most likely label among the different captions given.
KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence). Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder. Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder.
"""
)
query1 = st.text_input(
"Enter a URL to an image...",
value="http://images.cocodataset.org/val2017/000000039769.jpg",
)
query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
col1, col2 = st.beta_columns([3, 1])
with col2:
captions_count = st.selectbox(
"Number of labels", options=range(1, 6), index=2
)
compute = st.button("Classify")
with col1:
captions = []
defaults = ["κ·μ¬μ΄ κ³ μμ΄", "λ©μλ κ°μμ§", "ν¬λν¬λν νμ€ν°"]
for idx in range(captions_count):
value = defaults[idx] if idx < len(defaults) else ""
captions.append(st.text_input(f"Insert label {idx+1}", value=value))
if compute:
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
else:
st.markdown("""---""")
with st.spinner("Computing..."):
image_data = (
query2 if query2 is not None else requests.get(query1, stream=True).raw
)
image = Image.open(image_data)
# captions = [caption.strip() for caption in captions.split(",")]
captions = [f"μ΄κ²μ {caption.strip()}μ΄λ€." for caption in captions]
inputs = processor(
text=captions, images=image, return_tensors="jax", padding=True
)
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
outputs = model(**inputs)
probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
chart_data = pd.Series(probs[0], index=captions)
col1, col2 = st.beta_columns(2)
with col1:
st.image(image)
with col2:
st.bar_chart(chart_data)
|