Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -63,23 +63,30 @@ def process_segmentation(image):
|
|
| 63 |
|
| 64 |
processed_items = []
|
| 65 |
for segment in output:
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
processed_items.append(
|
| 79 |
-
'mask': mask,
|
| 80 |
-
'label': segment.get('label', 'Unknown'),
|
| 81 |
-
'score': segment.get('score', 0.0)
|
| 82 |
-
})
|
| 83 |
|
| 84 |
logger.info(f"Successfully processed {len(processed_items)} segments")
|
| 85 |
return processed_items
|
|
@@ -380,17 +387,35 @@ def main():
|
|
| 380 |
cols = st.columns(2)
|
| 381 |
for idx, item in enumerate(st.session_state.detected_items):
|
| 382 |
with cols[idx % 2]:
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
# ์์ดํ
์ ํ
|
| 390 |
selected_idx = st.selectbox(
|
| 391 |
"Select item to search:",
|
| 392 |
-
|
| 393 |
-
format_func=lambda i: f"{st.session_state.detected_items[i]
|
| 394 |
key='item_selector'
|
| 395 |
)
|
| 396 |
|
|
@@ -410,11 +435,15 @@ def main():
|
|
| 410 |
# ๊ฒ์ ๊ฒฐ๊ณผ ์ฒ๋ฆฌ
|
| 411 |
if search_clicked or st.session_state.get('search_clicked', False):
|
| 412 |
st.session_state.search_clicked = True
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
# ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ธ์
์ํ์ ์ ์ฅ
|
| 416 |
if 'search_results' not in st.session_state:
|
| 417 |
-
similar_items = process_search(st.session_state.image,
|
| 418 |
st.session_state.search_results = similar_items
|
| 419 |
|
| 420 |
# ์ ์ฅ๋ ๊ฒ์ ๊ฒฐ๊ณผ ํ์
|
|
|
|
| 63 |
|
| 64 |
processed_items = []
|
| 65 |
for segment in output:
|
| 66 |
+
# ๊ธฐ๋ณธ๊ฐ์ ํฌํจํ์ฌ ๋์
๋๋ฆฌ ์์ฑ
|
| 67 |
+
processed_segment = {
|
| 68 |
+
'label': segment.get('label', 'Unknown'),
|
| 69 |
+
'score': segment.get('score', 1.0), # score๊ฐ ์์ผ๋ฉด 1.0์ ๊ธฐ๋ณธ๊ฐ์ผ๋ก ์ฌ์ฉ
|
| 70 |
+
'mask': None
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
mask = segment.get('mask')
|
| 74 |
+
if mask is not None:
|
| 75 |
+
# ๋ง์คํฌ๊ฐ numpy array๊ฐ ์๋ ๊ฒฝ์ฐ ๋ณํ
|
| 76 |
+
if not isinstance(mask, np.ndarray):
|
| 77 |
+
mask = np.array(mask)
|
| 78 |
|
| 79 |
+
# ๋ง์คํฌ๊ฐ 2D๊ฐ ์๋ ๊ฒฝ์ฐ ์ฒซ ๋ฒ์งธ ์ฑ๋ ์ฌ์ฉ
|
| 80 |
+
if len(mask.shape) > 2:
|
| 81 |
+
mask = mask[:, :, 0]
|
| 82 |
|
| 83 |
+
# bool ๋ง์คํฌ๋ฅผ float๋ก ๋ณํ
|
| 84 |
+
processed_segment['mask'] = mask.astype(float)
|
| 85 |
+
else:
|
| 86 |
+
logger.warning(f"No mask found for segment with label {processed_segment['label']}")
|
| 87 |
+
continue # ๋ง์คํฌ๊ฐ ์๋ ์ธ๊ทธ๋จผํธ๋ ๊ฑด๋๋
|
| 88 |
|
| 89 |
+
processed_items.append(processed_segment)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
logger.info(f"Successfully processed {len(processed_items)} segments")
|
| 92 |
return processed_items
|
|
|
|
| 387 |
cols = st.columns(2)
|
| 388 |
for idx, item in enumerate(st.session_state.detected_items):
|
| 389 |
with cols[idx % 2]:
|
| 390 |
+
try:
|
| 391 |
+
if item.get('mask') is not None:
|
| 392 |
+
masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2)
|
| 393 |
+
st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}")
|
| 394 |
+
|
| 395 |
+
st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}")
|
| 396 |
+
|
| 397 |
+
# score ๊ฐ์ด ์๊ณ ์ซ์์ธ ๊ฒฝ์ฐ์๋ง ํ์
|
| 398 |
+
score = item.get('score')
|
| 399 |
+
if score is not None and isinstance(score, (int, float)):
|
| 400 |
+
st.write(f"Confidence: {score*100:.1f}%")
|
| 401 |
+
else:
|
| 402 |
+
st.write("Confidence: N/A")
|
| 403 |
+
except Exception as e:
|
| 404 |
+
logger.error(f"Error displaying item {idx}: {str(e)}")
|
| 405 |
+
st.error(f"Error displaying item {idx}")
|
| 406 |
+
|
| 407 |
+
valid_items = [i for i in range(len(st.session_state.detected_items))
|
| 408 |
+
if st.session_state.detected_items[i].get('mask') is not None]
|
| 409 |
|
| 410 |
+
if not valid_items:
|
| 411 |
+
st.warning("No valid items detected for search.")
|
| 412 |
+
return
|
| 413 |
+
|
| 414 |
# ์์ดํ
์ ํ
|
| 415 |
selected_idx = st.selectbox(
|
| 416 |
"Select item to search:",
|
| 417 |
+
valid_items,
|
| 418 |
+
format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}",
|
| 419 |
key='item_selector'
|
| 420 |
)
|
| 421 |
|
|
|
|
| 435 |
# ๊ฒ์ ๊ฒฐ๊ณผ ์ฒ๋ฆฌ
|
| 436 |
if search_clicked or st.session_state.get('search_clicked', False):
|
| 437 |
st.session_state.search_clicked = True
|
| 438 |
+
selected_item = st.session_state.detected_items[selected_idx]
|
| 439 |
+
|
| 440 |
+
if selected_item.get('mask') is None:
|
| 441 |
+
st.error("Selected item has no valid mask for search.")
|
| 442 |
+
return
|
| 443 |
|
| 444 |
# ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ธ์
์ํ์ ์ ์ฅ
|
| 445 |
if 'search_results' not in st.session_state:
|
| 446 |
+
similar_items = process_search(st.session_state.image, selected_item['mask'], num_results)
|
| 447 |
st.session_state.search_results = similar_items
|
| 448 |
|
| 449 |
# ์ ์ฅ๋ ๊ฒ์ ๊ฒฐ๊ณผ ํ์
|