JoJosmin's picture
Update app.py
4e8b6e2 verified
import streamlit as st
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
import chromadb
import logging
import io
import requests
from concurrent.futures import ThreadPoolExecutor
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CustomFashionEmbeddingFunction:
def __init__(self):
self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
def __call__(self, input):
try:
# ์ž…๋ ฅ์ด URL์ด๋‚˜ ๊ฒฝ๋กœ์ธ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
processed_images = []
for img in input:
if isinstance(img, (str, bytes)):
if isinstance(img, str):
response = requests.get(img)
img = Image.open(io.BytesIO(response.content)).convert('RGB')
else:
img = Image.open(io.BytesIO(img)).convert('RGB')
elif isinstance(img, np.ndarray):
img = Image.fromarray(img.astype('uint8')).convert('RGB')
processed_img = self.preprocess(img).unsqueeze(0).to(self.device)
processed_images.append(processed_img)
# ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ
batch = torch.cat(processed_images)
# CLIP ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ
with torch.no_grad():
clip_features = self.model.encode_image(batch)
clip_features = clip_features.cpu().numpy()
# ์ƒ‰์ƒ ํŠน์ง• ์ถ”์ถœ
color_features_list = []
for img in input:
if isinstance(img, (str, bytes)):
if isinstance(img, str):
response = requests.get(img)
img = Image.open(io.BytesIO(response.content)).convert('RGB')
else:
img = Image.open(io.BytesIO(img)).convert('RGB')
elif isinstance(img, np.ndarray):
img = Image.fromarray(img.astype('uint8')).convert('RGB')
color_features = self.extract_color_histogram(img)
color_features_list.append(color_features)
# ํŠน์ง• ๊ฒฐํ•ฉ
combined_embeddings = []
for clip_emb, color_feat in zip(clip_features, color_features_list):
# CLIP ์ž„๋ฒ ๋”ฉ์„ 768์ฐจ์›์œผ๋กœ ํŒจ๋”ฉ
if clip_emb.shape[0] < 768:
padding = np.zeros(768 - clip_emb.shape[0])
clip_emb = np.concatenate([clip_emb, padding])
else:
clip_emb = clip_emb[:768] # 768์ฐจ์›์œผ๋กœ ์ž๋ฅด๊ธฐ
# ์ƒ‰์ƒ ํŠน์ง•์„ 768์ฐจ์›์œผ๋กœ ํ™•์žฅ
color_features_expanded = np.repeat(color_feat, 32) # 24 * 32 = 768
# ์ •๊ทœํ™”
clip_emb = clip_emb / (np.linalg.norm(clip_emb) + 1e-8)
color_features_expanded = color_features_expanded / (np.linalg.norm(color_features_expanded) + 1e-8)
# ๊ฐ€์ค‘์น˜ ๊ฒฐํ•ฉ
combined = clip_emb * 0.7 + color_features_expanded * 0.3
combined = combined / (np.linalg.norm(combined) + 1e-8)
combined_embeddings.append(combined)
return np.array(combined_embeddings)
except Exception as e:
logger.error(f"Error in embedding function: {e}")
raise
def extract_color_histogram(self, image):
"""Extract color histogram from the image"""
try:
if isinstance(image, (str, bytes)):
if isinstance(image, str):
response = requests.get(image)
image = Image.open(io.BytesIO(response.content))
else:
image = Image.open(io.BytesIO(image))
if not isinstance(image, np.ndarray):
img_array = np.array(image)
else:
img_array = image
# HSV ๋ณ€ํ™˜ ๋ฐ ํžˆ์Šคํ† ๊ทธ๋žจ ๊ณ„์‚ฐ
img_hsv = Image.fromarray(img_array.astype('uint8')).convert('HSV')
hsv_pixels = np.array(img_hsv)
h_hist = np.histogram(hsv_pixels[:,:,0], bins=8, range=(0, 256))[0]
s_hist = np.histogram(hsv_pixels[:,:,1], bins=8, range=(0, 256))[0]
v_hist = np.histogram(hsv_pixels[:,:,2], bins=8, range=(0, 256))[0]
# ์ •๊ทœํ™”
h_hist = h_hist / (h_hist.sum() + 1e-8)
s_hist = s_hist / (s_hist.sum() + 1e-8)
v_hist = v_hist / (v_hist.sum() + 1e-8)
return np.concatenate([h_hist, s_hist, v_hist])
except Exception as e:
logger.error(f"Color histogram extraction error: {e}")
return np.zeros(24)
# Initialize session state
if 'image' not in st.session_state:
st.session_state.image = None
if 'detected_items' not in st.session_state:
st.session_state.detected_items = None
if 'selected_item_index' not in st.session_state:
st.session_state.selected_item_index = None
if 'upload_state' not in st.session_state:
st.session_state.upload_state = 'initial'
if 'search_clicked' not in st.session_state:
st.session_state.search_clicked = False
# Load segmentation model
@st.cache_resource
def load_segmentation_model():
try:
model_name = "mattmdjaga/segformer_b2_clothes"
image_processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForSemanticSegmentation.from_pretrained(model_name)
if torch.cuda.is_available():
model = model.to('cuda')
return model, image_processor
except Exception as e:
logger.error(f"Error loading segmentation model: {e}")
raise
# ChromaDB ์„ค์ •
def setup_multimodal_collection():
"""๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ์ปฌ๋ ‰์…˜ ์„ค์ •"""
try:
client = chromadb.PersistentClient(path="./fashion_multimodal_db_original")
embedding_function = CustomFashionEmbeddingFunction()
data_loader = ImageLoader()
# ๊ธฐ์กด ์ปฌ๋ ‰์…˜ ๊ฐ€์ ธ์˜ค๊ธฐ
try:
collection = client.get_collection(
name="fashion_multimodal",
embedding_function=embedding_function,
data_loader=data_loader
)
logger.info("Successfully connected to existing clothes_multimodal collection")
return collection
except Exception as e:
logger.error(f"Error getting existing collection: {e}")
# ์ปฌ๋ ‰์…˜์ด ์—†๋Š” ๊ฒฝ์šฐ์—๋งŒ ์ƒˆ๋กœ ์ƒ์„ฑ
collection = client.create_collection(
name="clothes_multimodal",
embedding_function=embedding_function,
data_loader=data_loader
)
logger.info("Created new clothes_multimodal collection")
return collection
except Exception as e:
logger.error(f"Error setting up multimodal collection: {e}")
raise
def process_segmentation(image):
"""Segmentation processing"""
try:
model, image_processor = load_segmentation_model()
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
inputs = image_processor(image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to('cuda') for k, v in inputs.items()}
# ์ถ”๋ก 
with torch.no_grad():
outputs = model(**inputs)
# ๋กœ์ง ๋ฐ ํ›„์ฒ˜๋ฆฌ
logits = outputs.logits.cpu()
upsampled_logits = torch.nn.functional.interpolate(
logits,
size=image.size[::-1], # (height, width)
mode="bilinear",
align_corners=False,
)
# ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋งˆ์Šคํฌ ์ƒ์„ฑ
seg_masks = upsampled_logits.argmax(dim=1).numpy()
processed_items = []
unique_labels = np.unique(seg_masks)
for label_idx in unique_labels:
if label_idx == 0: # background
continue
mask = (seg_masks[0] == label_idx).astype(float)
processed_segment = {
'label': f"Item_{label_idx}", # ๋ผ๋ฒจ ๋งคํ•‘์ด ํ•„์š”ํ•˜๋‹ค๋ฉด ์—ฌ๊ธฐ์„œ ์ฒ˜๋ฆฌ
'score': 1.0, # confidence score ๊ณ„์‚ฐ์ด ํ•„์š”ํ•˜๋‹ค๋ฉด ์ถ”๊ฐ€
'mask': mask
}
processed_items.append(processed_segment)
logger.info(f"Successfully processed {len(processed_items)} segments")
return processed_items
except Exception as e:
logger.error(f"Segmentation error: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return []
def search_similar_items(image, mask=None, top_k=10):
"""๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰"""
try:
collection = setup_multimodal_collection()
# ๋งˆ์Šคํฌ ์ ์šฉ
if mask is not None:
mask_3d = np.stack([mask] * 3, axis=-1)
masked_image = np.array(image) * mask_3d
query_image = Image.fromarray(masked_image.astype(np.uint8))
else:
query_image = image
# ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰
results = collection.query(
query_images=[np.array(query_image)],
n_results=top_k,
include=['metadatas', 'distances']
)
if not results or 'metadatas' not in results:
return []
similar_items = []
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
# L2 ๊ฑฐ๋ฆฌ๋ฅผ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„๋กœ ๋ณ€ํ™˜
# ์ •๊ทœํ™”๋œ ๋ฒกํ„ฐ ๊ฐ„์˜ L2 ๊ฑฐ๋ฆฌ(d)์™€ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„(cos_sim) ๊ด€๊ณ„:
# d^2 = 2(1 - cos_sim)
# cos_sim = 1 - (d^2/2)
cosine_similarity = 1 - (distance ** 2 / 2)
# -1~1 ๋ฒ”์œ„์˜ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„๋ฅผ 0~100 ๋ฒ”์œ„๋กœ ๋ณ€ํ™˜
similarity_score = ((cosine_similarity + 1) / 2) * 100
item_data = metadata.copy()
item_data['similarity_score'] = similarity_score
similar_items.append(item_data)
similar_items.sort(key=lambda x: x['similarity_score'], reverse=True)
return similar_items
except Exception as e:
logger.error(f"Multimodal search error: {e}")
return []
def update_db_with_multimodal():
"""DB๋ฅผ ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ๋ฐฉ์‹์œผ๋กœ ์—…๋ฐ์ดํŠธ"""
try:
# ์ƒˆ ์ปฌ๋ ‰์…˜ ์ƒ์„ฑ
collection = setup_multimodal_collection()
# ๊ธฐ์กด ์ปฌ๋ ‰์…˜์—์„œ ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ
client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
old_collection = client.get_collection("clothes")
old_data = old_collection.get(include=['metadatas'])
total_items = len(old_data['metadatas'])
progress_bar = st.progress(0)
status_text = st.empty()
batch_size = 100
successful_updates = 0
failed_updates = 0
for i in range(0, total_items, batch_size):
batch = old_data['metadatas'][i:i + batch_size]
images = []
valid_metadatas = []
valid_ids = []
for metadata in batch:
try:
if 'image_url' in metadata:
response = requests.get(metadata['image_url'])
img = Image.open(io.BytesIO(response.content)).convert('RGB')
images.append(np.array(img))
valid_metadatas.append(metadata)
valid_ids.append(metadata.get('id', str(hash(metadata['image_url']))))
successful_updates += 1
except Exception as e:
logger.error(f"Error processing image: {e}")
failed_updates += 1
continue
if images:
collection.add(
ids=valid_ids,
images=images,
metadatas=valid_metadatas
)
# Update progress
progress = (i + len(batch)) / total_items
progress_bar.progress(progress)
status_text.text(f"Processing: {i + len(batch)}/{total_items} items. "
f"Success: {successful_updates}, Failed: {failed_updates}")
status_text.text(f"Update completed. Successfully processed: {successful_updates}, "
f"Failed: {failed_updates}")
return True
except Exception as e:
logger.error(f"Multimodal DB update error: {e}")
return False
def show_similar_items(similar_items):
"""Display similar items in a structured format with similarity scores"""
if not similar_items:
st.warning("No similar items found.")
return
st.subheader("Similar Items:")
items_per_row = 2
for i in range(0, len(similar_items), items_per_row):
cols = st.columns(items_per_row)
for j, col in enumerate(cols):
if i + j < len(similar_items):
item = similar_items[i + j]
with col:
try:
if 'image_url' in item:
st.image(item['image_url'], use_column_width=True)
st.markdown(f"**Similarity: {item['similarity_score']:.1f}%**")
st.write(f"Brand: {item.get('brand', 'Unknown')}")
name = item.get('name', 'Unknown')
if len(name) > 50:
name = name[:47] + "..."
st.write(f"Name: {name}")
price = item.get('price', 0)
if isinstance(price, (int, float)):
st.write(f"Price: {price:,}์›")
else:
st.write(f"Price: {price}")
if 'discount' in item and item['discount']:
st.write(f"Discount: {item['discount']}%")
if 'original_price' in item:
st.write(f"Original: {item['original_price']:,}์›")
st.divider()
except Exception as e:
logger.error(f"Error displaying item: {e}")
st.error("Error displaying this item")
def process_search(image, mask, num_results):
"""์œ ์‚ฌ ์•„์ดํ…œ ๊ฒ€์ƒ‰ ์ฒ˜๋ฆฌ"""
try:
with st.spinner("Finding similar items..."):
similar_items = search_similar_items(image, mask, num_results)
return similar_items
except Exception as e:
logger.error(f"Search processing error: {e}")
return None
def handle_file_upload():
if st.session_state.uploaded_file is not None:
image = Image.open(st.session_state.uploaded_file).convert('RGB')
st.session_state.image = image
st.session_state.upload_state = 'image_uploaded'
st.rerun()
def handle_detection():
if st.session_state.image is not None:
detected_items = process_segmentation(st.session_state.image)
st.session_state.detected_items = detected_items
st.session_state.upload_state = 'items_detected'
st.rerun()
def handle_search():
st.session_state.search_clicked = True
def main():
st.title("Fashion Search App")
# Admin controls in sidebar
st.sidebar.title("Admin Controls")
if st.sidebar.checkbox("Show Admin Interface"):
if st.sidebar.button("Update Database (Multimodal)"):
with st.spinner("Updating database with multimodal support..."):
success = update_db_with_multimodal()
if success:
st.sidebar.success("Database updated successfully!")
else:
st.sidebar.error("Failed to update database")
st.divider()
# ํŒŒ์ผ ์—…๋กœ๋”
if st.session_state.upload_state == 'initial':
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
key='uploaded_file', on_change=handle_file_upload)
# ์ด๋ฏธ์ง€๊ฐ€ ์—…๋กœ๋“œ๋œ ์ƒํƒœ
if st.session_state.image is not None:
st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
if st.session_state.detected_items is None:
if st.button("Detect Items", key='detect_button', on_click=handle_detection):
pass
# ๊ฒ€์ถœ๋œ ์•„์ดํ…œ ํ‘œ์‹œ ๋ฐ ๊ฒ€์ƒ‰
if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0:
cols = st.columns(2)
for idx, item in enumerate(st.session_state.detected_items):
with cols[idx % 2]:
try:
if item.get('mask') is not None:
masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2)
st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}")
st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}")
score = item.get('score')
if score is not None and isinstance(score, (int, float)):
st.write(f"Confidence: {score*100:.1f}%")
else:
st.write("Confidence: N/A")
except Exception as e:
logger.error(f"Error displaying item {idx}: {str(e)}")
st.error(f"Error displaying item {idx}")
valid_items = [i for i in range(len(st.session_state.detected_items))
if st.session_state.detected_items[i].get('mask') is not None]
if not valid_items:
st.warning("No valid items detected for search.")
return
selected_idx = st.selectbox(
"Select item to search:",
valid_items,
format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}",
key='item_selector'
)
search_col1, search_col2 = st.columns([1, 2])
with search_col1:
search_clicked = st.button("Search Similar Items",
key='search_button',
type="primary")
with search_col2:
num_results = st.slider("Number of results:",
min_value=1,
max_value=20,
value=5,
key='num_results')
if search_clicked or st.session_state.get('search_clicked', False):
st.session_state.search_clicked = True
selected_item = st.session_state.detected_items[selected_idx]
if selected_item.get('mask') is None:
st.error("Selected item has no valid mask for search.")
return
if 'search_results' not in st.session_state:
similar_items = process_search(st.session_state.image,
selected_item['mask'],
num_results)
st.session_state.search_results = similar_items
if st.session_state.search_results:
show_similar_items(st.session_state.search_results)
else:
st.warning("No similar items found.")
# ์ƒˆ ๊ฒ€์ƒ‰ ๋ฒ„ํŠผ
if st.button("Start New Search", key='new_search'):
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
if __name__ == "__main__":
print('์‹œ์ž‘')
main()