|
import streamlit as st |
|
import open_clip |
|
import torch |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
import time |
|
import json |
|
import numpy as np |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') |
|
tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP') |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
return model, preprocess_val, tokenizer, device |
|
|
|
model, preprocess_val, tokenizer, device = load_model() |
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
with open('./musinsa-final.json', 'r', encoding='utf-8') as f: |
|
return json.load(f) |
|
|
|
data = load_data() |
|
|
|
|
|
def load_image_from_url(url, max_retries=3): |
|
for attempt in range(max_retries): |
|
try: |
|
response = requests.get(url, timeout=10) |
|
response.raise_for_status() |
|
img = Image.open(BytesIO(response.content)).convert('RGB') |
|
return img |
|
except (requests.RequestException, Image.UnidentifiedImageError) as e: |
|
|
|
if attempt < max_retries - 1: |
|
time.sleep(1) |
|
else: |
|
|
|
return None |
|
|
|
def get_image_embedding_from_url(image_url): |
|
image = load_image_from_url(image_url) |
|
if image is None: |
|
return None |
|
|
|
image_tensor = preprocess_val(image).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
image_features = model.encode_image(image_tensor) |
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
|
|
return image_features.cpu().numpy() |
|
|
|
@st.cache_data |
|
def process_database(): |
|
database_embeddings = [] |
|
database_info = [] |
|
|
|
for item in data: |
|
image_url = item['์ด๋ฏธ์ง ๋งํฌ'][0] |
|
embedding = get_image_embedding_from_url(image_url) |
|
|
|
if embedding is not None: |
|
database_embeddings.append(embedding) |
|
database_info.append({ |
|
'id': item['\ufeff์ํ ID'], |
|
'category': item['์นดํ
๊ณ ๋ฆฌ'], |
|
'brand': item['๋ธ๋๋๋ช
'], |
|
'name': item['์ ํ๋ช
'], |
|
'price': item['์ ๊ฐ'], |
|
'discount': item['ํ ์ธ์จ'], |
|
'image_url': image_url |
|
}) |
|
else: |
|
st.warning(f"Skipping item {item['๏ปฟ์ํ ID']} due to image loading failure") |
|
|
|
if database_embeddings: |
|
return np.vstack(database_embeddings), database_info |
|
else: |
|
st.error("No valid embeddings were generated.") |
|
return None, None |
|
|
|
database_embeddings, database_info = process_database() |
|
|
|
def get_text_embedding(text): |
|
text_tokens = tokenizer([text]).to(device) |
|
|
|
with torch.no_grad(): |
|
text_features = model.encode_text(text_tokens) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
return text_features.cpu().numpy() |
|
|
|
def find_similar_images(query_embedding, top_k=5): |
|
similarities = np.dot(database_embeddings, query_embedding.T).squeeze() |
|
top_indices = np.argsort(similarities)[::-1][:top_k] |
|
|
|
results = [] |
|
for idx in top_indices: |
|
results.append({ |
|
'info': database_info[idx], |
|
'similarity': similarities[idx] |
|
}) |
|
|
|
return results |
|
|
|
|
|
st.title("Fashion Search App") |
|
|
|
search_type = st.radio("Search by:", ("Image URL", "Text")) |
|
|
|
if search_type == "Image URL": |
|
query_image_url = st.text_input("Enter image URL:") |
|
if st.button("Search by Image"): |
|
if query_image_url: |
|
query_embedding = get_image_embedding_from_url(query_image_url) |
|
if query_embedding is not None: |
|
similar_images = find_similar_images(query_embedding) |
|
st.image(query_image_url, caption="Query Image", use_column_width=True) |
|
st.subheader("Similar Items:") |
|
for img in similar_images: |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.image(img['info']['image_url'], use_column_width=True) |
|
with col2: |
|
st.write(f"Name: {img['info']['name']}") |
|
st.write(f"Brand: {img['info']['brand']}") |
|
st.write(f"Category: {img['info']['category']}") |
|
st.write(f"Price: {img['info']['price']}") |
|
st.write(f"Discount: {img['info']['discount']}%") |
|
st.write(f"Similarity: {img['similarity']:.2f}") |
|
else: |
|
st.error("Failed to process the image. Please try another URL.") |
|
else: |
|
st.warning("Please enter an image URL.") |
|
|
|
else: |
|
query_text = st.text_input("Enter search text:") |
|
if st.button("Search by Text"): |
|
if query_text: |
|
text_embedding = get_text_embedding(query_text) |
|
similar_images = find_similar_images(text_embedding) |
|
st.subheader("Similar Items:") |
|
for img in similar_images: |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.image(img['info']['image_url'], use_column_width=True) |
|
with col2: |
|
st.write(f"Name: {img['info']['name']}") |
|
st.write(f"Brand: {img['info']['brand']}") |
|
st.write(f"Category: {img['info']['category']}") |
|
st.write(f"Price: {img['info']['price']}") |
|
st.write(f"Discount: {img['info']['discount']}%") |
|
st.write(f"Similarity: {img['similarity']:.2f}") |
|
else: |
|
st.warning("Please enter a search text.") |