leedoming commited on
Commit
b473cc2
β€’
1 Parent(s): 460f00d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +512 -0
app.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import open_clip
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
7
+ import chromadb
8
+ import logging
9
+ import io
10
+ import requests
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
13
+ from chromadb.utils.data_loaders import ImageLoader
14
+
15
+ # λ‘œκΉ… μ„€μ •
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class CustomFashionEmbeddingFunction:
20
+ def __init__(self):
21
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ self.model = self.model.to(self.device)
24
+
25
+ def __call__(self, input):
26
+ try:
27
+ # μž…λ ₯이 URLμ΄λ‚˜ 경둜인 경우 처리
28
+ processed_images = []
29
+ for img in input:
30
+ if isinstance(img, (str, bytes)):
31
+ if isinstance(img, str):
32
+ response = requests.get(img)
33
+ img = Image.open(io.BytesIO(response.content)).convert('RGB')
34
+ else:
35
+ img = Image.open(io.BytesIO(img)).convert('RGB')
36
+ elif isinstance(img, np.ndarray):
37
+ img = Image.fromarray(img.astype('uint8')).convert('RGB')
38
+
39
+ processed_img = self.preprocess(img).unsqueeze(0).to(self.device)
40
+ processed_images.append(processed_img)
41
+
42
+ # 배치 처리
43
+ batch = torch.cat(processed_images)
44
+
45
+ # CLIP μž„λ² λ”© μΆ”μΆœ
46
+ with torch.no_grad():
47
+ clip_features = self.model.encode_image(batch)
48
+ clip_features = clip_features.cpu().numpy()
49
+
50
+ # 색상 νŠΉμ§• μΆ”μΆœ
51
+ color_features_list = []
52
+ for img in input:
53
+ if isinstance(img, (str, bytes)):
54
+ if isinstance(img, str):
55
+ response = requests.get(img)
56
+ img = Image.open(io.BytesIO(response.content)).convert('RGB')
57
+ else:
58
+ img = Image.open(io.BytesIO(img)).convert('RGB')
59
+ elif isinstance(img, np.ndarray):
60
+ img = Image.fromarray(img.astype('uint8')).convert('RGB')
61
+
62
+ color_features = self.extract_color_histogram(img)
63
+ color_features_list.append(color_features)
64
+
65
+ # νŠΉμ§• κ²°ν•©
66
+ combined_embeddings = []
67
+ for clip_emb, color_feat in zip(clip_features, color_features_list):
68
+ # CLIP μž„λ² λ”©μ„ 768μ°¨μ›μœΌλ‘œ νŒ¨λ”©
69
+ if clip_emb.shape[0] < 768:
70
+ padding = np.zeros(768 - clip_emb.shape[0])
71
+ clip_emb = np.concatenate([clip_emb, padding])
72
+ else:
73
+ clip_emb = clip_emb[:768] # 768μ°¨μ›μœΌλ‘œ 자λ₯΄κΈ°
74
+
75
+ # 색상 νŠΉμ§•μ„ 768μ°¨μ›μœΌλ‘œ ν™•μž₯
76
+ color_features_expanded = np.repeat(color_feat, 32) # 24 * 32 = 768
77
+
78
+ # μ •κ·œν™”
79
+ clip_emb = clip_emb / (np.linalg.norm(clip_emb) + 1e-8)
80
+ color_features_expanded = color_features_expanded / (np.linalg.norm(color_features_expanded) + 1e-8)
81
+
82
+ # κ°€μ€‘μΉ˜ κ²°ν•©
83
+ combined = clip_emb * 0.7 + color_features_expanded * 0.3
84
+ combined = combined / (np.linalg.norm(combined) + 1e-8)
85
+
86
+ combined_embeddings.append(combined)
87
+
88
+ return np.array(combined_embeddings)
89
+
90
+ except Exception as e:
91
+ logger.error(f"Error in embedding function: {e}")
92
+ raise
93
+
94
+ def extract_color_histogram(self, image):
95
+ """Extract color histogram from the image"""
96
+ try:
97
+ if isinstance(image, (str, bytes)):
98
+ if isinstance(image, str):
99
+ response = requests.get(image)
100
+ image = Image.open(io.BytesIO(response.content))
101
+ else:
102
+ image = Image.open(io.BytesIO(image))
103
+
104
+ if not isinstance(image, np.ndarray):
105
+ img_array = np.array(image)
106
+ else:
107
+ img_array = image
108
+
109
+ # HSV λ³€ν™˜ 및 νžˆμŠ€ν† κ·Έλž¨ 계산
110
+ img_hsv = Image.fromarray(img_array.astype('uint8')).convert('HSV')
111
+ hsv_pixels = np.array(img_hsv)
112
+
113
+ h_hist = np.histogram(hsv_pixels[:,:,0], bins=8, range=(0, 256))[0]
114
+ s_hist = np.histogram(hsv_pixels[:,:,1], bins=8, range=(0, 256))[0]
115
+ v_hist = np.histogram(hsv_pixels[:,:,2], bins=8, range=(0, 256))[0]
116
+
117
+ # μ •οΏ½οΏ½ν™”
118
+ h_hist = h_hist / (h_hist.sum() + 1e-8)
119
+ s_hist = s_hist / (s_hist.sum() + 1e-8)
120
+ v_hist = v_hist / (v_hist.sum() + 1e-8)
121
+
122
+ return np.concatenate([h_hist, s_hist, v_hist])
123
+ except Exception as e:
124
+ logger.error(f"Color histogram extraction error: {e}")
125
+ return np.zeros(24)
126
+
127
+ # Initialize session state
128
+ if 'image' not in st.session_state:
129
+ st.session_state.image = None
130
+ if 'detected_items' not in st.session_state:
131
+ st.session_state.detected_items = None
132
+ if 'selected_item_index' not in st.session_state:
133
+ st.session_state.selected_item_index = None
134
+ if 'upload_state' not in st.session_state:
135
+ st.session_state.upload_state = 'initial'
136
+ if 'search_clicked' not in st.session_state:
137
+ st.session_state.search_clicked = False
138
+
139
+ # Load segmentation model
140
+ @st.cache_resource
141
+ def load_segmentation_model():
142
+ try:
143
+ model_name = "mattmdjaga/segformer_b2_clothes"
144
+ image_processor = AutoImageProcessor.from_pretrained(model_name)
145
+ model = AutoModelForSemanticSegmentation.from_pretrained(model_name)
146
+
147
+ if torch.cuda.is_available():
148
+ model = model.to('cuda')
149
+
150
+ return model, image_processor
151
+ except Exception as e:
152
+ logger.error(f"Error loading segmentation model: {e}")
153
+ raise
154
+
155
+ # ChromaDB μ„€μ •
156
+ def setup_multimodal_collection():
157
+ """λ©€ν‹°λͺ¨λ‹¬ μ»¬λ ‰μ…˜ μ„€μ •"""
158
+ try:
159
+ client = chromadb.PersistentClient(path="./fashion_multimodal_db")
160
+ embedding_function = CustomFashionEmbeddingFunction()
161
+ data_loader = ImageLoader()
162
+
163
+ # κΈ°μ‘΄ μ»¬λ ‰μ…˜ κ°€μ Έμ˜€κΈ°
164
+ try:
165
+ collection = client.get_collection(
166
+ name="fashion_multimodal",
167
+ embedding_function=embedding_function,
168
+ data_loader=data_loader
169
+ )
170
+ logger.info("Successfully connected to existing clothes_multimodal collection")
171
+ return collection
172
+
173
+ except Exception as e:
174
+ logger.error(f"Error getting existing collection: {e}")
175
+ # μ»¬λ ‰μ…˜μ΄ μ—†λŠ” κ²½μš°μ—λ§Œ μƒˆλ‘œ 생성
176
+ collection = client.create_collection(
177
+ name="clothes_multimodal",
178
+ embedding_function=embedding_function,
179
+ data_loader=data_loader
180
+ )
181
+ logger.info("Created new clothes_multimodal collection")
182
+ return collection
183
+
184
+ except Exception as e:
185
+ logger.error(f"Error setting up multimodal collection: {e}")
186
+ raise
187
+
188
+ def process_segmentation(image):
189
+ """Segmentation processing"""
190
+ try:
191
+ model, image_processor = load_segmentation_model()
192
+
193
+ # 이미지 μ „μ²˜λ¦¬
194
+ inputs = image_processor(image, return_tensors="pt")
195
+
196
+ if torch.cuda.is_available():
197
+ inputs = {k: v.to('cuda') for k, v in inputs.items()}
198
+
199
+ # μΆ”λ‘ 
200
+ with torch.no_grad():
201
+ outputs = model(**inputs)
202
+
203
+ # 둜직 및 ν›„μ²˜λ¦¬
204
+ logits = outputs.logits.cpu()
205
+ upsampled_logits = torch.nn.functional.interpolate(
206
+ logits,
207
+ size=image.size[::-1], # (height, width)
208
+ mode="bilinear",
209
+ align_corners=False,
210
+ )
211
+
212
+ # μ„Έκ·Έλ©˜ν…Œμ΄μ…˜ 마슀크 생성
213
+ seg_masks = upsampled_logits.argmax(dim=1).numpy()
214
+
215
+ processed_items = []
216
+ unique_labels = np.unique(seg_masks)
217
+
218
+ for label_idx in unique_labels:
219
+ if label_idx == 0: # background
220
+ continue
221
+
222
+ mask = (seg_masks[0] == label_idx).astype(float)
223
+
224
+ processed_segment = {
225
+ 'label': f"Item_{label_idx}", # 라벨 맀핑이 ν•„μš”ν•˜λ‹€λ©΄ μ—¬κΈ°μ„œ 처리
226
+ 'score': 1.0, # confidence score 계산이 ν•„μš”ν•˜λ‹€λ©΄ μΆ”κ°€
227
+ 'mask': mask
228
+ }
229
+
230
+ processed_items.append(processed_segment)
231
+
232
+ logger.info(f"Successfully processed {len(processed_items)} segments")
233
+ return processed_items
234
+
235
+ except Exception as e:
236
+ logger.error(f"Segmentation error: {str(e)}")
237
+ import traceback
238
+ logger.error(traceback.format_exc())
239
+ return []
240
+
241
+ def search_similar_items(image, mask=None, top_k=10):
242
+ """λ©€ν‹°λͺ¨λ‹¬ 검색 μˆ˜ν–‰"""
243
+ try:
244
+ collection = setup_multimodal_collection()
245
+
246
+ # 마슀크 적용
247
+ if mask is not None:
248
+ mask_3d = np.stack([mask] * 3, axis=-1)
249
+ masked_image = np.array(image) * mask_3d
250
+ query_image = Image.fromarray(masked_image.astype(np.uint8))
251
+ else:
252
+ query_image = image
253
+
254
+ # 검색 μˆ˜ν–‰
255
+ results = collection.query(
256
+ query_images=[np.array(query_image)],
257
+ n_results=top_k,
258
+ include=['metadatas', 'distances']
259
+ )
260
+
261
+ if not results or 'metadatas' not in results:
262
+ return []
263
+
264
+ similar_items = []
265
+ for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
266
+ similarity_score = (1 - distance) * 100
267
+ item_data = metadata.copy()
268
+ item_data['similarity_score'] = similarity_score
269
+ similar_items.append(item_data)
270
+
271
+ similar_items.sort(key=lambda x: x['similarity_score'], reverse=True)
272
+ return similar_items
273
+
274
+ except Exception as e:
275
+ logger.error(f"Multimodal search error: {e}")
276
+ return []
277
+
278
+ def update_db_with_multimodal():
279
+ """DBλ₯Ό λ©€ν‹°λͺ¨λ‹¬ λ°©μ‹μœΌλ‘œ μ—…λ°μ΄νŠΈ"""
280
+ try:
281
+ # μƒˆ μ»¬λ ‰μ…˜ 생성
282
+ collection = setup_multimodal_collection()
283
+
284
+ # κΈ°μ‘΄ μ»¬λ ‰μ…˜μ—μ„œ 데이터 κ°€μ Έμ˜€κΈ°
285
+ client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
286
+ old_collection = client.get_collection("clothes")
287
+ old_data = old_collection.get(include=['metadatas'])
288
+
289
+ total_items = len(old_data['metadatas'])
290
+ progress_bar = st.progress(0)
291
+ status_text = st.empty()
292
+
293
+ batch_size = 100
294
+ successful_updates = 0
295
+ failed_updates = 0
296
+
297
+ for i in range(0, total_items, batch_size):
298
+ batch = old_data['metadatas'][i:i + batch_size]
299
+
300
+ images = []
301
+ valid_metadatas = []
302
+ valid_ids = []
303
+
304
+ for metadata in batch:
305
+ try:
306
+ if 'image_url' in metadata:
307
+ response = requests.get(metadata['image_url'])
308
+ img = Image.open(io.BytesIO(response.content)).convert('RGB')
309
+ images.append(np.array(img))
310
+ valid_metadatas.append(metadata)
311
+ valid_ids.append(metadata.get('id', str(hash(metadata['image_url']))))
312
+ successful_updates += 1
313
+ except Exception as e:
314
+ logger.error(f"Error processing image: {e}")
315
+ failed_updates += 1
316
+ continue
317
+
318
+ if images:
319
+ collection.add(
320
+ ids=valid_ids,
321
+ images=images,
322
+ metadatas=valid_metadatas
323
+ )
324
+
325
+ # Update progress
326
+ progress = (i + len(batch)) / total_items
327
+ progress_bar.progress(progress)
328
+ status_text.text(f"Processing: {i + len(batch)}/{total_items} items. "
329
+ f"Success: {successful_updates}, Failed: {failed_updates}")
330
+
331
+ status_text.text(f"Update completed. Successfully processed: {successful_updates}, "
332
+ f"Failed: {failed_updates}")
333
+ return True
334
+
335
+ except Exception as e:
336
+ logger.error(f"Multimodal DB update error: {e}")
337
+ return False
338
+
339
+ def show_similar_items(similar_items):
340
+ """Display similar items in a structured format with similarity scores"""
341
+ if not similar_items:
342
+ st.warning("No similar items found.")
343
+ return
344
+
345
+ st.subheader("Similar Items:")
346
+
347
+ items_per_row = 2
348
+ for i in range(0, len(similar_items), items_per_row):
349
+ cols = st.columns(items_per_row)
350
+ for j, col in enumerate(cols):
351
+ if i + j < len(similar_items):
352
+ item = similar_items[i + j]
353
+ with col:
354
+ try:
355
+ if 'image_url' in item:
356
+ st.image(item['image_url'], use_column_width=True)
357
+
358
+ st.markdown(f"**Similarity: {item['similarity_score']:.1f}%**")
359
+
360
+ st.write(f"Brand: {item.get('brand', 'Unknown')}")
361
+ name = item.get('name', 'Unknown')
362
+ if len(name) > 50:
363
+ name = name[:47] + "..."
364
+ st.write(f"Name: {name}")
365
+
366
+ price = item.get('price', 0)
367
+ if isinstance(price, (int, float)):
368
+ st.write(f"Price: {price:,}원")
369
+ else:
370
+ st.write(f"Price: {price}")
371
+
372
+ if 'discount' in item and item['discount']:
373
+ st.write(f"Discount: {item['discount']}%")
374
+ if 'original_price' in item:
375
+ st.write(f"Original: {item['original_price']:,}원")
376
+
377
+ st.divider()
378
+
379
+ except Exception as e:
380
+ logger.error(f"Error displaying item: {e}")
381
+ st.error("Error displaying this item")
382
+
383
+ def process_search(image, mask, num_results):
384
+ """μœ μ‚¬ μ•„μ΄ν…œ 검색 처리"""
385
+ try:
386
+ with st.spinner("Finding similar items..."):
387
+ similar_items = search_similar_items(image, mask, num_results)
388
+
389
+ return similar_items
390
+ except Exception as e:
391
+ logger.error(f"Search processing error: {e}")
392
+ return None
393
+
394
+ def handle_file_upload():
395
+ if st.session_state.uploaded_file is not None:
396
+ image = Image.open(st.session_state.uploaded_file).convert('RGB')
397
+ st.session_state.image = image
398
+ st.session_state.upload_state = 'image_uploaded'
399
+ st.rerun()
400
+
401
+ def handle_detection():
402
+ if st.session_state.image is not None:
403
+ detected_items = process_segmentation(st.session_state.image)
404
+ st.session_state.detected_items = detected_items
405
+ st.session_state.upload_state = 'items_detected'
406
+ st.rerun()
407
+
408
+ def handle_search():
409
+ st.session_state.search_clicked = True
410
+
411
+ def main():
412
+ st.title("Fashion Search App")
413
+
414
+ # Admin controls in sidebar
415
+ st.sidebar.title("Admin Controls")
416
+ if st.sidebar.checkbox("Show Admin Interface"):
417
+ if st.sidebar.button("Update Database (Multimodal)"):
418
+ with st.spinner("Updating database with multimodal support..."):
419
+ success = update_db_with_multimodal()
420
+ if success:
421
+ st.sidebar.success("Database updated successfully!")
422
+ else:
423
+ st.sidebar.error("Failed to update database")
424
+ st.divider()
425
+
426
+ # 파일 μ—…λ‘œλ”
427
+ if st.session_state.upload_state == 'initial':
428
+ uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
429
+ key='uploaded_file', on_change=handle_file_upload)
430
+
431
+ # 이미지가 μ—…λ‘œλ“œλœ μƒνƒœ
432
+ if st.session_state.image is not None:
433
+ st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
434
+
435
+ if st.session_state.detected_items is None:
436
+ if st.button("Detect Items", key='detect_button', on_click=handle_detection):
437
+ pass
438
+
439
+ # κ²€μΆœλœ μ•„μ΄ν…œ ν‘œμ‹œ 및 검색
440
+ if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0:
441
+ cols = st.columns(2)
442
+ for idx, item in enumerate(st.session_state.detected_items):
443
+ with cols[idx % 2]:
444
+ try:
445
+ if item.get('mask') is not None:
446
+ masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2)
447
+ st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}")
448
+
449
+ st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}")
450
+ score = item.get('score')
451
+ if score is not None and isinstance(score, (int, float)):
452
+ st.write(f"Confidence: {score*100:.1f}%")
453
+ else:
454
+ st.write("Confidence: N/A")
455
+ except Exception as e:
456
+ logger.error(f"Error displaying item {idx}: {str(e)}")
457
+ st.error(f"Error displaying item {idx}")
458
+
459
+ valid_items = [i for i in range(len(st.session_state.detected_items))
460
+ if st.session_state.detected_items[i].get('mask') is not None]
461
+
462
+ if not valid_items:
463
+ st.warning("No valid items detected for search.")
464
+ return
465
+
466
+ selected_idx = st.selectbox(
467
+ "Select item to search:",
468
+ valid_items,
469
+ format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}",
470
+ key='item_selector'
471
+ )
472
+
473
+ search_col1, search_col2 = st.columns([1, 2])
474
+ with search_col1:
475
+ search_clicked = st.button("Search Similar Items",
476
+ key='search_button',
477
+ type="primary")
478
+ with search_col2:
479
+ num_results = st.slider("Number of results:",
480
+ min_value=1,
481
+ max_value=20,
482
+ value=5,
483
+ key='num_results')
484
+
485
+ if search_clicked or st.session_state.get('search_clicked', False):
486
+ st.session_state.search_clicked = True
487
+ selected_item = st.session_state.detected_items[selected_idx]
488
+
489
+ if selected_item.get('mask') is None:
490
+ st.error("Selected item has no valid mask for search.")
491
+ return
492
+
493
+ if 'search_results' not in st.session_state:
494
+ similar_items = process_search(st.session_state.image,
495
+ selected_item['mask'],
496
+ num_results)
497
+ st.session_state.search_results = similar_items
498
+
499
+ if st.session_state.search_results:
500
+ show_similar_items(st.session_state.search_results)
501
+ else:
502
+ st.warning("No similar items found.")
503
+
504
+ # μƒˆ 검색 λ²„νŠΌ
505
+ if st.button("Start New Search", key='new_search'):
506
+ for key in list(st.session_state.keys()):
507
+ del st.session_state[key]
508
+ st.rerun()
509
+
510
+ if __name__ == "__main__":
511
+ print('μ‹œμž‘')
512
+ main()