intelliarts's picture
Update app.py
eb80b2b
import os
import numpy as np
import boto3
import streamlit as st
import faiss
import pandas as pd
from PIL import Image
from model_prediction import Ranker
from io import BytesIO
import cv2
@st.cache
def load_model():
return Ranker()
def load_faiss_index():
return faiss.read_index('embeddings.index')
def load_labels():
return pd.read_csv("labels.csv")
class ModelLoader:
model = None
index = None
labels = None
@classmethod
def get_model(cls):
if cls.model is None:
cls.model = load_model()
return cls.model
@classmethod
def get_index(cls):
if cls.index is None:
cls.index = load_faiss_index()
return cls.index
@classmethod
def get_labels(cls):
if cls.labels is None:
cls.labels = load_labels()
return cls.labels
target_size = (224, 224)
st.set_page_config(page_title="Product Retrieval App")
st.title("Product Retrieval App")
st.markdown("""The Product Retrieval App is a demonstration of a computer vision model created by <a href="https://intelliarts.com/">Intelliarts</a>. It can analyze and interpret visual data , i.e., shapes, colors, and textures from uploaded digital images. The data is then used to conduct a search on the web. The output of the computer vision model is a set of images that are predicted to be most similar to the input image.
To use the Product Retrieval App, you need to:
1. Select an image that depicts the item of interest. Acceptable formats are JPG, JPEG, and PNG.
2. Upload the image by either dragging and dropping the file into the search field or selecting a file from your computer using the “browse files” button.
3. Scroll to the bottom of the page to review the output results.""",
unsafe_allow_html=True)
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
loading_text = st.empty()
s3 = boto3.client(
's3',
aws_access_key_id='AKIAUUWYLZEQYT6ESW4Q',
aws_secret_access_key='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',
region_name='eu-west-1'
)
bucket_name = "product-retrieval"
if uploaded_file is not None:
image = Image.open(uploaded_file)
image = np.asarray(image)
if len(image.shape) > 2 and image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
image = Image.fromarray(image)
st.image(image, caption="Uploaded image", use_column_width=True)
loading_text.text("Loading predictions...")
model = ModelLoader.get_model()
index = ModelLoader.get_index()
labels = ModelLoader.get_labels()
image_embedding = model.predict(image)
distances, indices = index.search(image_embedding, 12)
predicted_images = labels["path"][indices[0]].to_list()
loading_text.empty()
col1, col2, col3, col4 = st.columns(4)
for i, img_path in enumerate(predicted_images):
response = s3.get_object(Bucket=bucket_name, Key=img_path.split("/")[-1])
image_data = response['Body'].read()
img = Image.open(BytesIO(image_data)).resize(target_size)
if i % 4 == 0:
column = col1
elif i % 4 == 1:
column = col2
elif i % 4 == 2:
column = col3
else:
column = col4
with column:
st.image(img, caption=f"Predicted image {i+1}", use_column_width=True)