Spaces:
Build error
Build error
File size: 2,612 Bytes
84c806e 6525b03 84c806e 6525b03 84c806e 6525b03 84c806e 6525b03 84c806e 6525b03 84c806e 6525b03 84c806e 6525b03 84c806e 6525b03 84c806e 6525b03 84c806e 6525b03 84c806e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
import jax
import jax.numpy as jnp
import numpy as np
import requests
import streamlit as st
from PIL import Image
from utils import load_model
def split_image(im, num_rows=3, num_cols=3):
im = np.array(im)
row_size = im.shape[0] // num_rows
col_size = im.shape[1] // num_cols
tiles = [
im[x : x + M, y : y + N]
for x in range(0, num_rows * row_size, row_size)
for y in range(0, num_cols * col_size, col_size)
]
return tiles
# def split_image(X):
# num_rows = X.shape[0] // 224
# num_cols = X.shape[1] // 224
# Xc = X[0:num_rows * 224, 0:num_cols * 224, :]
# patches = []
# for j in range(num_rows):
# for i in range(num_cols):
# patches.append(Xc[j * 224:(j + 1) * 224, i * 224:(i + 1) * 224, :])
# return patches
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Most Relevant Part of Image")
st.markdown(
"""
Given a piece of text, the CLIP model finds the part of an image that best explains the text.
To try it out, you can
1) Upload an image
2) Explain a part of the image in text
Which will yield the most relevant image tile from a 3x3 grid of the image
"""
)
query1 = st.text_input(
"Enter a URL to an image...",
value="https://img.sbs.co.kr/newimg/news/20200823/201463830_1280.jpg",
)
query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
captions = st.text_input(
"Enter query to find most relevant part of image ",
value="이건 서울의 경복궁 사진이다.",
)
if st.button("질문 (Query)"):
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
else:
image_data = (
query2 if query2 is not None else requests.get(query1, stream=True).raw
)
image = Image.open(image_data)
st.image(image)
images = split_image(image)
inputs = processor(
text=captions, images=images, return_tensors="jax", padding=True
)
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
)
outputs = model(**inputs)
probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
for idx, prob in sorted(enumerate(probs), key=lambda x: x[1], reverse=True):
st.text(f"Score: {prob[0]:.3f}")
st.image(images[idx])
|