leedoming commited on
Commit
55609ac
Β·
verified Β·
1 Parent(s): 5815464

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +517 -0
app.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import open_clip
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ from transformers import pipeline
7
+ import chromadb
8
+ import logging
9
+ import io
10
+ import requests
11
+ from concurrent.futures import ThreadPoolExecutor
12
+
13
+ # λ‘œκΉ… μ„€μ •
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Initialize session state
18
+ if 'image' not in st.session_state:
19
+ st.session_state.image = None
20
+ if 'detected_items' not in st.session_state:
21
+ st.session_state.detected_items = None
22
+ if 'selected_item_index' not in st.session_state:
23
+ st.session_state.selected_item_index = None
24
+ if 'upload_state' not in st.session_state:
25
+ st.session_state.upload_state = 'initial'
26
+ if 'search_clicked' not in st.session_state:
27
+ st.session_state.search_clicked = False
28
+
29
+ # Load models
30
+ @st.cache_resource
31
+ def load_models():
32
+ try:
33
+ # CLIP λͺ¨λΈ
34
+ model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
35
+
36
+ # μ„Έκ·Έλ©˜ν…Œμ΄μ…˜ λͺ¨λΈ
37
+ segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ model.to(device)
41
+
42
+ return model, preprocess_val, segmenter, device
43
+ except Exception as e:
44
+ logger.error(f"Error loading models: {e}")
45
+ raise
46
+
47
+ # λͺ¨λΈ λ‘œλ“œ
48
+ clip_model, preprocess_val, segmenter, device = load_models()
49
+
50
+ # ChromaDB μ„€μ •
51
+ client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
52
+ collection = client.get_collection(name="clothes")
53
+
54
+ def extract_color_histogram(image, mask=None):
55
+ """Extract color histogram from the image, considering the mask if provided"""
56
+ try:
57
+ img_array = np.array(image)
58
+ if mask is not None:
59
+ # Apply mask
60
+ mask = np.expand_dims(mask, axis=2)
61
+ img_array = img_array * mask
62
+ # Only consider pixels that are part of the clothing item
63
+ valid_pixels = img_array[mask[:,:,0] > 0]
64
+ else:
65
+ valid_pixels = img_array.reshape(-1, 3)
66
+
67
+ # Convert to HSV color space for better color representation
68
+ if len(valid_pixels) > 0:
69
+ img_hsv = Image.fromarray(valid_pixels.reshape(1, -1, 3).astype(np.uint8)).convert('HSV')
70
+ hsv_pixels = np.array(img_hsv)
71
+
72
+ # Calculate histogram for each HSV channel
73
+ h_hist = np.histogram(hsv_pixels[:,:,0], bins=10, range=(0, 256))[0]
74
+ s_hist = np.histogram(hsv_pixels[:,:,1], bins=10, range=(0, 256))[0]
75
+ v_hist = np.histogram(hsv_pixels[:,:,2], bins=10, range=(0, 256))[0]
76
+
77
+ # Normalize histograms
78
+ h_hist = h_hist / h_hist.sum() if h_hist.sum() > 0 else h_hist
79
+ s_hist = s_hist / s_hist.sum() if s_hist.sum() > 0 else s_hist
80
+ v_hist = v_hist / v_hist.sum() if v_hist.sum() > 0 else v_hist
81
+
82
+ return np.concatenate([h_hist, s_hist, v_hist])
83
+ return np.zeros(30) # Return zero histogram if no valid pixels
84
+ except Exception as e:
85
+ logger.error(f"Color histogram extraction error: {e}")
86
+ return np.zeros(30)
87
+
88
+ def process_segmentation(image):
89
+ """Segmentation processing"""
90
+ try:
91
+ # pipeline 좜λ ₯ κ²°κ³Ό 직접 처리
92
+ output = segmenter(image)
93
+
94
+ if not output or len(output) == 0:
95
+ logger.warning("No segments found in image")
96
+ return []
97
+
98
+ processed_items = []
99
+ for segment in output:
100
+ # 기본값을 ν¬ν•¨ν•˜μ—¬ λ”•μ…”λ„ˆλ¦¬ 생성
101
+ processed_segment = {
102
+ 'label': segment.get('label', 'Unknown'),
103
+ 'score': segment.get('score', 1.0), # scoreκ°€ μ—†μœΌλ©΄ 1.0을 κΈ°λ³Έκ°’μœΌλ‘œ μ‚¬μš©
104
+ 'mask': None
105
+ }
106
+
107
+ mask = segment.get('mask')
108
+ if mask is not None:
109
+ # λ§ˆμŠ€ν¬κ°€ numpy arrayκ°€ μ•„λ‹Œ 경우 λ³€ν™˜
110
+ if not isinstance(mask, np.ndarray):
111
+ mask = np.array(mask)
112
+
113
+ # λ§ˆμŠ€ν¬κ°€ 2Dκ°€ μ•„λ‹Œ 경우 첫 번째 채널 μ‚¬μš©
114
+ if len(mask.shape) > 2:
115
+ mask = mask[:, :, 0]
116
+
117
+ # bool 마슀크λ₯Ό float둜 λ³€ν™˜
118
+ processed_segment['mask'] = mask.astype(float)
119
+ else:
120
+ logger.warning(f"No mask found for segment with label {processed_segment['label']}")
121
+ continue # λ§ˆμŠ€ν¬κ°€ μ—†λŠ” μ„Έκ·Έλ¨ΌνŠΈλŠ” κ±΄λ„ˆλœ€
122
+
123
+ processed_items.append(processed_segment)
124
+
125
+ logger.info(f"Successfully processed {len(processed_items)} segments")
126
+ return processed_items
127
+
128
+ except Exception as e:
129
+ logger.error(f"Segmentation error: {str(e)}")
130
+ import traceback
131
+ logger.error(traceback.format_exc())
132
+ return []
133
+
134
+ def extract_features(image, mask=None):
135
+ """Extract both CLIP features and color features with segmentation mask"""
136
+ try:
137
+ # Extract CLIP features
138
+ if mask is not None:
139
+ img_array = np.array(image)
140
+ mask = np.expand_dims(mask, axis=2)
141
+ masked_img = img_array * mask
142
+ masked_img[mask[:,:,0] == 0] = 255 # Set background to white
143
+ image = Image.fromarray(masked_img.astype(np.uint8))
144
+
145
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
146
+ with torch.no_grad():
147
+ clip_features = clip_model.encode_image(image_tensor)
148
+ clip_features /= clip_features.norm(dim=-1, keepdim=True)
149
+ clip_features = clip_features.cpu().numpy().flatten()
150
+
151
+ # Extract color features
152
+ color_features = extract_color_histogram(image, mask)
153
+
154
+ # Combine features
155
+ # Note: We normalize and weight the features to balance their influence
156
+ clip_features_normalized = clip_features / np.linalg.norm(clip_features)
157
+ color_features_normalized = color_features / np.linalg.norm(color_features)
158
+
159
+ # Adjust these weights to control the influence of each feature type
160
+ clip_weight = 0.7 # CLIP features weight
161
+ color_weight = 0.3 # Color features weight
162
+
163
+ combined_features = np.concatenate([
164
+ clip_features_normalized * clip_weight,
165
+ color_features_normalized * color_weight
166
+ ])
167
+
168
+ return combined_features
169
+ except Exception as e:
170
+ logger.error(f"Feature extraction error: {e}")
171
+ raise
172
+
173
+ def download_and_process_image(image_url, metadata_id):
174
+ """Download image from URL and apply segmentation"""
175
+ try:
176
+ response = requests.get(image_url, timeout=10)
177
+ if response.status_code != 200:
178
+ logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
179
+ return None
180
+
181
+ image = Image.open(io.BytesIO(response.content)).convert('RGB')
182
+ logger.info(f"Successfully downloaded image {metadata_id}")
183
+
184
+ processed_items = process_segmentation(image)
185
+ if processed_items and len(processed_items) > 0:
186
+ # κ°€μž₯ 큰 μ„Έκ·Έλ¨ΌνŠΈμ˜ 마슀크 μ‚¬μš©
187
+ largest_mask = max(processed_items, key=lambda x: np.sum(x['mask']))['mask']
188
+ features = extract_features(image, largest_mask)
189
+ logger.info(f"Successfully extracted features for image {metadata_id}")
190
+ return features
191
+
192
+ logger.warning(f"No valid mask found for image {metadata_id}")
193
+ return None
194
+
195
+ except Exception as e:
196
+ logger.error(f"Error processing image {metadata_id}: {str(e)}")
197
+ import traceback
198
+ logger.error(traceback.format_exc())
199
+ return None
200
+
201
+ def update_db_with_segmentation():
202
+ """DB의 λͺ¨λ“  이미지에 λŒ€ν•΄ segmentation을 μ μš©ν•˜κ³  featureλ₯Ό μ—…λ°μ΄νŠΈ"""
203
+ try:
204
+ logger.info("Starting database update with segmentation and color features")
205
+
206
+ # μƒˆλ‘œμš΄ collection 생성
207
+ try:
208
+ client.delete_collection("clothes_segmented")
209
+ logger.info("Deleted existing segmented collection")
210
+ except:
211
+ logger.info("No existing segmented collection to delete")
212
+
213
+ new_collection = client.create_collection(
214
+ name="clothes_segmented",
215
+ metadata={"description": "Clothes collection with segmentation and color features"}
216
+ )
217
+ logger.info("Created new segmented collection")
218
+
219
+ # κΈ°μ‘΄ collectionμ—μ„œ λ©”νƒ€λ°μ΄ν„°λ§Œ κ°€μ Έμ˜€κΈ°
220
+ try:
221
+ all_items = collection.get(include=['metadatas'])
222
+ total_items = len(all_items['metadatas'])
223
+ logger.info(f"Found {total_items} items in database")
224
+ except Exception as e:
225
+ logger.error(f"Error getting items from collection: {str(e)}")
226
+ all_items = {'metadatas': []}
227
+ total_items = 0
228
+
229
+ # μ§„ν–‰ 상황 ν‘œμ‹œλ₯Ό μœ„ν•œ progress bar
230
+ progress_bar = st.progress(0)
231
+ status_text = st.empty()
232
+
233
+ successful_updates = 0
234
+ failed_updates = 0
235
+
236
+ with ThreadPoolExecutor(max_workers=4) as executor:
237
+ futures = []
238
+ # 이미지 URL이 μžˆλŠ” ν•­λͺ©λ§Œ 처리
239
+ valid_items = [m for m in all_items['metadatas'] if 'image_url' in m]
240
+
241
+ for metadata in valid_items:
242
+ future = executor.submit(
243
+ download_and_process_image,
244
+ metadata['image_url'],
245
+ metadata.get('id', 'unknown')
246
+ )
247
+ futures.append((metadata, future))
248
+
249
+ # κ²°κ³Ό 처리 및 μƒˆ DB에 μ €μž₯
250
+ for idx, (metadata, future) in enumerate(futures):
251
+ try:
252
+ new_features = future.result()
253
+ if new_features is not None:
254
+ item_id = metadata.get('id', str(hash(metadata['image_url'])))
255
+ try:
256
+ new_collection.add(
257
+ embeddings=[new_features.tolist()],
258
+ metadatas=[metadata],
259
+ ids=[item_id]
260
+ )
261
+ successful_updates += 1
262
+ logger.info(f"Successfully added item {item_id}")
263
+ except Exception as e:
264
+ logger.error(f"Error adding item to new collection: {str(e)}")
265
+ failed_updates += 1
266
+ else:
267
+ failed_updates += 1
268
+
269
+ # μ§„ν–‰ 상황 μ—…λ°μ΄νŠΈ
270
+ progress = (idx + 1) / len(futures)
271
+ progress_bar.progress(progress)
272
+ status_text.text(f"Processing: {idx + 1}/{len(futures)} items. Success: {successful_updates}, Failed: {failed_updates}")
273
+
274
+ except Exception as e:
275
+ logger.error(f"Error processing item: {str(e)}")
276
+ failed_updates += 1
277
+ continue
278
+
279
+ # μ΅œμ’… κ²°κ³Ό ν‘œμ‹œ
280
+ status_text.text(f"Update completed. Successfully processed: {successful_updates}, Failed: {failed_updates}")
281
+ logger.info(f"Database update completed. Successful: {successful_updates}, Failed: {failed_updates}")
282
+
283
+ # μ„±κ³΅μ μœΌλ‘œ 처리된 ν•­λͺ©μ΄ μžˆλŠ”μ§€ 확인
284
+ if successful_updates > 0:
285
+ return True
286
+ else:
287
+ logger.error("No items were successfully processed")
288
+ return False
289
+
290
+ except Exception as e:
291
+ logger.error(f"Database update error: {str(e)}")
292
+ import traceback
293
+ logger.error(traceback.format_exc())
294
+ return False
295
+
296
+ def search_similar_items(features, top_k=10):
297
+ """Search similar items using combined features"""
298
+ try:
299
+ # μ„Έκ·Έλ©˜ν…Œμ΄μ…˜μ΄ 적용된 collection이 μžˆλŠ”μ§€ 확인
300
+ try:
301
+ search_collection = client.get_collection("clothes_segmented")
302
+ logger.info("Using segmented collection for search")
303
+ except:
304
+ # μ—†μœΌλ©΄ κΈ°μ‘΄ collection μ‚¬μš©
305
+ search_collection = collection
306
+ logger.info("Using original collection for search")
307
+
308
+ results = search_collection.query(
309
+ query_embeddings=[features.tolist()],
310
+ n_results=top_k,
311
+ include=['metadatas', 'distances']
312
+ )
313
+
314
+ if not results or not results['metadatas'] or not results['distances']:
315
+ logger.warning("No results returned from ChromaDB")
316
+ return []
317
+
318
+ similar_items = []
319
+ for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
320
+ try:
321
+ similarity_score = 1 / (1 + float(distance))
322
+ item_data = metadata.copy()
323
+ item_data['similarity_score'] = similarity_score
324
+ similar_items.append(item_data)
325
+ except Exception as e:
326
+ logger.error(f"Error processing search result: {str(e)}")
327
+ continue
328
+
329
+ similar_items.sort(key=lambda x: x['similarity_score'], reverse=True)
330
+ return similar_items
331
+ except Exception as e:
332
+ logger.error(f"Search error: {str(e)}")
333
+ return []
334
+
335
+ def show_similar_items(similar_items):
336
+ """Display similar items in a structured format with similarity scores"""
337
+ if not similar_items:
338
+ st.warning("No similar items found.")
339
+ return
340
+
341
+ st.subheader("Similar Items:")
342
+
343
+ # κ²°κ³Όλ₯Ό 2μ—΄λ‘œ ν‘œμ‹œ
344
+ items_per_row = 2
345
+ for i in range(0, len(similar_items), items_per_row):
346
+ cols = st.columns(items_per_row)
347
+ for j, col in enumerate(cols):
348
+ if i + j < len(similar_items):
349
+ item = similar_items[i + j]
350
+ with col:
351
+ try:
352
+ if 'image_url' in item:
353
+ st.image(item['image_url'], use_column_width=True)
354
+
355
+ # μœ μ‚¬λ„ 점수λ₯Ό νΌμ„ΌνŠΈλ‘œ ν‘œμ‹œ
356
+ similarity_percent = item['similarity_score'] * 100
357
+ st.markdown(f"**Similarity: {similarity_percent:.1f}%**")
358
+
359
+ st.write(f"Brand: {item.get('brand', 'Unknown')}")
360
+ name = item.get('name', 'Unknown')
361
+ if len(name) > 50: # κΈ΄ 이름은 μ€„μž„
362
+ name = name[:47] + "..."
363
+ st.write(f"Name: {name}")
364
+
365
+ # 가격 정보 ν‘œμ‹œ
366
+ price = item.get('price', 0)
367
+ if isinstance(price, (int, float)):
368
+ st.write(f"Price: {price:,}원")
369
+ else:
370
+ st.write(f"Price: {price}")
371
+
372
+ # 할인 정보가 μžˆλŠ” 경우
373
+ if 'discount' in item and item['discount']:
374
+ st.write(f"Discount: {item['discount']}%")
375
+ if 'original_price' in item:
376
+ st.write(f"Original: {item['original_price']:,}원")
377
+
378
+ st.divider() # ꡬ뢄선 μΆ”κ°€
379
+
380
+ except Exception as e:
381
+ logger.error(f"Error displaying item: {e}")
382
+ st.error("Error displaying this item")
383
+
384
+ def process_search(image, mask, num_results):
385
+ """μœ μ‚¬ μ•„μ΄ν…œ 검색 처리"""
386
+ try:
387
+ with st.spinner("Extracting features..."):
388
+ features = extract_features(image, mask)
389
+
390
+ with st.spinner("Finding similar items..."):
391
+ similar_items = search_similar_items(features, top_k=num_results)
392
+
393
+ return similar_items
394
+ except Exception as e:
395
+ logger.error(f"Search processing error: {e}")
396
+ return None
397
+
398
+ def handle_file_upload():
399
+ if st.session_state.uploaded_file is not None:
400
+ image = Image.open(st.session_state.uploaded_file).convert('RGB')
401
+ st.session_state.image = image
402
+ st.session_state.upload_state = 'image_uploaded'
403
+ st.rerun()
404
+
405
+ def handle_detection():
406
+ if st.session_state.image is not None:
407
+ detected_items = process_segmentation(st.session_state.image)
408
+ st.session_state.detected_items = detected_items
409
+ st.session_state.upload_state = 'items_detected'
410
+ st.rerun()
411
+
412
+ def handle_search():
413
+ st.session_state.search_clicked = True
414
+
415
+ def main():
416
+ st.title("Fashion Search App")
417
+
418
+ # Admin controls in sidebar
419
+ st.sidebar.title("Admin Controls")
420
+ if st.sidebar.checkbox("Show Admin Interface"):
421
+ # Admin interface κ΅¬ν˜„ (ν•„μš”ν•œ 경우)
422
+ st.sidebar.warning("Admin interface is not implemented yet.")
423
+ st.divider()
424
+
425
+ # 파일 μ—…λ‘œλ”
426
+ if st.session_state.upload_state == 'initial':
427
+ uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
428
+ key='uploaded_file', on_change=handle_file_upload)
429
+
430
+ # 이미지가 μ—…λ‘œλ“œλœ μƒνƒœ
431
+ if st.session_state.image is not None:
432
+ st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
433
+
434
+ if st.session_state.detected_items is None:
435
+ if st.button("Detect Items", key='detect_button', on_click=handle_detection):
436
+ pass
437
+
438
+ # κ²€μΆœλœ μ•„μ΄ν…œ ν‘œμ‹œ
439
+ if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0:
440
+ # κ°μ§€λœ μ•„μ΄ν…œλ“€μ„ 2μ—΄λ‘œ ν‘œμ‹œ
441
+ cols = st.columns(2)
442
+ for idx, item in enumerate(st.session_state.detected_items):
443
+ with cols[idx % 2]:
444
+ try:
445
+ if item.get('mask') is not None:
446
+ masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2)
447
+ st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}")
448
+
449
+ st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}")
450
+
451
+ # score 값이 있고 숫자인 κ²½μš°μ—λ§Œ ν‘œμ‹œ
452
+ score = item.get('score')
453
+ if score is not None and isinstance(score, (int, float)):
454
+ st.write(f"Confidence: {score*100:.1f}%")
455
+ else:
456
+ st.write("Confidence: N/A")
457
+ except Exception as e:
458
+ logger.error(f"Error displaying item {idx}: {str(e)}")
459
+ st.error(f"Error displaying item {idx}")
460
+
461
+ valid_items = [i for i in range(len(st.session_state.detected_items))
462
+ if st.session_state.detected_items[i].get('mask') is not None]
463
+
464
+ if not valid_items:
465
+ st.warning("No valid items detected for search.")
466
+ return
467
+
468
+ # μ•„μ΄ν…œ 선택
469
+ selected_idx = st.selectbox(
470
+ "Select item to search:",
471
+ valid_items,
472
+ format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}",
473
+ key='item_selector'
474
+ )
475
+
476
+ # 검색 컨트둀
477
+ search_col1, search_col2 = st.columns([1, 2])
478
+ with search_col1:
479
+ search_clicked = st.button("Search Similar Items",
480
+ key='search_button',
481
+ type="primary")
482
+ with search_col2:
483
+ num_results = st.slider("Number of results:",
484
+ min_value=1,
485
+ max_value=20,
486
+ value=5,
487
+ key='num_results')
488
+
489
+ # 검색 κ²°κ³Ό 처리
490
+ if search_clicked or st.session_state.get('search_clicked', False):
491
+ st.session_state.search_clicked = True
492
+ selected_item = st.session_state.detected_items[selected_idx]
493
+
494
+ if selected_item.get('mask') is None:
495
+ st.error("Selected item has no valid mask for search.")
496
+ return
497
+
498
+ # 검색 κ²°κ³Όλ₯Ό μ„Έμ…˜ μƒνƒœμ— μ €μž₯
499
+ if 'search_results' not in st.session_state:
500
+ similar_items = process_search(st.session_state.image, selected_item['mask'], num_results)
501
+ st.session_state.search_results = similar_items
502
+
503
+ # μ €μž₯된 검색 κ²°κ³Ό ν‘œμ‹œ
504
+ if st.session_state.search_results:
505
+ show_similar_items(st.session_state.search_results)
506
+ else:
507
+ st.warning("No similar items found.")
508
+
509
+ # μƒˆ 검색 λ²„νŠΌ
510
+ if st.button("Start New Search", key='new_search'):
511
+ # λͺ¨λ“  μƒνƒœ μ΄ˆκΈ°ν™”
512
+ for key in list(st.session_state.keys()):
513
+ del st.session_state[key]
514
+ st.rerun()
515
+
516
+ if __name__ == "__main__":
517
+ main()