|
import chromadb |
|
import logging |
|
import open_clip |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
from transformers import pipeline |
|
import requests |
|
import io |
|
import json |
|
import uuid |
|
from concurrent.futures import ThreadPoolExecutor |
|
from tqdm import tqdm |
|
import os |
|
from io import BytesIO |
|
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction |
|
from chromadb.utils.data_loaders import ImageLoader |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler('fashion_db_creation.log'), |
|
logging.StreamHandler() |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def load_models(): |
|
try: |
|
logger.info("Loading models...") |
|
|
|
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') |
|
|
|
|
|
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
from torchvision import transforms |
|
resize_transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
return model, preprocess_val, segmenter, device, resize_transform |
|
except Exception as e: |
|
logger.error(f"Error loading models: {e}") |
|
raise |
|
|
|
def process_segmentation(image, segmenter): |
|
"""Segmentation processing""" |
|
try: |
|
output = segmenter(image) |
|
|
|
if not output: |
|
logger.warning("No segments found in image") |
|
return None |
|
|
|
segment_sizes = [np.sum(seg['mask']) for seg in output] |
|
|
|
if not segment_sizes: |
|
return None |
|
|
|
largest_idx = np.argmax(segment_sizes) |
|
mask = output[largest_idx]['mask'] |
|
|
|
if not isinstance(mask, np.ndarray): |
|
mask = np.array(mask) |
|
|
|
if len(mask.shape) > 2: |
|
mask = mask[:, :, 0] |
|
|
|
mask = mask.astype(float) |
|
|
|
logger.info(f"Successfully created mask with shape {mask.shape}") |
|
return mask |
|
|
|
except Exception as e: |
|
logger.error(f"Segmentation error: {str(e)}") |
|
return None |
|
|
|
def load_image_from_url(url, max_retries=3): |
|
for attempt in range(max_retries): |
|
try: |
|
response = requests.get(url, timeout=10) |
|
response.raise_for_status() |
|
img = Image.open(BytesIO(response.content)).convert('RGB') |
|
return img |
|
except Exception as e: |
|
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") |
|
if attempt < max_retries - 1: |
|
time.sleep(1) |
|
else: |
|
logger.error(f"Failed to load image from {url} after {max_retries} attempts") |
|
return None |
|
|
|
def extract_features(image, mask, model, preprocess_val, device): |
|
"""Advanced feature extraction with mask-based attention""" |
|
try: |
|
img_array = np.array(image) |
|
mask = np.expand_dims(mask, axis=2) |
|
mask_3channel = np.repeat(mask, 3, axis=2) |
|
|
|
|
|
image_tensor_original = preprocess_val(image).unsqueeze(0).to(device) |
|
|
|
|
|
masked_img_white = img_array * mask_3channel + (1 - mask_3channel) * 255 |
|
image_masked_white = Image.fromarray(masked_img_white.astype(np.uint8)) |
|
image_tensor_masked = preprocess_val(image_masked_white).unsqueeze(0).to(device) |
|
|
|
|
|
bbox = get_bbox_from_mask(mask) |
|
cropped_img = crop_and_resize(img_array * mask_3channel, bbox) |
|
image_cropped = Image.fromarray(cropped_img.astype(np.uint8)) |
|
image_tensor_cropped = preprocess_val(image_cropped).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
|
|
features_original = model.encode_image(image_tensor_original) |
|
features_masked = model.encode_image(image_tensor_masked) |
|
features_cropped = model.encode_image(image_tensor_cropped) |
|
|
|
|
|
combined_features = ( |
|
0.2 * features_original + |
|
0.3 * features_masked + |
|
0.5 * features_cropped |
|
) |
|
|
|
|
|
combined_features /= combined_features.norm(dim=-1, keepdim=True) |
|
|
|
return combined_features.cpu().numpy().flatten() |
|
|
|
except Exception as e: |
|
logger.error(f"Feature extraction error: {e}") |
|
return None |
|
|
|
def get_bbox_from_mask(mask): |
|
"""๋ง์คํฌ๋ก๋ถํฐ ๊ฒฝ๊ณ์์ ์ขํ ์ถ์ถ""" |
|
rows = np.any(mask, axis=1) |
|
cols = np.any(mask, axis=0) |
|
rmin, rmax = np.where(rows)[0][[0, -1]] |
|
cmin, cmax = np.where(cols)[0][[0, -1]] |
|
|
|
padding = 10 |
|
rmin = max(rmin - padding, 0) |
|
rmax = min(rmax + padding, mask.shape[0]) |
|
cmin = max(cmin - padding, 0) |
|
cmax = min(cmax + padding, mask.shape[1]) |
|
return rmin, rmax, cmin, cmax |
|
|
|
def crop_and_resize(image, bbox): |
|
"""๊ฒฝ๊ณ์์๋ก ์ด๋ฏธ์ง ํฌ๋กญ ๋ฐ ๋ฆฌ์ฌ์ด์ฆ""" |
|
rmin, rmax, cmin, cmax = bbox |
|
cropped = image[rmin:rmax, cmin:cmax] |
|
|
|
size = max(cropped.shape[:2]) |
|
square_img = np.full((size, size, 3), 255, dtype=np.uint8) |
|
start_h = (size - cropped.shape[0]) // 2 |
|
start_w = (size - cropped.shape[1]) // 2 |
|
square_img[start_h:start_h+cropped.shape[0], |
|
start_w:start_w+cropped.shape[1]] = cropped |
|
return square_img |
|
|
|
def process_item(item, model, preprocess_val, segmenter, device, resize_transform): |
|
"""Process single item from JSON data""" |
|
try: |
|
|
|
if '์ด๋ฏธ์ง ๋งํฌ' in item: |
|
image_url = item['์ด๋ฏธ์ง ๋งํฌ'] |
|
elif '์ด๋ฏธ์ง URL' in item: |
|
image_url = item['์ด๋ฏธ์ง URL'] |
|
else: |
|
logger.warning(f"No image URL found in item") |
|
return None |
|
|
|
|
|
metadata = create_metadata(item) |
|
|
|
|
|
image = load_image_from_url(image_url) |
|
if image is None: |
|
logger.warning(f"Failed to load image from {image_url}") |
|
return None |
|
|
|
|
|
mask = process_segmentation(image, segmenter) |
|
if mask is None: |
|
logger.warning(f"Failed to create segmentation mask for {image_url}") |
|
return None |
|
|
|
|
|
try: |
|
features = extract_features(image, mask, model, preprocess_val, device) |
|
if features is None: |
|
raise ValueError("Feature extraction failed") |
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
logger.error(f"Feature extraction failed for {image_url}: {str(e)}") |
|
return None |
|
|
|
return { |
|
'id': metadata['product_id'], |
|
'embedding': features.tolist(), |
|
'metadata': metadata, |
|
'image_uri': image_url |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing item: {str(e)}") |
|
return None |
|
|
|
|
|
def save_debug_images(image, mask, url): |
|
try: |
|
debug_dir = "debug_images" |
|
os.makedirs(debug_dir, exist_ok=True) |
|
|
|
|
|
filename = url.split('/')[-1].split('?')[0] |
|
|
|
|
|
image.save(f"{debug_dir}/original_{filename}") |
|
|
|
mask_img = Image.fromarray((mask * 255).astype(np.uint8)) |
|
mask_img.save(f"{debug_dir}/mask_{filename}") |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to save debug images: {str(e)}") |
|
|
|
def create_metadata(item): |
|
"""Create standardized metadata from different JSON formats""" |
|
metadata = {} |
|
|
|
|
|
if '๏ปฟ์ํ ID' in item: |
|
metadata['product_id'] = item['๏ปฟ์ํ ID'] |
|
else: |
|
|
|
unique_string = f"{item.get('์ํ๋ช
', '')}{item.get('์ด๋ฏธ์ง URL', '')}" |
|
metadata['product_id'] = str(hash(unique_string)) |
|
|
|
|
|
metadata['brand'] = item.get('๋ธ๋๋๋ช
', 'unknown') |
|
metadata['name'] = item.get('์ ํ๋ช
') or item.get('์ํ๋ช
', 'unknown') |
|
metadata['price'] = (item.get('์ ๊ฐ') or item.get('๊ฐ๊ฒฉ') or |
|
item.get('ํ๋งค๊ฐ', 'unknown')) |
|
metadata['discount'] = item.get('ํ ์ธ์จ', 'unknown') |
|
|
|
if '์นดํ
๊ณ ๋ฆฌ' in item: |
|
if isinstance(item['์นดํ
๊ณ ๋ฆฌ'], list): |
|
metadata['category'] = '/'.join(item['์นดํ
๊ณ ๋ฆฌ']) |
|
else: |
|
metadata['category'] = item['์นดํ
๊ณ ๋ฆฌ'] |
|
else: |
|
|
|
name = metadata['name'].lower() |
|
categories = ['์ํผ์ค', '์
์ธ ', '๋ธ๋ผ์ฐ์ค', '๋ํธ', '๊ฐ๋๊ฑด', |
|
'์ค์ปคํธ', 'ํฌ์ธ ', '์
์
', '์์ฐํฐ', '์์ผ'] |
|
found_categories = [cat for cat in categories if cat in name] |
|
metadata['category'] = '/'.join(found_categories) if found_categories else 'unknown' |
|
|
|
metadata['image_url'] = (item.get('์ด๋ฏธ์ง ๋งํฌ') or |
|
item.get('์ด๋ฏธ์ง URL', 'unknown')) |
|
|
|
|
|
if '์ด๋ฏธ์ง ๋งํฌ' in item: |
|
metadata['source'] = 'musinsa' |
|
elif 'cdn.011st.com' in metadata['image_url']: |
|
metadata['source'] = '11st' |
|
elif 'gmarket' in metadata['image_url']: |
|
metadata['source'] = 'gmarket' |
|
else: |
|
metadata['source'] = 'unknown' |
|
|
|
return metadata |
|
|
|
def create_multimodal_fashion_db(json_files): |
|
try: |
|
logger.info("Starting multimodal fashion database creation") |
|
|
|
|
|
model, preprocess_val, segmenter, device, resize_transform = load_models() |
|
|
|
|
|
client = chromadb.PersistentClient(path="./fashion_multimodal_db") |
|
|
|
|
|
embedding_function = OpenCLIPEmbeddingFunction() |
|
data_loader = ImageLoader() |
|
|
|
try: |
|
client.delete_collection("fashion_multimodal") |
|
logger.info("Deleted existing collection") |
|
except: |
|
logger.info("No existing collection to delete") |
|
|
|
collection = client.create_collection( |
|
name="fashion_multimodal", |
|
embedding_function=embedding_function, |
|
data_loader=data_loader, |
|
metadata={"description": "Fashion multimodal collection with advanced feature extraction"} |
|
) |
|
|
|
|
|
stats = { |
|
'total_processed': 0, |
|
'successful': 0, |
|
'failed': 0, |
|
'feature_extraction_failed': 0 |
|
} |
|
|
|
|
|
for json_file in json_files: |
|
with open(json_file, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
|
|
logger.info(f"Processing {len(data)} items from {json_file}") |
|
|
|
with ThreadPoolExecutor(max_workers=4) as executor: |
|
futures = [] |
|
for item in data: |
|
future = executor.submit( |
|
process_item, |
|
item, model, preprocess_val, segmenter, device, resize_transform |
|
) |
|
futures.append(future) |
|
|
|
processed_items = [] |
|
for future in tqdm(futures, desc=f"Processing {json_file}"): |
|
stats['total_processed'] += 1 |
|
result = future.result() |
|
|
|
if result is not None: |
|
processed_items.append(result) |
|
stats['successful'] += 1 |
|
else: |
|
stats['failed'] += 1 |
|
|
|
|
|
if processed_items: |
|
try: |
|
collection.add( |
|
ids=[item['id'] for item in processed_items], |
|
embeddings=[item['embedding'] for item in processed_items], |
|
metadatas=[item['metadata'] for item in processed_items], |
|
uris=[item['image_uri'] for item in processed_items] |
|
) |
|
except Exception as e: |
|
logger.error(f"Failed to add batch to collection: {str(e)}") |
|
stats['failed'] += len(processed_items) |
|
stats['successful'] -= len(processed_items) |
|
|
|
|
|
logger.info("Processing completed:") |
|
logger.info(f"Total processed: {stats['total_processed']}") |
|
logger.info(f"Successful: {stats['successful']}") |
|
logger.info(f"Failed: {stats['failed']}") |
|
|
|
return stats['successful'] > 0 |
|
|
|
except Exception as e: |
|
logger.error(f"Database creation error: {str(e)}") |
|
return False |
|
|
|
if __name__ == "__main__": |
|
json_files = [ |
|
'./musinsa_ranking_images_category_0920.json', |
|
'./11st/11st_bagaccessory_20241017_172846.json', |
|
'./11st/11st_best_abroad_bagaccessory_20241017_173300.json', |
|
'./11st/11st_best_abroad_fashion_20241017_173144.json', |
|
'./11st/11st_best_abroad_luxury_20241017_173343.json', |
|
'./11st/11st_best_men_20241017_172534.json', |
|
'./11st/11st_best_women_20241017_172127.json', |
|
'./gmarket/gmarket_best_accessory_20241015_155921.json', |
|
'./gmarket/gmarket_best_bag_20241015_155811.json', |
|
'./gmarket/gmarket_best_brand_20241015_155530.json', |
|
'./gmarket/gmarket_best_casual_20241015_155421.json', |
|
'./gmarket/gmarket_best_men_20241015_155025.json', |
|
'./gmarket/gmarket_best_shoe_20241015_155613.json', |
|
'./gmarket/gmarket_best_women_20241015_154206.json' |
|
] |
|
|
|
success = create_multimodal_fashion_db(json_files) |
|
|
|
if success: |
|
print("Successfully created multimodal fashion database!") |
|
else: |
|
print("Failed to create database. Check the logs for details.") |