import chromadb import logging import open_clip import torch from PIL import Image import numpy as np from transformers import pipeline import requests import io import json import uuid from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm import os from io import BytesIO from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction from chromadb.utils.data_loaders import ImageLoader # 로깅 설정 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('fashion_db_creation.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) def load_models(): try: logger.info("Loading models...") # CLIP 모델 model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') # 세그멘테이션 모델 segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # 이미지 전처리를 위한 transforms 추가 from torchvision import transforms resize_transform = transforms.Compose([ transforms.Resize((224, 224)), # CLIP 입력 크기에 맞춤 transforms.ToTensor(), ]) return model, preprocess_val, segmenter, device, resize_transform except Exception as e: logger.error(f"Error loading models: {e}") raise def process_segmentation(image, segmenter): """Segmentation processing""" try: output = segmenter(image) if not output: logger.warning("No segments found in image") return None segment_sizes = [np.sum(seg['mask']) for seg in output] if not segment_sizes: return None largest_idx = np.argmax(segment_sizes) mask = output[largest_idx]['mask'] if not isinstance(mask, np.ndarray): mask = np.array(mask) if len(mask.shape) > 2: mask = mask[:, :, 0] mask = mask.astype(float) logger.info(f"Successfully created mask with shape {mask.shape}") return mask except Exception as e: logger.error(f"Segmentation error: {str(e)}") return None def load_image_from_url(url, max_retries=3): for attempt in range(max_retries): try: response = requests.get(url, timeout=10) response.raise_for_status() img = Image.open(BytesIO(response.content)).convert('RGB') return img except Exception as e: logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") if attempt < max_retries - 1: time.sleep(1) else: logger.error(f"Failed to load image from {url} after {max_retries} attempts") return None def extract_features(image, mask, model, preprocess_val, device): """Advanced feature extraction with mask-based attention""" try: img_array = np.array(image) mask = np.expand_dims(mask, axis=2) mask_3channel = np.repeat(mask, 3, axis=2) # 1. 원본 이미지에서 특징 추출 image_tensor_original = preprocess_val(image).unsqueeze(0).to(device) # 2. 마스크된 이미지(흰색 배경) 특징 추출 masked_img_white = img_array * mask_3channel + (1 - mask_3channel) * 255 image_masked_white = Image.fromarray(masked_img_white.astype(np.uint8)) image_tensor_masked = preprocess_val(image_masked_white).unsqueeze(0).to(device) # 3. 의류 부분만 크롭한 버전 특징 추출 bbox = get_bbox_from_mask(mask) # 마스크로부터 경계상자 추출 cropped_img = crop_and_resize(img_array * mask_3channel, bbox) image_cropped = Image.fromarray(cropped_img.astype(np.uint8)) image_tensor_cropped = preprocess_val(image_cropped).unsqueeze(0).to(device) with torch.no_grad(): # 세 가지 버전의 특징 추출 features_original = model.encode_image(image_tensor_original) features_masked = model.encode_image(image_tensor_masked) features_cropped = model.encode_image(image_tensor_cropped) # 가중치를 사용한 특징 결합 combined_features = ( 0.2 * features_original + 0.3 * features_masked + 0.5 * features_cropped ) # 정규화 combined_features /= combined_features.norm(dim=-1, keepdim=True) return combined_features.cpu().numpy().flatten() except Exception as e: logger.error(f"Feature extraction error: {e}") return None def get_bbox_from_mask(mask): """마스크로부터 경계상자 좌표 추출""" rows = np.any(mask, axis=1) cols = np.any(mask, axis=0) rmin, rmax = np.where(rows)[0][[0, -1]] cmin, cmax = np.where(cols)[0][[0, -1]] # 여유 공간 추가 padding = 10 rmin = max(rmin - padding, 0) rmax = min(rmax + padding, mask.shape[0]) cmin = max(cmin - padding, 0) cmax = min(cmax + padding, mask.shape[1]) return rmin, rmax, cmin, cmax def crop_and_resize(image, bbox): """경계상자로 이미지 크롭 및 리사이즈""" rmin, rmax, cmin, cmax = bbox cropped = image[rmin:rmax, cmin:cmax] # PIL을 사용하여 정사각형으로 리사이즈 size = max(cropped.shape[:2]) square_img = np.full((size, size, 3), 255, dtype=np.uint8) start_h = (size - cropped.shape[0]) // 2 start_w = (size - cropped.shape[1]) // 2 square_img[start_h:start_h+cropped.shape[0], start_w:start_w+cropped.shape[1]] = cropped return square_img def process_item(item, model, preprocess_val, segmenter, device, resize_transform): """Process single item from JSON data""" try: # 이미지 URL 추출 if '이미지 링크' in item: image_url = item['이미지 링크'] elif '이미지 URL' in item: image_url = item['이미지 URL'] else: logger.warning(f"No image URL found in item") return None # 메타데이터 생성 metadata = create_metadata(item) # 이미지 다운로드 image = load_image_from_url(image_url) if image is None: logger.warning(f"Failed to load image from {image_url}") return None # 세그멘테이션 수행 mask = process_segmentation(image, segmenter) if mask is None: logger.warning(f"Failed to create segmentation mask for {image_url}") return None # 새로운 특징 추출 방식 적용 try: features = extract_features(image, mask, model, preprocess_val, device) if features is None: raise ValueError("Feature extraction failed") # 디버깅용 이미지 저장 (선택사항) # save_debug_images(image, mask, image_url) except Exception as e: logger.error(f"Feature extraction failed for {image_url}: {str(e)}") return None return { 'id': metadata['product_id'], 'embedding': features.tolist(), 'metadata': metadata, 'image_uri': image_url } except Exception as e: logger.error(f"Error processing item: {str(e)}") return None # 디버깅용 이미지 저장 함수 (선택사항) def save_debug_images(image, mask, url): try: debug_dir = "debug_images" os.makedirs(debug_dir, exist_ok=True) # URL에서 파일명 추출 filename = url.split('/')[-1].split('?')[0] # 원본, 마스크, 처리된 이미지 저장 image.save(f"{debug_dir}/original_{filename}") mask_img = Image.fromarray((mask * 255).astype(np.uint8)) mask_img.save(f"{debug_dir}/mask_{filename}") except Exception as e: logger.warning(f"Failed to save debug images: {str(e)}") def create_metadata(item): """Create standardized metadata from different JSON formats""" metadata = {} # 상품 ID 처리 개선 if '상품 ID' in item: # 무신사 형식 metadata['product_id'] = item['상품 ID'] else: # 11번가/G마켓의 경우 상품명과 URL로 유니크한 ID 생성 unique_string = f"{item.get('상품명', '')}{item.get('이미지 URL', '')}" metadata['product_id'] = str(hash(unique_string)) # 나머지 메타데이터 처리 metadata['brand'] = item.get('브랜드명', 'unknown') metadata['name'] = item.get('제품명') or item.get('상품명', 'unknown') metadata['price'] = (item.get('정가') or item.get('가격') or item.get('판매가', 'unknown')) metadata['discount'] = item.get('할인율', 'unknown') if '카테고리' in item: if isinstance(item['카테고리'], list): metadata['category'] = '/'.join(item['카테고리']) else: metadata['category'] = item['카테고리'] else: # 11번가/G마켓의 경우 상품명에서 카테고리 추출 시도 name = metadata['name'].lower() categories = ['원피스', '셔츠', '블라우스', '니트', '가디건', '스커트', '팬츠', '셋업', '아우터', '자켓'] found_categories = [cat for cat in categories if cat in name] metadata['category'] = '/'.join(found_categories) if found_categories else 'unknown' metadata['image_url'] = (item.get('이미지 링크') or item.get('이미지 URL', 'unknown')) # 쇼핑몰 출처 추가 if '이미지 링크' in item: metadata['source'] = 'musinsa' elif 'cdn.011st.com' in metadata['image_url']: metadata['source'] = '11st' elif 'gmarket' in metadata['image_url']: metadata['source'] = 'gmarket' else: metadata['source'] = 'unknown' return metadata def create_multimodal_fashion_db(json_files): try: logger.info("Starting multimodal fashion database creation") # 모델 로드 model, preprocess_val, segmenter, device, resize_transform = load_models() # ChromaDB 설정 client = chromadb.PersistentClient(path="./fashion_multimodal_db") # Multimodal collection 생성 embedding_function = OpenCLIPEmbeddingFunction() data_loader = ImageLoader() try: client.delete_collection("fashion_multimodal") logger.info("Deleted existing collection") except: logger.info("No existing collection to delete") collection = client.create_collection( name="fashion_multimodal", embedding_function=embedding_function, data_loader=data_loader, metadata={"description": "Fashion multimodal collection with advanced feature extraction"} ) # 처리 결과 통계 stats = { 'total_processed': 0, 'successful': 0, 'failed': 0, 'feature_extraction_failed': 0 } # JSON 파일들 처리 for json_file in json_files: with open(json_file, 'r', encoding='utf-8') as f: data = json.load(f) logger.info(f"Processing {len(data)} items from {json_file}") with ThreadPoolExecutor(max_workers=4) as executor: futures = [] for item in data: future = executor.submit( process_item, item, model, preprocess_val, segmenter, device, resize_transform ) futures.append(future) processed_items = [] for future in tqdm(futures, desc=f"Processing {json_file}"): stats['total_processed'] += 1 result = future.result() if result is not None: processed_items.append(result) stats['successful'] += 1 else: stats['failed'] += 1 # 배치로 데이터베이스에 추가 if processed_items: try: collection.add( ids=[item['id'] for item in processed_items], embeddings=[item['embedding'] for item in processed_items], metadatas=[item['metadata'] for item in processed_items], uris=[item['image_uri'] for item in processed_items] ) except Exception as e: logger.error(f"Failed to add batch to collection: {str(e)}") stats['failed'] += len(processed_items) stats['successful'] -= len(processed_items) # 최종 통계 출력 logger.info("Processing completed:") logger.info(f"Total processed: {stats['total_processed']}") logger.info(f"Successful: {stats['successful']}") logger.info(f"Failed: {stats['failed']}") return stats['successful'] > 0 except Exception as e: logger.error(f"Database creation error: {str(e)}") return False if __name__ == "__main__": json_files = [ './musinsa_ranking_images_category_0920.json', './11st/11st_bagaccessory_20241017_172846.json', './11st/11st_best_abroad_bagaccessory_20241017_173300.json', './11st/11st_best_abroad_fashion_20241017_173144.json', './11st/11st_best_abroad_luxury_20241017_173343.json', './11st/11st_best_men_20241017_172534.json', './11st/11st_best_women_20241017_172127.json', './gmarket/gmarket_best_accessory_20241015_155921.json', './gmarket/gmarket_best_bag_20241015_155811.json', './gmarket/gmarket_best_brand_20241015_155530.json', './gmarket/gmarket_best_casual_20241015_155421.json', './gmarket/gmarket_best_men_20241015_155025.json', './gmarket/gmarket_best_shoe_20241015_155613.json', './gmarket/gmarket_best_women_20241015_154206.json' ] success = create_multimodal_fashion_db(json_files) if success: print("Successfully created multimodal fashion database!") else: print("Failed to create database. Check the logs for details.")