leedoming commited on
Commit
466ea14
โ€ข
1 Parent(s): 6ae0afe

Create db_multimodal_create.py

Browse files
Files changed (1) hide show
  1. db_multimodal_create.py +398 -0
db_multimodal_create.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ import logging
3
+ import open_clip
4
+ import torch
5
+ from PIL import Image
6
+ import numpy as np
7
+ from transformers import pipeline
8
+ import requests
9
+ import io
10
+ import json
11
+ import uuid
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from tqdm import tqdm
14
+ import os
15
+ from io import BytesIO
16
+ from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
17
+ from chromadb.utils.data_loaders import ImageLoader
18
+
19
+ # ๋กœ๊น… ์„ค์ •
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(levelname)s - %(message)s',
23
+ handlers=[
24
+ logging.FileHandler('fashion_db_creation.log'),
25
+ logging.StreamHandler()
26
+ ]
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+
30
+ def load_models():
31
+ try:
32
+ logger.info("Loading models...")
33
+ # CLIP ๋ชจ๋ธ
34
+ model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
35
+
36
+ # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋ชจ๋ธ
37
+ segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ model.to(device)
41
+
42
+ # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ transforms ์ถ”๊ฐ€
43
+ from torchvision import transforms
44
+ resize_transform = transforms.Compose([
45
+ transforms.Resize((224, 224)), # CLIP ์ž…๋ ฅ ํฌ๊ธฐ์— ๋งž์ถค
46
+ transforms.ToTensor(),
47
+ ])
48
+
49
+ return model, preprocess_val, segmenter, device, resize_transform
50
+ except Exception as e:
51
+ logger.error(f"Error loading models: {e}")
52
+ raise
53
+
54
+ def process_segmentation(image, segmenter):
55
+ """Segmentation processing"""
56
+ try:
57
+ output = segmenter(image)
58
+
59
+ if not output:
60
+ logger.warning("No segments found in image")
61
+ return None
62
+
63
+ segment_sizes = [np.sum(seg['mask']) for seg in output]
64
+
65
+ if not segment_sizes:
66
+ return None
67
+
68
+ largest_idx = np.argmax(segment_sizes)
69
+ mask = output[largest_idx]['mask']
70
+
71
+ if not isinstance(mask, np.ndarray):
72
+ mask = np.array(mask)
73
+
74
+ if len(mask.shape) > 2:
75
+ mask = mask[:, :, 0]
76
+
77
+ mask = mask.astype(float)
78
+
79
+ logger.info(f"Successfully created mask with shape {mask.shape}")
80
+ return mask
81
+
82
+ except Exception as e:
83
+ logger.error(f"Segmentation error: {str(e)}")
84
+ return None
85
+
86
+ def load_image_from_url(url, max_retries=3):
87
+ for attempt in range(max_retries):
88
+ try:
89
+ response = requests.get(url, timeout=10)
90
+ response.raise_for_status()
91
+ img = Image.open(BytesIO(response.content)).convert('RGB')
92
+ return img
93
+ except Exception as e:
94
+ logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
95
+ if attempt < max_retries - 1:
96
+ time.sleep(1)
97
+ else:
98
+ logger.error(f"Failed to load image from {url} after {max_retries} attempts")
99
+ return None
100
+
101
+ def extract_features(image, mask, model, preprocess_val, device):
102
+ """Advanced feature extraction with mask-based attention"""
103
+ try:
104
+ img_array = np.array(image)
105
+ mask = np.expand_dims(mask, axis=2)
106
+ mask_3channel = np.repeat(mask, 3, axis=2)
107
+
108
+ # 1. ์›๋ณธ ์ด๋ฏธ์ง€์—์„œ ํŠน์ง• ์ถ”์ถœ
109
+ image_tensor_original = preprocess_val(image).unsqueeze(0).to(device)
110
+
111
+ # 2. ๋งˆ์Šคํฌ๋œ ์ด๋ฏธ์ง€(ํฐ์ƒ‰ ๋ฐฐ๊ฒฝ) ํŠน์ง• ์ถ”์ถœ
112
+ masked_img_white = img_array * mask_3channel + (1 - mask_3channel) * 255
113
+ image_masked_white = Image.fromarray(masked_img_white.astype(np.uint8))
114
+ image_tensor_masked = preprocess_val(image_masked_white).unsqueeze(0).to(device)
115
+
116
+ # 3. ์˜๋ฅ˜ ๋ถ€๋ถ„๋งŒ ํฌ๋กญํ•œ ๋ฒ„์ „ ํŠน์ง• ์ถ”์ถœ
117
+ bbox = get_bbox_from_mask(mask) # ๋งˆ์Šคํฌ๋กœ๋ถ€ํ„ฐ ๊ฒฝ๊ณ„์ƒ์ž ์ถ”์ถœ
118
+ cropped_img = crop_and_resize(img_array * mask_3channel, bbox)
119
+ image_cropped = Image.fromarray(cropped_img.astype(np.uint8))
120
+ image_tensor_cropped = preprocess_val(image_cropped).unsqueeze(0).to(device)
121
+
122
+ with torch.no_grad():
123
+ # ์„ธ ๊ฐ€์ง€ ๋ฒ„์ „์˜ ํŠน์ง• ์ถ”์ถœ
124
+ features_original = model.encode_image(image_tensor_original)
125
+ features_masked = model.encode_image(image_tensor_masked)
126
+ features_cropped = model.encode_image(image_tensor_cropped)
127
+
128
+ # ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•œ ํŠน์ง• ๊ฒฐํ•ฉ
129
+ combined_features = (
130
+ 0.2 * features_original +
131
+ 0.3 * features_masked +
132
+ 0.5 * features_cropped
133
+ )
134
+
135
+ # ์ •๊ทœํ™”
136
+ combined_features /= combined_features.norm(dim=-1, keepdim=True)
137
+
138
+ return combined_features.cpu().numpy().flatten()
139
+
140
+ except Exception as e:
141
+ logger.error(f"Feature extraction error: {e}")
142
+ return None
143
+
144
+ def get_bbox_from_mask(mask):
145
+ """๋งˆ์Šคํฌ๋กœ๋ถ€ํ„ฐ ๊ฒฝ๊ณ„์ƒ์ž ์ขŒํ‘œ ์ถ”์ถœ"""
146
+ rows = np.any(mask, axis=1)
147
+ cols = np.any(mask, axis=0)
148
+ rmin, rmax = np.where(rows)[0][[0, -1]]
149
+ cmin, cmax = np.where(cols)[0][[0, -1]]
150
+ # ์—ฌ์œ  ๊ณต๊ฐ„ ์ถ”๊ฐ€
151
+ padding = 10
152
+ rmin = max(rmin - padding, 0)
153
+ rmax = min(rmax + padding, mask.shape[0])
154
+ cmin = max(cmin - padding, 0)
155
+ cmax = min(cmax + padding, mask.shape[1])
156
+ return rmin, rmax, cmin, cmax
157
+
158
+ def crop_and_resize(image, bbox):
159
+ """๊ฒฝ๊ณ„์ƒ์ž๋กœ ์ด๋ฏธ์ง€ ํฌ๋กญ ๋ฐ ๋ฆฌ์‚ฌ์ด์ฆˆ"""
160
+ rmin, rmax, cmin, cmax = bbox
161
+ cropped = image[rmin:rmax, cmin:cmax]
162
+ # PIL์„ ์‚ฌ์šฉํ•˜์—ฌ ์ •์‚ฌ๊ฐํ˜•์œผ๋กœ ๋ฆฌ์‚ฌ์ด์ฆˆ
163
+ size = max(cropped.shape[:2])
164
+ square_img = np.full((size, size, 3), 255, dtype=np.uint8)
165
+ start_h = (size - cropped.shape[0]) // 2
166
+ start_w = (size - cropped.shape[1]) // 2
167
+ square_img[start_h:start_h+cropped.shape[0],
168
+ start_w:start_w+cropped.shape[1]] = cropped
169
+ return square_img
170
+
171
+ def process_item(item, model, preprocess_val, segmenter, device, resize_transform):
172
+ """Process single item from JSON data"""
173
+ try:
174
+ # ์ด๋ฏธ์ง€ URL ์ถ”์ถœ
175
+ if '์ด๋ฏธ์ง€ ๋งํฌ' in item:
176
+ image_url = item['์ด๋ฏธ์ง€ ๋งํฌ']
177
+ elif '์ด๋ฏธ์ง€ URL' in item:
178
+ image_url = item['์ด๋ฏธ์ง€ URL']
179
+ else:
180
+ logger.warning(f"No image URL found in item")
181
+ return None
182
+
183
+ # ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ƒ์„ฑ
184
+ metadata = create_metadata(item)
185
+
186
+ # ์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ
187
+ image = load_image_from_url(image_url)
188
+ if image is None:
189
+ logger.warning(f"Failed to load image from {image_url}")
190
+ return None
191
+
192
+ # ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ์ˆ˜ํ–‰
193
+ mask = process_segmentation(image, segmenter)
194
+ if mask is None:
195
+ logger.warning(f"Failed to create segmentation mask for {image_url}")
196
+ return None
197
+
198
+ # ์ƒˆ๋กœ์šด ํŠน์ง• ์ถ”์ถœ ๋ฐฉ์‹ ์ ์šฉ
199
+ try:
200
+ features = extract_features(image, mask, model, preprocess_val, device)
201
+ if features is None:
202
+ raise ValueError("Feature extraction failed")
203
+
204
+ # ๋””๋ฒ„๊น…์šฉ ์ด๋ฏธ์ง€ ์ €์žฅ (์„ ํƒ์‚ฌํ•ญ)
205
+ # save_debug_images(image, mask, image_url)
206
+
207
+ except Exception as e:
208
+ logger.error(f"Feature extraction failed for {image_url}: {str(e)}")
209
+ return None
210
+
211
+ return {
212
+ 'id': metadata['product_id'],
213
+ 'embedding': features.tolist(),
214
+ 'metadata': metadata,
215
+ 'image_uri': image_url
216
+ }
217
+
218
+ except Exception as e:
219
+ logger.error(f"Error processing item: {str(e)}")
220
+ return None
221
+
222
+ # ๋””๋ฒ„๊น…์šฉ ์ด๋ฏธ์ง€ ์ €์žฅ ํ•จ์ˆ˜ (์„ ํƒ์‚ฌํ•ญ)
223
+ def save_debug_images(image, mask, url):
224
+ try:
225
+ debug_dir = "debug_images"
226
+ os.makedirs(debug_dir, exist_ok=True)
227
+
228
+ # URL์—์„œ ํŒŒ์ผ๋ช… ์ถ”์ถœ
229
+ filename = url.split('/')[-1].split('?')[0]
230
+
231
+ # ์›๋ณธ, ๋งˆ์Šคํฌ, ์ฒ˜๋ฆฌ๋œ ์ด๋ฏธ์ง€ ์ €์žฅ
232
+ image.save(f"{debug_dir}/original_{filename}")
233
+
234
+ mask_img = Image.fromarray((mask * 255).astype(np.uint8))
235
+ mask_img.save(f"{debug_dir}/mask_{filename}")
236
+
237
+ except Exception as e:
238
+ logger.warning(f"Failed to save debug images: {str(e)}")
239
+
240
+ def create_metadata(item):
241
+ """Create standardized metadata from different JSON formats"""
242
+ metadata = {}
243
+
244
+ # ์ƒํ’ˆ ID ์ฒ˜๋ฆฌ ๊ฐœ์„ 
245
+ if '๏ปฟ์ƒํ’ˆ ID' in item: # ๋ฌด์‹ ์‚ฌ ํ˜•์‹
246
+ metadata['product_id'] = item['๏ปฟ์ƒํ’ˆ ID']
247
+ else:
248
+ # 11๋ฒˆ๊ฐ€/G๋งˆ์ผ“์˜ ๊ฒฝ์šฐ ์ƒํ’ˆ๋ช…๊ณผ URL๋กœ ์œ ๋‹ˆํฌํ•œ ID ์ƒ์„ฑ
249
+ unique_string = f"{item.get('์ƒํ’ˆ๋ช…', '')}{item.get('์ด๋ฏธ์ง€ URL', '')}"
250
+ metadata['product_id'] = str(hash(unique_string))
251
+
252
+ # ๋‚˜๋จธ์ง€ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ
253
+ metadata['brand'] = item.get('๋ธŒ๋žœ๋“œ๋ช…', 'unknown')
254
+ metadata['name'] = item.get('์ œํ’ˆ๋ช…') or item.get('์ƒํ’ˆ๋ช…', 'unknown')
255
+ metadata['price'] = (item.get('์ •๊ฐ€') or item.get('๊ฐ€๊ฒฉ') or
256
+ item.get('ํŒ๋งค๊ฐ€', 'unknown'))
257
+ metadata['discount'] = item.get('ํ• ์ธ์œจ', 'unknown')
258
+
259
+ if '์นดํ…Œ๊ณ ๋ฆฌ' in item:
260
+ if isinstance(item['์นดํ…Œ๊ณ ๋ฆฌ'], list):
261
+ metadata['category'] = '/'.join(item['์นดํ…Œ๊ณ ๋ฆฌ'])
262
+ else:
263
+ metadata['category'] = item['์นดํ…Œ๊ณ ๋ฆฌ']
264
+ else:
265
+ # 11๋ฒˆ๊ฐ€/G๋งˆ์ผ“์˜ ๊ฒฝ์šฐ ์ƒํ’ˆ๋ช…์—์„œ ์นดํ…Œ๊ณ ๋ฆฌ ์ถ”์ถœ ์‹œ๋„
266
+ name = metadata['name'].lower()
267
+ categories = ['์›ํ”ผ์Šค', '์…”์ธ ', '๋ธ”๋ผ์šฐ์Šค', '๋‹ˆํŠธ', '๊ฐ€๋””๊ฑด',
268
+ '์Šค์ปคํŠธ', 'ํŒฌ์ธ ', '์…‹์—…', '์•„์šฐํ„ฐ', '์ž์ผ“']
269
+ found_categories = [cat for cat in categories if cat in name]
270
+ metadata['category'] = '/'.join(found_categories) if found_categories else 'unknown'
271
+
272
+ metadata['image_url'] = (item.get('์ด๋ฏธ์ง€ ๋งํฌ') or
273
+ item.get('์ด๋ฏธ์ง€ URL', 'unknown'))
274
+
275
+ # ์‡ผํ•‘๋ชฐ ์ถœ์ฒ˜ ์ถ”๊ฐ€
276
+ if '์ด๋ฏธ์ง€ ๋งํฌ' in item:
277
+ metadata['source'] = 'musinsa'
278
+ elif 'cdn.011st.com' in metadata['image_url']:
279
+ metadata['source'] = '11st'
280
+ elif 'gmarket' in metadata['image_url']:
281
+ metadata['source'] = 'gmarket'
282
+ else:
283
+ metadata['source'] = 'unknown'
284
+
285
+ return metadata
286
+
287
+ def create_multimodal_fashion_db(json_files):
288
+ try:
289
+ logger.info("Starting multimodal fashion database creation")
290
+
291
+ # ๋ชจ๋ธ ๋กœ๋“œ
292
+ model, preprocess_val, segmenter, device, resize_transform = load_models()
293
+
294
+ # ChromaDB ์„ค์ •
295
+ client = chromadb.PersistentClient(path="./fashion_multimodal_db")
296
+
297
+ # Multimodal collection ์ƒ์„ฑ
298
+ embedding_function = OpenCLIPEmbeddingFunction()
299
+ data_loader = ImageLoader()
300
+
301
+ try:
302
+ client.delete_collection("fashion_multimodal")
303
+ logger.info("Deleted existing collection")
304
+ except:
305
+ logger.info("No existing collection to delete")
306
+
307
+ collection = client.create_collection(
308
+ name="fashion_multimodal",
309
+ embedding_function=embedding_function,
310
+ data_loader=data_loader,
311
+ metadata={"description": "Fashion multimodal collection with advanced feature extraction"}
312
+ )
313
+
314
+ # ์ฒ˜๋ฆฌ ๊ฒฐ๊ณผ ํ†ต๊ณ„
315
+ stats = {
316
+ 'total_processed': 0,
317
+ 'successful': 0,
318
+ 'failed': 0,
319
+ 'feature_extraction_failed': 0
320
+ }
321
+
322
+ # JSON ํŒŒ์ผ๋“ค ์ฒ˜๋ฆฌ
323
+ for json_file in json_files:
324
+ with open(json_file, 'r', encoding='utf-8') as f:
325
+ data = json.load(f)
326
+
327
+ logger.info(f"Processing {len(data)} items from {json_file}")
328
+
329
+ with ThreadPoolExecutor(max_workers=4) as executor:
330
+ futures = []
331
+ for item in data:
332
+ future = executor.submit(
333
+ process_item,
334
+ item, model, preprocess_val, segmenter, device, resize_transform
335
+ )
336
+ futures.append(future)
337
+
338
+ processed_items = []
339
+ for future in tqdm(futures, desc=f"Processing {json_file}"):
340
+ stats['total_processed'] += 1
341
+ result = future.result()
342
+
343
+ if result is not None:
344
+ processed_items.append(result)
345
+ stats['successful'] += 1
346
+ else:
347
+ stats['failed'] += 1
348
+
349
+ # ๋ฐฐ์น˜๋กœ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์ถ”๊ฐ€
350
+ if processed_items:
351
+ try:
352
+ collection.add(
353
+ ids=[item['id'] for item in processed_items],
354
+ embeddings=[item['embedding'] for item in processed_items],
355
+ metadatas=[item['metadata'] for item in processed_items],
356
+ uris=[item['image_uri'] for item in processed_items]
357
+ )
358
+ except Exception as e:
359
+ logger.error(f"Failed to add batch to collection: {str(e)}")
360
+ stats['failed'] += len(processed_items)
361
+ stats['successful'] -= len(processed_items)
362
+
363
+ # ์ตœ์ข… ํ†ต๊ณ„ ์ถœ๋ ฅ
364
+ logger.info("Processing completed:")
365
+ logger.info(f"Total processed: {stats['total_processed']}")
366
+ logger.info(f"Successful: {stats['successful']}")
367
+ logger.info(f"Failed: {stats['failed']}")
368
+
369
+ return stats['successful'] > 0
370
+
371
+ except Exception as e:
372
+ logger.error(f"Database creation error: {str(e)}")
373
+ return False
374
+
375
+ if __name__ == "__main__":
376
+ json_files = [
377
+ './musinsa_ranking_images_category_0920.json',
378
+ './11st/11st_bagaccessory_20241017_172846.json',
379
+ './11st/11st_best_abroad_bagaccessory_20241017_173300.json',
380
+ './11st/11st_best_abroad_fashion_20241017_173144.json',
381
+ './11st/11st_best_abroad_luxury_20241017_173343.json',
382
+ './11st/11st_best_men_20241017_172534.json',
383
+ './11st/11st_best_women_20241017_172127.json',
384
+ './gmarket/gmarket_best_accessory_20241015_155921.json',
385
+ './gmarket/gmarket_best_bag_20241015_155811.json',
386
+ './gmarket/gmarket_best_brand_20241015_155530.json',
387
+ './gmarket/gmarket_best_casual_20241015_155421.json',
388
+ './gmarket/gmarket_best_men_20241015_155025.json',
389
+ './gmarket/gmarket_best_shoe_20241015_155613.json',
390
+ './gmarket/gmarket_best_women_20241015_154206.json'
391
+ ]
392
+
393
+ success = create_multimodal_fashion_db(json_files)
394
+
395
+ if success:
396
+ print("Successfully created multimodal fashion database!")
397
+ else:
398
+ print("Failed to create database. Check the logs for details.")