itda / app.py
leedoming's picture
Update app.py
7d37333 verified
raw
history blame
6.97 kB
import streamlit as st
import open_clip
import torch
import requests
from PIL import Image
from io import BytesIO
import time
import json
import numpy as np
import cv2
from inference_sdk import InferenceHTTPClient
import matplotlib.pyplot as plt
import base64
# ์˜ค๋ฅ˜ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ์˜ˆ์™ธ ํด๋ž˜์Šค ์ •์˜
class APIError(Exception):
pass
# Load model and tokenizer
@st.cache_resource
def load_model():
model, preprocess_val, tokenizer = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, preprocess_val, tokenizer, device
model, preprocess_val, tokenizer, device = load_model()
# Roboflow client setup function
def setup_roboflow_client(api_key):
return InferenceHTTPClient(
api_url="https://outline.roboflow.com",
api_key=api_key
)
# Streamlit app
st.title("Fashion Search App with Segmentation")
# API Key input
api_key = st.text_input("Enter your Roboflow API Key", type="password")
if api_key:
CLIENT = setup_roboflow_client(api_key)
def segment_image(image_path):
try:
# ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ฝ๊ธฐ
with open(image_path, "rb") as image_file:
image_data = image_file.read()
# ์ด๋ฏธ์ง€๋ฅผ base64๋กœ ์ธ์ฝ”๋”ฉ
encoded_image = base64.b64encode(image_data).decode('utf-8')
# ์›๋ณธ ์ด๋ฏธ์ง€ ๋กœ๋“œ
image = cv2.imread(image_path)
image = cv2.resize(image, (800, 600))
mask = np.zeros(image.shape, dtype=np.uint8)
try:
# Roboflow API ํ˜ธ์ถœ
results = CLIENT.infer(encoded_image, model_id="closet/1")
except Exception as api_error:
st.error(f"API Error: {str(api_error)}")
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
if 'predictions' in results:
for prediction in results['predictions']:
points = prediction['points']
pts = np.array([[p['x'], p['y']] for p in points], np.int32)
scale_x = image.shape[1] / results['image']['width']
scale_y = image.shape[0] / results['image']['height']
pts = pts * [scale_x, scale_y]
pts = pts.astype(np.int32)
pts = pts.reshape((-1, 1, 2))
cv2.fillPoly(mask, [pts], color=(255, 255, 255)) # White mask
segmented_image = cv2.bitwise_and(image, mask)
else:
st.warning("No predictions found in the image. Returning original image.")
segmented_image = image
return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
except Exception as e:
st.error(f"Error in segmentation: {str(e)}")
# ์›๋ณธ ์ด๋ฏธ์ง€๋ฅผ ๋‹ค์‹œ ์ฝ์–ด ๋ฐ˜ํ™˜
return Image.open(image_path)
def get_image_embedding(image):
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image_tensor)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy()
# Load and process data
@st.cache_data
def load_data():
with open('musinsa-final.json', 'r', encoding='utf-8') as f:
return json.load(f)
data = load_data()
# Process database with segmentation
@st.cache_data
def process_database():
database_embeddings = []
database_info = []
for item in data:
image_url = item['์ด๋ฏธ์ง€ ๋งํฌ'][0]
# '\ufeff์ƒํ’ˆ ID' ๋Œ€์‹  '์ƒํ’ˆ ID'๋ฅผ ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜, ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ˆ˜์ •
product_id = item.get('\ufeff์ƒํ’ˆ ID') or item.get('์ƒํ’ˆ ID')
image_path = "temp_{}.jpg".format(product_id)
response = requests.get(image_url)
with open(image_path, 'wb') as f:
f.write(response.content)
segmented_image = segment_image(image_path)
embedding = get_image_embedding(segmented_image)
database_embeddings.append(embedding)
database_info.append({
'id': product_id,
'category': item['์นดํ…Œ๊ณ ๋ฆฌ'],
'brand': item['๋ธŒ๋žœ๋“œ๋ช…'],
'name': item['์ œํ’ˆ๋ช…'],
'price': item['์ •๊ฐ€'],
'discount': item['ํ• ์ธ์œจ'],
'image_url': image_url
})
return np.vstack(database_embeddings), database_info
database_embeddings, database_info = process_database()
def find_similar_images(query_embedding, top_k=5):
similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
top_indices = np.argsort(similarities)[::-1][:top_k]
results = []
for idx in top_indices:
results.append({
'info': database_info[idx],
'similarity': similarities[idx]
})
return results
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Find Similar Items'):
with st.spinner('Processing...'):
# Save uploaded image temporarily
temp_path = "temp_upload.jpg"
image.save(temp_path)
# Segment the uploaded image
segmented_image = segment_image(temp_path)
st.image(segmented_image, caption='Segmented Image', use_column_width=True)
# Get embedding for segmented image
query_embedding = get_image_embedding(segmented_image)
similar_images = find_similar_images(query_embedding)
st.subheader("Similar Items:")
for img in similar_images:
col1, col2 = st.columns(2)
with col1:
st.image(img['info']['image_url'], use_column_width=True)
with col2:
st.write(f"Name: {img['info']['name']}")
st.write(f"Brand: {img['info']['brand']}")
st.write(f"Category: {img['info']['category']}")
st.write(f"Price: {img['info']['price']}")
st.write(f"Discount: {img['info']['discount']}%")
st.write(f"Similarity: {img['similarity']:.2f}")
else:
st.warning("Please enter your Roboflow API Key to use the app.")