Spaces:
Sleeping
Sleeping
File size: 8,652 Bytes
2dba380 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import chromadb
import logging
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import pipeline
import requests
import io
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import os
# 로깅 설정
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('db_creation.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def load_models():
"""Load CLIP and segmentation models"""
try:
logger.info("Loading models...")
# CLIP 모델
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
# 세그멘테이션 모델
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
model.to(device)
return model, preprocess_val, segmenter, device
except Exception as e:
logger.error(f"Error loading models: {e}")
raise
def process_segmentation(image, segmenter):
"""Apply segmentation to image"""
try:
segments = segmenter(image)
if not segments:
return None
# 가장 큰 세그먼트 선택
largest_segment = max(segments, key=lambda s: np.sum(s['mask']))
mask = np.array(largest_segment['mask'])
return mask
except Exception as e:
logger.error(f"Segmentation error: {e}")
return None
def extract_features(image, mask, model, preprocess_val, device):
"""Extract CLIP features with segmentation mask"""
try:
if mask is not None:
img_array = np.array(image)
mask = np.expand_dims(mask, axis=2)
masked_img = img_array * mask
masked_img[mask[:,:,0] == 0] = 255 # 배경을 흰색으로
image = Image.fromarray(masked_img.astype(np.uint8))
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
with torch.no_grad():
features = model.encode_image(image_tensor)
features /= features.norm(dim=-1, keepdim=True)
return features.cpu().numpy().flatten()
except Exception as e:
logger.error(f"Feature extraction error: {e}")
return None
def download_and_process_image(url, metadata_id, model, preprocess_val, segmenter, device):
"""Download and process single image"""
try:
response = requests.get(url, timeout=10)
if response.status_code != 200:
logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
return None
image = Image.open(io.BytesIO(response.content)).convert('RGB')
# Apply segmentation
mask = process_segmentation(image, segmenter)
if mask is None:
logger.warning(f"No valid mask found for image {metadata_id}")
return None
# Extract features
features = extract_features(image, mask, model, preprocess_val, device)
if features is None:
logger.warning(f"Failed to extract features for image {metadata_id}")
return None
return features
except Exception as e:
logger.error(f"Error processing image {metadata_id}: {e}")
return None
def create_segmented_db(source_path, target_path, batch_size=100):
"""Create new segmented database from existing one"""
try:
logger.info("Loading models...")
model, preprocess_val, segmenter, device = load_models()
# Source DB 연결
source_client = chromadb.PersistentClient(path=source_path)
source_collection = source_client.get_collection(name="clothes")
# Target DB 생성
os.makedirs(target_path, exist_ok=True)
target_client = chromadb.PersistentClient(path=target_path)
try:
target_client.delete_collection("clothes_segmented")
except:
pass
target_collection = target_client.create_collection(
name="clothes_segmented",
metadata={"description": "Clothes collection with segmentation-based features"}
)
# 전체 아이템 수 확인
all_items = source_collection.get(include=['metadatas'])
total_items = len(all_items['metadatas'])
logger.info(f"Found {total_items} items in source database")
# 배치 처리를 위한 준비
successful_updates = 0
failed_updates = 0
# ThreadPoolExecutor 설정
max_workers = min(10, os.cpu_count() or 4)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 전체 데이터를 배치로 나누어 처리
for batch_start in tqdm(range(0, total_items, batch_size), desc="Processing batches"):
batch_end = min(batch_start + batch_size, total_items)
batch_items = all_items['metadatas'][batch_start:batch_end]
# 배치 내의 모든 이미지에 대한 future 생성
futures = []
for metadata in batch_items:
if 'image_url' in metadata:
future = executor.submit(
download_and_process_image,
metadata['image_url'],
metadata.get('id', 'unknown'),
model, preprocess_val, segmenter, device
)
futures.append((metadata, future))
# 배치 결과 처리
batch_embeddings = []
batch_metadatas = []
batch_ids = []
for metadata, future in futures:
try:
features = future.result()
if features is not None:
batch_embeddings.append(features.tolist())
batch_metadatas.append(metadata)
batch_ids.append(metadata.get('id', str(hash(metadata['image_url']))))
successful_updates += 1
else:
failed_updates += 1
except Exception as e:
logger.error(f"Error processing batch item: {e}")
failed_updates += 1
continue
# 배치 데이터 저장
if batch_embeddings:
try:
target_collection.add(
embeddings=batch_embeddings,
metadatas=batch_metadatas,
ids=batch_ids
)
logger.info(f"Added batch of {len(batch_embeddings)} items")
except Exception as e:
logger.error(f"Error adding batch to collection: {e}")
failed_updates += len(batch_embeddings)
successful_updates -= len(batch_embeddings)
# 최종 결과 출력
logger.info(f"Database creation completed.")
logger.info(f"Successfully processed: {successful_updates}")
logger.info(f"Failed: {failed_updates}")
logger.info(f"Total completion rate: {(successful_updates/total_items)*100:.2f}%")
return True
except Exception as e:
logger.error(f"Database creation error: {e}")
return False
if __name__ == "__main__":
# 설정값
SOURCE_DB_PATH = "./clothesDB_11GmarketMusinsa" # 원본 DB 경로
TARGET_DB_PATH = "./clothesDB_11GmarketMusinsa_segmented" # 새로운 DB 경로
BATCH_SIZE = 50 # 한 번에 처리할 아이템 수
# DB 생성 실행
success = create_segmented_db(SOURCE_DB_PATH, TARGET_DB_PATH, BATCH_SIZE)
if success:
logger.info("Successfully created segmented database!")
else:
logger.error("Failed to create segmented database.") |