Spaces:
Running
Running
import os | |
import numpy as np | |
import boto3 | |
import streamlit as st | |
import faiss | |
import pandas as pd | |
from PIL import Image | |
from model_prediction import Ranker | |
from io import BytesIO | |
import cv2 | |
def load_model(): | |
return Ranker() | |
def load_faiss_index(): | |
return faiss.read_index('embeddings.index') | |
def load_labels(): | |
return pd.read_csv("labels.csv") | |
class ModelLoader: | |
model = None | |
index = None | |
labels = None | |
def get_model(cls): | |
if cls.model is None: | |
cls.model = load_model() | |
return cls.model | |
def get_index(cls): | |
if cls.index is None: | |
cls.index = load_faiss_index() | |
return cls.index | |
def get_labels(cls): | |
if cls.labels is None: | |
cls.labels = load_labels() | |
return cls.labels | |
target_size = (224, 224) | |
st.set_page_config(page_title="Product Retrieval App") | |
st.title("Product Retrieval App") | |
st.markdown("""The Product Retrieval App is a demonstration of a computer vision model created by <a href="https://intelliarts.com/">Intelliarts</a>. It can analyze and interpret visual data , i.e., shapes, colors, and textures from uploaded digital images. The data is then used to conduct a search on the web. The output of the computer vision model is a set of images that are predicted to be most similar to the input image. | |
To use the Product Retrieval App, you need to: | |
1. Select an image that depicts the item of interest. Acceptable formats are JPG, JPEG, and PNG. | |
2. Upload the image by either dragging and dropping the file into the search field or selecting a file from your computer using the “browse files” button. | |
3. Scroll to the bottom of the page to review the output results.""", | |
unsafe_allow_html=True) | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
loading_text = st.empty() | |
s3 = boto3.client( | |
's3', | |
aws_access_key_id='AKIAUUWYLZEQYT6ESW4Q', | |
aws_secret_access_key='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', | |
region_name='eu-west-1' | |
) | |
bucket_name = "product-retrieval" | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
image = np.asarray(image) | |
if len(image.shape) > 2 and image.shape[2] == 4: | |
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) | |
image = Image.fromarray(image) | |
st.image(image, caption="Uploaded image", use_column_width=True) | |
loading_text.text("Loading predictions...") | |
model = ModelLoader.get_model() | |
index = ModelLoader.get_index() | |
labels = ModelLoader.get_labels() | |
image_embedding = model.predict(image) | |
distances, indices = index.search(image_embedding, 12) | |
predicted_images = labels["path"][indices[0]].to_list() | |
loading_text.empty() | |
col1, col2, col3, col4 = st.columns(4) | |
for i, img_path in enumerate(predicted_images): | |
response = s3.get_object(Bucket=bucket_name, Key=img_path.split("/")[-1]) | |
image_data = response['Body'].read() | |
img = Image.open(BytesIO(image_data)).resize(target_size) | |
if i % 4 == 0: | |
column = col1 | |
elif i % 4 == 1: | |
column = col2 | |
elif i % 4 == 2: | |
column = col3 | |
else: | |
column = col4 | |
with column: | |
st.image(img, caption=f"Predicted image {i+1}", use_column_width=True) | |