import streamlit as st import open_clip import torch from PIL import Image import numpy as np from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation import chromadb import logging import io import requests from concurrent.futures import ThreadPoolExecutor from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction from chromadb.utils.data_loaders import ImageLoader # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class CustomFashionEmbeddingFunction: def __init__(self): self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(self.device) def __call__(self, input): try: # 입력이 URL이나 경로인 경우 처리 processed_images = [] for img in input: if isinstance(img, (str, bytes)): if isinstance(img, str): response = requests.get(img) img = Image.open(io.BytesIO(response.content)).convert('RGB') else: img = Image.open(io.BytesIO(img)).convert('RGB') elif isinstance(img, np.ndarray): img = Image.fromarray(img.astype('uint8')).convert('RGB') processed_img = self.preprocess(img).unsqueeze(0).to(self.device) processed_images.append(processed_img) # 배치 처리 batch = torch.cat(processed_images) # CLIP 임베딩 추출 with torch.no_grad(): clip_features = self.model.encode_image(batch) clip_features = clip_features.cpu().numpy() # 색상 특징 추출 color_features_list = [] for img in input: if isinstance(img, (str, bytes)): if isinstance(img, str): response = requests.get(img) img = Image.open(io.BytesIO(response.content)).convert('RGB') else: img = Image.open(io.BytesIO(img)).convert('RGB') elif isinstance(img, np.ndarray): img = Image.fromarray(img.astype('uint8')).convert('RGB') color_features = self.extract_color_histogram(img) color_features_list.append(color_features) # 특징 결합 combined_embeddings = [] for clip_emb, color_feat in zip(clip_features, color_features_list): # CLIP 임베딩을 768차원으로 패딩 if clip_emb.shape[0] < 768: padding = np.zeros(768 - clip_emb.shape[0]) clip_emb = np.concatenate([clip_emb, padding]) else: clip_emb = clip_emb[:768] # 768차원으로 자르기 # 색상 특징을 768차원으로 확장 color_features_expanded = np.repeat(color_feat, 32) # 24 * 32 = 768 # 정규화 clip_emb = clip_emb / (np.linalg.norm(clip_emb) + 1e-8) color_features_expanded = color_features_expanded / (np.linalg.norm(color_features_expanded) + 1e-8) # 가중치 결합 combined = clip_emb * 0.7 + color_features_expanded * 0.3 combined = combined / (np.linalg.norm(combined) + 1e-8) combined_embeddings.append(combined) return np.array(combined_embeddings) except Exception as e: logger.error(f"Error in embedding function: {e}") raise def extract_color_histogram(self, image): """Extract color histogram from the image""" try: if isinstance(image, (str, bytes)): if isinstance(image, str): response = requests.get(image) image = Image.open(io.BytesIO(response.content)) else: image = Image.open(io.BytesIO(image)) if not isinstance(image, np.ndarray): img_array = np.array(image) else: img_array = image # HSV 변환 및 히스토그램 계산 img_hsv = Image.fromarray(img_array.astype('uint8')).convert('HSV') hsv_pixels = np.array(img_hsv) h_hist = np.histogram(hsv_pixels[:,:,0], bins=8, range=(0, 256))[0] s_hist = np.histogram(hsv_pixels[:,:,1], bins=8, range=(0, 256))[0] v_hist = np.histogram(hsv_pixels[:,:,2], bins=8, range=(0, 256))[0] # 정규화 h_hist = h_hist / (h_hist.sum() + 1e-8) s_hist = s_hist / (s_hist.sum() + 1e-8) v_hist = v_hist / (v_hist.sum() + 1e-8) return np.concatenate([h_hist, s_hist, v_hist]) except Exception as e: logger.error(f"Color histogram extraction error: {e}") return np.zeros(24) # Initialize session state if 'image' not in st.session_state: st.session_state.image = None if 'detected_items' not in st.session_state: st.session_state.detected_items = None if 'selected_item_index' not in st.session_state: st.session_state.selected_item_index = None if 'upload_state' not in st.session_state: st.session_state.upload_state = 'initial' if 'search_clicked' not in st.session_state: st.session_state.search_clicked = False # Load segmentation model @st.cache_resource def load_segmentation_model(): try: model_name = "mattmdjaga/segformer_b2_clothes" image_processor = AutoImageProcessor.from_pretrained(model_name) model = AutoModelForSemanticSegmentation.from_pretrained(model_name) if torch.cuda.is_available(): model = model.to('cuda') return model, image_processor except Exception as e: logger.error(f"Error loading segmentation model: {e}") raise # ChromaDB 설정 def setup_multimodal_collection(): """멀티모달 컬렉션 설정""" try: client = chromadb.PersistentClient(path="./fashion_multimodal_db") embedding_function = CustomFashionEmbeddingFunction() data_loader = ImageLoader() # 기존 컬렉션 가져오기 try: collection = client.get_collection( name="fashion_multimodal", embedding_function=embedding_function, data_loader=data_loader ) logger.info("Successfully connected to existing clothes_multimodal collection") return collection except Exception as e: logger.error(f"Error getting existing collection: {e}") # 컬렉션이 없는 경우에만 새로 생성 collection = client.create_collection( name="clothes_multimodal", embedding_function=embedding_function, data_loader=data_loader ) logger.info("Created new clothes_multimodal collection") return collection except Exception as e: logger.error(f"Error setting up multimodal collection: {e}") raise def process_segmentation(image): """Segmentation processing""" try: model, image_processor = load_segmentation_model() # 이미지 전처리 inputs = image_processor(image, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to('cuda') for k, v in inputs.items()} # 추론 with torch.no_grad(): outputs = model(**inputs) # 로직 및 후처리 logits = outputs.logits.cpu() upsampled_logits = torch.nn.functional.interpolate( logits, size=image.size[::-1], # (height, width) mode="bilinear", align_corners=False, ) # 세그멘테이션 마스크 생성 seg_masks = upsampled_logits.argmax(dim=1).numpy() processed_items = [] unique_labels = np.unique(seg_masks) for label_idx in unique_labels: if label_idx == 0: # background continue mask = (seg_masks[0] == label_idx).astype(float) processed_segment = { 'label': f"Item_{label_idx}", # 라벨 매핑이 필요하다면 여기서 처리 'score': 1.0, # confidence score 계산이 필요하다면 추가 'mask': mask } processed_items.append(processed_segment) logger.info(f"Successfully processed {len(processed_items)} segments") return processed_items except Exception as e: logger.error(f"Segmentation error: {str(e)}") import traceback logger.error(traceback.format_exc()) return [] def search_similar_items(image, mask=None, top_k=10): """멀티모달 검색 수행""" try: collection = setup_multimodal_collection() # 마스크 적용 if mask is not None: mask_3d = np.stack([mask] * 3, axis=-1) masked_image = np.array(image) * mask_3d query_image = Image.fromarray(masked_image.astype(np.uint8)) else: query_image = image # 검색 수행 results = collection.query( query_images=[np.array(query_image)], n_results=top_k, include=['metadatas', 'distances'] ) if not results or 'metadatas' not in results: return [] similar_items = [] for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): # L2 거리를 코사인 유사도로 변환 # 정규화된 벡터 간의 L2 거리(d)와 코사인 유사도(cos_sim) 관계: # d^2 = 2(1 - cos_sim) # cos_sim = 1 - (d^2/2) cosine_similarity = 1 - (distance ** 2 / 2) # -1~1 범위의 코사인 유사도를 0~100 범위로 변환 similarity_score = ((cosine_similarity + 1) / 2) * 100 item_data = metadata.copy() item_data['similarity_score'] = similarity_score similar_items.append(item_data) similar_items.sort(key=lambda x: x['similarity_score'], reverse=True) return similar_items except Exception as e: logger.error(f"Multimodal search error: {e}") return [] def update_db_with_multimodal(): """DB를 멀티모달 방식으로 업데이트""" try: # 새 컬렉션 생성 collection = setup_multimodal_collection() # 기존 컬렉션에서 데이터 가져오기 client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa") old_collection = client.get_collection("clothes") old_data = old_collection.get(include=['metadatas']) total_items = len(old_data['metadatas']) progress_bar = st.progress(0) status_text = st.empty() batch_size = 100 successful_updates = 0 failed_updates = 0 for i in range(0, total_items, batch_size): batch = old_data['metadatas'][i:i + batch_size] images = [] valid_metadatas = [] valid_ids = [] for metadata in batch: try: if 'image_url' in metadata: response = requests.get(metadata['image_url']) img = Image.open(io.BytesIO(response.content)).convert('RGB') images.append(np.array(img)) valid_metadatas.append(metadata) valid_ids.append(metadata.get('id', str(hash(metadata['image_url'])))) successful_updates += 1 except Exception as e: logger.error(f"Error processing image: {e}") failed_updates += 1 continue if images: collection.add( ids=valid_ids, images=images, metadatas=valid_metadatas ) # Update progress progress = (i + len(batch)) / total_items progress_bar.progress(progress) status_text.text(f"Processing: {i + len(batch)}/{total_items} items. " f"Success: {successful_updates}, Failed: {failed_updates}") status_text.text(f"Update completed. Successfully processed: {successful_updates}, " f"Failed: {failed_updates}") return True except Exception as e: logger.error(f"Multimodal DB update error: {e}") return False def show_similar_items(similar_items): """Display similar items in a structured format with similarity scores""" if not similar_items: st.warning("No similar items found.") return st.subheader("Similar Items:") items_per_row = 2 for i in range(0, len(similar_items), items_per_row): cols = st.columns(items_per_row) for j, col in enumerate(cols): if i + j < len(similar_items): item = similar_items[i + j] with col: try: if 'image_url' in item: st.image(item['image_url'], use_column_width=True) st.markdown(f"**Similarity: {item['similarity_score']:.1f}%**") st.write(f"Brand: {item.get('brand', 'Unknown')}") name = item.get('name', 'Unknown') if len(name) > 50: name = name[:47] + "..." st.write(f"Name: {name}") price = item.get('price', 0) if isinstance(price, (int, float)): st.write(f"Price: {price:,}원") else: st.write(f"Price: {price}") if 'discount' in item and item['discount']: st.write(f"Discount: {item['discount']}%") if 'original_price' in item: st.write(f"Original: {item['original_price']:,}원") st.divider() except Exception as e: logger.error(f"Error displaying item: {e}") st.error("Error displaying this item") def process_search(image, mask, num_results): """유사 아이템 검색 처리""" try: with st.spinner("Finding similar items..."): similar_items = search_similar_items(image, mask, num_results) return similar_items except Exception as e: logger.error(f"Search processing error: {e}") return None def handle_file_upload(): if st.session_state.uploaded_file is not None: image = Image.open(st.session_state.uploaded_file).convert('RGB') st.session_state.image = image st.session_state.upload_state = 'image_uploaded' st.rerun() def handle_detection(): if st.session_state.image is not None: detected_items = process_segmentation(st.session_state.image) st.session_state.detected_items = detected_items st.session_state.upload_state = 'items_detected' st.rerun() def handle_search(): st.session_state.search_clicked = True def main(): st.title("Fashion Search App") # Admin controls in sidebar st.sidebar.title("Admin Controls") if st.sidebar.checkbox("Show Admin Interface"): if st.sidebar.button("Update Database (Multimodal)"): with st.spinner("Updating database with multimodal support..."): success = update_db_with_multimodal() if success: st.sidebar.success("Database updated successfully!") else: st.sidebar.error("Failed to update database") st.divider() # 파일 업로더 if st.session_state.upload_state == 'initial': uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'], key='uploaded_file', on_change=handle_file_upload) # 이미지가 업로드된 상태 if st.session_state.image is not None: st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True) if st.session_state.detected_items is None: if st.button("Detect Items", key='detect_button', on_click=handle_detection): pass # 검출된 아이템 표시 및 검색 if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0: cols = st.columns(2) for idx, item in enumerate(st.session_state.detected_items): with cols[idx % 2]: try: if item.get('mask') is not None: masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2) st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}") st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}") score = item.get('score') if score is not None and isinstance(score, (int, float)): st.write(f"Confidence: {score*100:.1f}%") else: st.write("Confidence: N/A") except Exception as e: logger.error(f"Error displaying item {idx}: {str(e)}") st.error(f"Error displaying item {idx}") valid_items = [i for i in range(len(st.session_state.detected_items)) if st.session_state.detected_items[i].get('mask') is not None] if not valid_items: st.warning("No valid items detected for search.") return selected_idx = st.selectbox( "Select item to search:", valid_items, format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}", key='item_selector' ) search_col1, search_col2 = st.columns([1, 2]) with search_col1: search_clicked = st.button("Search Similar Items", key='search_button', type="primary") with search_col2: num_results = st.slider("Number of results:", min_value=1, max_value=20, value=5, key='num_results') if search_clicked or st.session_state.get('search_clicked', False): st.session_state.search_clicked = True selected_item = st.session_state.detected_items[selected_idx] if selected_item.get('mask') is None: st.error("Selected item has no valid mask for search.") return if 'search_results' not in st.session_state: similar_items = process_search(st.session_state.image, selected_item['mask'], num_results) st.session_state.search_results = similar_items if st.session_state.search_results: show_similar_items(st.session_state.search_results) else: st.warning("No similar items found.") # 새 검색 버튼 if st.button("Start New Search", key='new_search'): for key in list(st.session_state.keys()): del st.session_state[key] st.rerun() if __name__ == "__main__": print('시작') main()