Spaces:
Runtime error
Runtime error
import random | |
import onnxruntime | |
import pandas as pd | |
import plotly.express as px | |
import streamlit as st | |
import torch | |
from lang_map import langs | |
from PIL import Image | |
from transformers import AutoTokenizer, CLIPProcessor | |
st.set_page_config(layout="wide") | |
options = list(langs.keys()) | |
class SessionState: | |
def __init__(self, **kwargs): | |
for key, val in kwargs.items(): | |
setattr(self, key, val) | |
def get_state(**kwargs): | |
if "session_state" not in st.session_state: | |
st.session_state["session_state"] = SessionState(**kwargs) | |
return st.session_state["session_state"] | |
def add_selectbox_and_input(key): | |
col1, col2 = st.columns(2) | |
with col1: | |
select = st.selectbox("Select a language", options, key=f"{key}_select") | |
with col2: | |
user_input = st.text_input("Input text", key=f"{key}_text") | |
state.inputs[key] = (select, user_input) | |
state = get_state(count=1, inputs={}) | |
st.title("Zero-shot image classification with CLIP in 201 languages") | |
col1, col2 = st.columns(2) | |
image: Image.Image = None | |
with col1: | |
st.subheader("Image") | |
uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image.", use_column_width=True) | |
def process(): | |
session_options = onnxruntime.SessionOptions() | |
session_options.graph_optimization_level = ( | |
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
) | |
onnx_path = "model-quant.onnx" | |
ort_session = onnxruntime.InferenceSession(onnx_path, session_options) | |
processor = CLIPProcessor.from_pretrained( | |
"openai/clip-vit-base-patch32" | |
).image_processor | |
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
image_inputs = processor(images=image, return_tensors="pt") | |
classes = [] | |
languages = [] | |
for key, value in state.inputs.items(): | |
languages.append(str(value[0])) | |
classes.append(str(value[1])) | |
languages = [langs[lang] for lang in languages] | |
input_ids = [] | |
attention_mask = [] | |
for i, _ in enumerate(languages): | |
tokenizer.set_src_lang_special_tokens(languages[i]) | |
input = tokenizer.batch_encode_plus( | |
[classes[i]], | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=100, | |
) | |
input_ids.append(input["input_ids"]) | |
attention_mask.append(input["attention_mask"]) | |
input_ids = torch.concat(input_ids, dim=0) | |
attention_mask = torch.concat(attention_mask, dim=0) | |
ort_inputs = { | |
"pixel_values": image_inputs["pixel_values"].numpy(), | |
"input_ids": input_ids.numpy(), | |
"attention_mask": attention_mask.numpy(), | |
} | |
ort_outputs = ort_session.run(None, ort_inputs) | |
logits = torch.tensor(ort_outputs[0]) | |
probabilities = logits.softmax(dim=-1).squeeze().detach().numpy() | |
chart_data = pd.DataFrame({"Class": classes, "Probability": probabilities}) | |
chart_data = chart_data.sort_values(by=["Probability"], ascending=True) | |
fig = px.bar(chart_data, x="Probability", y="Class", orientation="h") | |
with col2: | |
st.subheader("Predictions") | |
st.write(fig) | |
with col2: | |
st.subheader("Classes") | |
add_selectbox_and_input("Input 1") | |
for i in range(2, state.count + 1): | |
add_selectbox_and_input(f"Input {i}") | |
if st.button("Add class"): | |
state.count += 1 | |
add_selectbox_and_input(f"Input {state.count}") | |
st.markdown("""---""") | |
if st.button("Generate"): | |
with st.spinner("Processing the data"): | |
process() | |