itda-multimodal-segmentation / db_multimodal_create.py
leedoming's picture
Create db_multimodal_create.py
466ea14 verified
raw
history blame
15.1 kB
import chromadb
import logging
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import pipeline
import requests
import io
import json
import uuid
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import os
from io import BytesIO
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('fashion_db_creation.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def load_models():
try:
logger.info("Loading models...")
# CLIP ๋ชจ๋ธ
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
# ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋ชจ๋ธ
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ transforms ์ถ”๊ฐ€
from torchvision import transforms
resize_transform = transforms.Compose([
transforms.Resize((224, 224)), # CLIP ์ž…๋ ฅ ํฌ๊ธฐ์— ๋งž์ถค
transforms.ToTensor(),
])
return model, preprocess_val, segmenter, device, resize_transform
except Exception as e:
logger.error(f"Error loading models: {e}")
raise
def process_segmentation(image, segmenter):
"""Segmentation processing"""
try:
output = segmenter(image)
if not output:
logger.warning("No segments found in image")
return None
segment_sizes = [np.sum(seg['mask']) for seg in output]
if not segment_sizes:
return None
largest_idx = np.argmax(segment_sizes)
mask = output[largest_idx]['mask']
if not isinstance(mask, np.ndarray):
mask = np.array(mask)
if len(mask.shape) > 2:
mask = mask[:, :, 0]
mask = mask.astype(float)
logger.info(f"Successfully created mask with shape {mask.shape}")
return mask
except Exception as e:
logger.error(f"Segmentation error: {str(e)}")
return None
def load_image_from_url(url, max_retries=3):
for attempt in range(max_retries):
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
img = Image.open(BytesIO(response.content)).convert('RGB')
return img
except Exception as e:
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
if attempt < max_retries - 1:
time.sleep(1)
else:
logger.error(f"Failed to load image from {url} after {max_retries} attempts")
return None
def extract_features(image, mask, model, preprocess_val, device):
"""Advanced feature extraction with mask-based attention"""
try:
img_array = np.array(image)
mask = np.expand_dims(mask, axis=2)
mask_3channel = np.repeat(mask, 3, axis=2)
# 1. ์›๋ณธ ์ด๋ฏธ์ง€์—์„œ ํŠน์ง• ์ถ”์ถœ
image_tensor_original = preprocess_val(image).unsqueeze(0).to(device)
# 2. ๋งˆ์Šคํฌ๋œ ์ด๋ฏธ์ง€(ํฐ์ƒ‰ ๋ฐฐ๊ฒฝ) ํŠน์ง• ์ถ”์ถœ
masked_img_white = img_array * mask_3channel + (1 - mask_3channel) * 255
image_masked_white = Image.fromarray(masked_img_white.astype(np.uint8))
image_tensor_masked = preprocess_val(image_masked_white).unsqueeze(0).to(device)
# 3. ์˜๋ฅ˜ ๋ถ€๋ถ„๋งŒ ํฌ๋กญํ•œ ๋ฒ„์ „ ํŠน์ง• ์ถ”์ถœ
bbox = get_bbox_from_mask(mask) # ๋งˆ์Šคํฌ๋กœ๋ถ€ํ„ฐ ๊ฒฝ๊ณ„์ƒ์ž ์ถ”์ถœ
cropped_img = crop_and_resize(img_array * mask_3channel, bbox)
image_cropped = Image.fromarray(cropped_img.astype(np.uint8))
image_tensor_cropped = preprocess_val(image_cropped).unsqueeze(0).to(device)
with torch.no_grad():
# ์„ธ ๊ฐ€์ง€ ๋ฒ„์ „์˜ ํŠน์ง• ์ถ”์ถœ
features_original = model.encode_image(image_tensor_original)
features_masked = model.encode_image(image_tensor_masked)
features_cropped = model.encode_image(image_tensor_cropped)
# ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•œ ํŠน์ง• ๊ฒฐํ•ฉ
combined_features = (
0.2 * features_original +
0.3 * features_masked +
0.5 * features_cropped
)
# ์ •๊ทœํ™”
combined_features /= combined_features.norm(dim=-1, keepdim=True)
return combined_features.cpu().numpy().flatten()
except Exception as e:
logger.error(f"Feature extraction error: {e}")
return None
def get_bbox_from_mask(mask):
"""๋งˆ์Šคํฌ๋กœ๋ถ€ํ„ฐ ๊ฒฝ๊ณ„์ƒ์ž ์ขŒํ‘œ ์ถ”์ถœ"""
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
# ์—ฌ์œ  ๊ณต๊ฐ„ ์ถ”๊ฐ€
padding = 10
rmin = max(rmin - padding, 0)
rmax = min(rmax + padding, mask.shape[0])
cmin = max(cmin - padding, 0)
cmax = min(cmax + padding, mask.shape[1])
return rmin, rmax, cmin, cmax
def crop_and_resize(image, bbox):
"""๊ฒฝ๊ณ„์ƒ์ž๋กœ ์ด๋ฏธ์ง€ ํฌ๋กญ ๋ฐ ๋ฆฌ์‚ฌ์ด์ฆˆ"""
rmin, rmax, cmin, cmax = bbox
cropped = image[rmin:rmax, cmin:cmax]
# PIL์„ ์‚ฌ์šฉํ•˜์—ฌ ์ •์‚ฌ๊ฐํ˜•์œผ๋กœ ๋ฆฌ์‚ฌ์ด์ฆˆ
size = max(cropped.shape[:2])
square_img = np.full((size, size, 3), 255, dtype=np.uint8)
start_h = (size - cropped.shape[0]) // 2
start_w = (size - cropped.shape[1]) // 2
square_img[start_h:start_h+cropped.shape[0],
start_w:start_w+cropped.shape[1]] = cropped
return square_img
def process_item(item, model, preprocess_val, segmenter, device, resize_transform):
"""Process single item from JSON data"""
try:
# ์ด๋ฏธ์ง€ URL ์ถ”์ถœ
if '์ด๋ฏธ์ง€ ๋งํฌ' in item:
image_url = item['์ด๋ฏธ์ง€ ๋งํฌ']
elif '์ด๋ฏธ์ง€ URL' in item:
image_url = item['์ด๋ฏธ์ง€ URL']
else:
logger.warning(f"No image URL found in item")
return None
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ƒ์„ฑ
metadata = create_metadata(item)
# ์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ
image = load_image_from_url(image_url)
if image is None:
logger.warning(f"Failed to load image from {image_url}")
return None
# ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์ˆ˜ํ–‰
mask = process_segmentation(image, segmenter)
if mask is None:
logger.warning(f"Failed to create segmentation mask for {image_url}")
return None
# ์ƒˆ๋กœ์šด ํŠน์ง• ์ถ”์ถœ ๋ฐฉ์‹ ์ ์šฉ
try:
features = extract_features(image, mask, model, preprocess_val, device)
if features is None:
raise ValueError("Feature extraction failed")
# ๋””๋ฒ„๊น…์šฉ ์ด๋ฏธ์ง€ ์ €์žฅ (์„ ํƒ์‚ฌํ•ญ)
# save_debug_images(image, mask, image_url)
except Exception as e:
logger.error(f"Feature extraction failed for {image_url}: {str(e)}")
return None
return {
'id': metadata['product_id'],
'embedding': features.tolist(),
'metadata': metadata,
'image_uri': image_url
}
except Exception as e:
logger.error(f"Error processing item: {str(e)}")
return None
# ๋””๋ฒ„๊น…์šฉ ์ด๋ฏธ์ง€ ์ €์žฅ ํ•จ์ˆ˜ (์„ ํƒ์‚ฌํ•ญ)
def save_debug_images(image, mask, url):
try:
debug_dir = "debug_images"
os.makedirs(debug_dir, exist_ok=True)
# URL์—์„œ ํŒŒ์ผ๋ช… ์ถ”์ถœ
filename = url.split('/')[-1].split('?')[0]
# ์›๋ณธ, ๋งˆ์Šคํฌ, ์ฒ˜๋ฆฌ๋œ ์ด๋ฏธ์ง€ ์ €์žฅ
image.save(f"{debug_dir}/original_{filename}")
mask_img = Image.fromarray((mask * 255).astype(np.uint8))
mask_img.save(f"{debug_dir}/mask_{filename}")
except Exception as e:
logger.warning(f"Failed to save debug images: {str(e)}")
def create_metadata(item):
"""Create standardized metadata from different JSON formats"""
metadata = {}
# ์ƒํ’ˆ ID ์ฒ˜๋ฆฌ ๊ฐœ์„ 
if '๏ปฟ์ƒํ’ˆ ID' in item: # ๋ฌด์‹ ์‚ฌ ํ˜•์‹
metadata['product_id'] = item['๏ปฟ์ƒํ’ˆ ID']
else:
# 11๋ฒˆ๊ฐ€/G๋งˆ์ผ“์˜ ๊ฒฝ์šฐ ์ƒํ’ˆ๋ช…๊ณผ URL๋กœ ์œ ๋‹ˆํฌํ•œ ID ์ƒ์„ฑ
unique_string = f"{item.get('์ƒํ’ˆ๋ช…', '')}{item.get('์ด๋ฏธ์ง€ URL', '')}"
metadata['product_id'] = str(hash(unique_string))
# ๋‚˜๋จธ์ง€ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
metadata['brand'] = item.get('๋ธŒ๋žœ๋“œ๋ช…', 'unknown')
metadata['name'] = item.get('์ œํ’ˆ๋ช…') or item.get('์ƒํ’ˆ๋ช…', 'unknown')
metadata['price'] = (item.get('์ •๊ฐ€') or item.get('๊ฐ€๊ฒฉ') or
item.get('ํŒ๋งค๊ฐ€', 'unknown'))
metadata['discount'] = item.get('ํ• ์ธ์œจ', 'unknown')
if '์นดํ…Œ๊ณ ๋ฆฌ' in item:
if isinstance(item['์นดํ…Œ๊ณ ๋ฆฌ'], list):
metadata['category'] = '/'.join(item['์นดํ…Œ๊ณ ๋ฆฌ'])
else:
metadata['category'] = item['์นดํ…Œ๊ณ ๋ฆฌ']
else:
# 11๋ฒˆ๊ฐ€/G๋งˆ์ผ“์˜ ๊ฒฝ์šฐ ์ƒํ’ˆ๋ช…์—์„œ ์นดํ…Œ๊ณ ๋ฆฌ ์ถ”์ถœ ์‹œ๋„
name = metadata['name'].lower()
categories = ['์›ํ”ผ์Šค', '์…”์ธ ', '๋ธ”๋ผ์šฐ์Šค', '๋‹ˆํŠธ', '๊ฐ€๋””๊ฑด',
'์Šค์ปคํŠธ', 'ํŒฌ์ธ ', '์…‹์—…', '์•„์šฐํ„ฐ', '์ž์ผ“']
found_categories = [cat for cat in categories if cat in name]
metadata['category'] = '/'.join(found_categories) if found_categories else 'unknown'
metadata['image_url'] = (item.get('์ด๋ฏธ์ง€ ๋งํฌ') or
item.get('์ด๋ฏธ์ง€ URL', 'unknown'))
# ์‡ผํ•‘๋ชฐ ์ถœ์ฒ˜ ์ถ”๊ฐ€
if '์ด๋ฏธ์ง€ ๋งํฌ' in item:
metadata['source'] = 'musinsa'
elif 'cdn.011st.com' in metadata['image_url']:
metadata['source'] = '11st'
elif 'gmarket' in metadata['image_url']:
metadata['source'] = 'gmarket'
else:
metadata['source'] = 'unknown'
return metadata
def create_multimodal_fashion_db(json_files):
try:
logger.info("Starting multimodal fashion database creation")
# ๋ชจ๋ธ ๋กœ๋“œ
model, preprocess_val, segmenter, device, resize_transform = load_models()
# ChromaDB ์„ค์ •
client = chromadb.PersistentClient(path="./fashion_multimodal_db")
# Multimodal collection ์ƒ์„ฑ
embedding_function = OpenCLIPEmbeddingFunction()
data_loader = ImageLoader()
try:
client.delete_collection("fashion_multimodal")
logger.info("Deleted existing collection")
except:
logger.info("No existing collection to delete")
collection = client.create_collection(
name="fashion_multimodal",
embedding_function=embedding_function,
data_loader=data_loader,
metadata={"description": "Fashion multimodal collection with advanced feature extraction"}
)
# ์ฒ˜๋ฆฌ ๊ฒฐ๊ณผ ํ†ต๊ณ„
stats = {
'total_processed': 0,
'successful': 0,
'failed': 0,
'feature_extraction_failed': 0
}
# JSON ํŒŒ์ผ๋“ค ์ฒ˜๋ฆฌ
for json_file in json_files:
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
logger.info(f"Processing {len(data)} items from {json_file}")
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
for item in data:
future = executor.submit(
process_item,
item, model, preprocess_val, segmenter, device, resize_transform
)
futures.append(future)
processed_items = []
for future in tqdm(futures, desc=f"Processing {json_file}"):
stats['total_processed'] += 1
result = future.result()
if result is not None:
processed_items.append(result)
stats['successful'] += 1
else:
stats['failed'] += 1
# ๋ฐฐ์น˜๋กœ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์ถ”๊ฐ€
if processed_items:
try:
collection.add(
ids=[item['id'] for item in processed_items],
embeddings=[item['embedding'] for item in processed_items],
metadatas=[item['metadata'] for item in processed_items],
uris=[item['image_uri'] for item in processed_items]
)
except Exception as e:
logger.error(f"Failed to add batch to collection: {str(e)}")
stats['failed'] += len(processed_items)
stats['successful'] -= len(processed_items)
# ์ตœ์ข… ํ†ต๊ณ„ ์ถœ๋ ฅ
logger.info("Processing completed:")
logger.info(f"Total processed: {stats['total_processed']}")
logger.info(f"Successful: {stats['successful']}")
logger.info(f"Failed: {stats['failed']}")
return stats['successful'] > 0
except Exception as e:
logger.error(f"Database creation error: {str(e)}")
return False
if __name__ == "__main__":
json_files = [
'./musinsa_ranking_images_category_0920.json',
'./11st/11st_bagaccessory_20241017_172846.json',
'./11st/11st_best_abroad_bagaccessory_20241017_173300.json',
'./11st/11st_best_abroad_fashion_20241017_173144.json',
'./11st/11st_best_abroad_luxury_20241017_173343.json',
'./11st/11st_best_men_20241017_172534.json',
'./11st/11st_best_women_20241017_172127.json',
'./gmarket/gmarket_best_accessory_20241015_155921.json',
'./gmarket/gmarket_best_bag_20241015_155811.json',
'./gmarket/gmarket_best_brand_20241015_155530.json',
'./gmarket/gmarket_best_casual_20241015_155421.json',
'./gmarket/gmarket_best_men_20241015_155025.json',
'./gmarket/gmarket_best_shoe_20241015_155613.json',
'./gmarket/gmarket_best_women_20241015_154206.json'
]
success = create_multimodal_fashion_db(json_files)
if success:
print("Successfully created multimodal fashion database!")
else:
print("Failed to create database. Check the logs for details.")