|
import streamlit as st |
|
import open_clip |
|
import torch |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
import time |
|
import numpy as np |
|
from ultralytics import YOLO |
|
import chromadb |
|
from transformers import pipeline |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") |
|
|
|
|
|
@st.cache_resource |
|
def load_clip_model(): |
|
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') |
|
tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP') |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
return model, preprocess_val, tokenizer, device |
|
|
|
clip_model, preprocess_val, tokenizer, device = load_clip_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client = chromadb.PersistentClient(path="./clothesDB_202410_2") |
|
|
|
collection = client.get_collection(name="clothes") |
|
|
|
|
|
|
|
|
|
def load_image_from_url(url, max_retries=3): |
|
for attempt in range(max_retries): |
|
try: |
|
response = requests.get(url, timeout=10) |
|
response.raise_for_status() |
|
img = Image.open(BytesIO(response.content)).convert('RGB') |
|
return img |
|
except (requests.RequestException, Image.UnidentifiedImageError) as e: |
|
if attempt < max_retries - 1: |
|
time.sleep(1) |
|
else: |
|
return None |
|
|
|
def get_image_embedding(image): |
|
image_tensor = preprocess_val(image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
image_features = clip_model.encode_image(image_tensor) |
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
return image_features.cpu().numpy().flatten() |
|
|
|
def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]): |
|
|
|
segments = segmenter(img) |
|
|
|
|
|
mask_list = [] |
|
detected_categories = [] |
|
for s in segments: |
|
if s['label'] in clothes: |
|
mask_list.append(s['mask']) |
|
detected_categories.append(s['label']) |
|
|
|
|
|
final_mask = np.zeros_like(np.array(img)[:, :, 0]) |
|
for mask in mask_list: |
|
current_mask = np.array(mask) |
|
final_mask = np.maximum(final_mask, current_mask) |
|
|
|
|
|
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255) |
|
|
|
|
|
img_with_alpha = img.convert("RGBA") |
|
img_with_alpha.putalpha(final_mask) |
|
|
|
return img_with_alpha.convert("RGB"), final_mask, detected_categories |
|
|
|
def find_similar_images(query_embedding, collection, top_k=5): |
|
query_embedding = query_embedding.reshape(1, -1) |
|
results = collection.query( |
|
query_embeddings=query_embedding, |
|
n_results=top_k, |
|
include=['metadatas', 'distances'] |
|
) |
|
|
|
top_metadatas = results['metadatas'][0] |
|
top_distances = results['distances'][0] |
|
|
|
structured_results = [] |
|
for metadata, distance in zip(top_metadatas, top_distances): |
|
structured_results.append({ |
|
'info': metadata, |
|
'similarity': 1 - distance |
|
}) |
|
|
|
return structured_results |
|
|
|
|
|
|
|
|
|
|
|
if 'step' not in st.session_state: |
|
st.session_state.step = 'input' |
|
if 'query_image_url' not in st.session_state: |
|
st.session_state.query_image_url = '' |
|
if 'detections' not in st.session_state: |
|
st.session_state.detections = [] |
|
if 'segmented_image' not in st.session_state: |
|
st.session_state.segmented_image = None |
|
if 'selected_category' not in st.session_state: |
|
st.session_state.selected_category = None |
|
|
|
|
|
st.title("Advanced Fashion Search App") |
|
|
|
|
|
if st.session_state.step == 'input': |
|
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url) |
|
if st.button("Detect Clothing"): |
|
if st.session_state.query_image_url: |
|
query_image = load_image_from_url(st.session_state.query_image_url) |
|
if query_image is not None: |
|
st.session_state.query_image = query_image |
|
|
|
segmented_image, final_mask, detected_categories = segment_clothing(query_image) |
|
st.session_state.segmented_image = segmented_image |
|
st.session_state.detections = detected_categories |
|
st.image(segmented_image, caption="Segmented Image", use_column_width=True) |
|
if st.session_state.detections: |
|
st.session_state.step = 'select_category' |
|
else: |
|
st.warning("No clothing items detected in the image.") |
|
else: |
|
st.error("Failed to load the image. Please try another URL.") |
|
else: |
|
st.warning("Please enter an image URL.") |
|
|
|
elif st.session_state.step == 'select_category': |
|
st.image(st.session_state.segmented_image, caption="Segmented Image with Detected Categories", use_column_width=True) |
|
st.subheader("Detected Clothing Categories:") |
|
|
|
if st.session_state.detections: |
|
selected_category = st.selectbox("Select a category to search:", st.session_state.detections) |
|
if st.button("Search Similar Items"): |
|
st.session_state.selected_category = selected_category |
|
st.session_state.step = 'show_results' |
|
else: |
|
st.warning("No categories detected.") |
|
|
|
elif st.session_state.step == 'show_results': |
|
original_image = st.session_state.query_image.convert("RGB") |
|
st.image(original_image, caption="Original Image", use_column_width=True) |
|
|
|
|
|
query_embedding = get_image_embedding(st.session_state.segmented_image) |
|
|
|
similar_images = find_similar_images(query_embedding, collection) |
|
|
|
st.subheader("Similar Items:") |
|
for img in similar_images: |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
|
|
st.image(img['info']['image_url'], use_column_width=True) |
|
with col2: |
|
st.write(f"Name: {img['info']['name']}") |
|
st.write(f"Brand: {img['info']['brand']}") |
|
category = img['info'].get('category') |
|
if category: |
|
st.write(f"Category: {category}") |
|
st.write(f"Price: {img['info']['price']}") |
|
st.write(f"Discount: {img['info']['discount']}%") |
|
st.write(f"Similarity: {img['similarity']:.2f}") |
|
|
|
if st.button("Start New Search"): |
|
st.session_state.step = 'input' |
|
st.session_state.query_image_url = '' |
|
st.session_state.detections = [] |
|
st.session_state.segmented_image = None |