leedoming commited on
Commit
8b3d9ea
ยท
verified ยท
1 Parent(s): dca0895

Rename app-yolo.py to app.py

Browse files
Files changed (1) hide show
  1. app-yolo.py โ†’ app.py +85 -45
app-yolo.py โ†’ app.py RENAMED
@@ -29,6 +29,15 @@ def load_yolo_model():
29
 
30
  yolo_model = load_yolo_model()
31
 
 
 
 
 
 
 
 
 
 
32
  # Helper functions
33
  def load_image_from_url(url, max_retries=3):
34
  for attempt in range(max_retries):
@@ -42,9 +51,6 @@ def load_image_from_url(url, max_retries=3):
42
  time.sleep(1)
43
  else:
44
  return None
45
- #Load chromaDB
46
- client = chromadb.PersistentClient(path="./clothesDB")
47
- collection = client.get_collection(name="clothes_items_ver3")
48
 
49
  def get_image_embedding(image):
50
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
@@ -60,32 +66,56 @@ def get_text_embedding(text):
60
  text_features /= text_features.norm(dim=-1, keepdim=True)
61
  return text_features.cpu().numpy()
62
 
63
- def get_all_embeddings_from_collection(collection):
64
- all_embeddings = collection.get(include=['embeddings'])['embeddings']
65
- return np.array(all_embeddings)
66
-
67
- def get_metadata_from_ids(collection, ids):
68
- results = collection.get(ids=ids)
69
- return results['metadatas']
70
 
71
- def find_similar_images(query_embedding, collection, top_k=5):
72
- database_embeddings = get_all_embeddings_from_collection(collection)
73
- similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
74
- top_indices = np.argsort(similarities)[::-1][:top_k]
75
 
76
- all_data = collection.get(include=['metadatas'])['metadatas']
 
 
 
 
77
 
78
- top_metadatas = [all_data[idx] for idx in top_indices]
 
 
 
 
 
 
 
 
 
 
79
 
80
- results = []
81
- for idx, metadata in enumerate(top_metadatas):
82
- results.append({
83
  'info': metadata,
84
- 'similarity': similarities[top_indices[idx]]
85
  })
86
- return results
87
-
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def detect_clothing(image):
91
  results = yolo_model(image)
@@ -118,6 +148,12 @@ if 'selected_category' not in st.session_state:
118
  # Streamlit app
119
  st.title("Advanced Fashion Search App")
120
 
 
 
 
 
 
 
121
  # ๋‹จ๊ณ„๋ณ„ ์ฒ˜๋ฆฌ
122
  if st.session_state.step == 'input':
123
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
@@ -136,7 +172,6 @@ if st.session_state.step == 'input':
136
  else:
137
  st.warning("Please enter an image URL.")
138
 
139
- # Update the 'select_category' step
140
  elif st.session_state.step == 'select_category':
141
  st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
142
  st.subheader("Detected Clothing Items:")
@@ -179,6 +214,13 @@ elif st.session_state.step == 'show_results':
179
  st.write(f"Price: {img['info']['price']}")
180
  st.write(f"Discount: {img['info']['discount']}%")
181
  st.write(f"Similarity: {img['similarity']:.2f}")
 
 
 
 
 
 
 
182
 
183
  if st.button("Start New Search"):
184
  st.session_state.step = 'input'
@@ -186,25 +228,23 @@ elif st.session_state.step == 'show_results':
186
  st.session_state.detections = []
187
  st.session_state.selected_category = None
188
 
189
- else: # Text search
190
- query_text = st.text_input("Enter search text:")
191
- if st.button("Search by Text"):
192
- if query_text:
193
- text_embedding = get_text_embedding(query_text)
194
- similar_images = find_similar_images(text_embedding, collection)
195
- st.subheader("Similar Items:")
196
- for img in similar_images:
197
- col1, col2 = st.columns(2)
198
- with col1:
199
- st.image(img['info']['image_url'], use_column_width=True)
200
- with col2:
201
- st.write(f"Name: {img['info']['name']}")
202
- st.write(f"Brand: {img['info']['brand']}")
203
- category = img['info'].get('category')
204
- if category:
205
- st.write(f"Category: {category}")
206
- st.write(f"Price: {img['info']['price']}")
207
- st.write(f"Discount: {img['info']['discount']}%")
208
- st.write(f"Similarity: {img['similarity']:.2f}")
209
- else:
210
- st.warning("Please enter a search text.")
 
29
 
30
  yolo_model = load_yolo_model()
31
 
32
+ # Load ChromaDB
33
+ @st.cache_resource
34
+ def load_chromadb():
35
+ client = chromadb.PersistentClient(path="./chromadb_new")
36
+ collection = client.get_collection(name="clothes_items_musinsa_ver1")
37
+ return collection
38
+
39
+ collection = load_chromadb()
40
+
41
  # Helper functions
42
  def load_image_from_url(url, max_retries=3):
43
  for attempt in range(max_retries):
 
51
  time.sleep(1)
52
  else:
53
  return None
 
 
 
54
 
55
  def get_image_embedding(image):
56
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
 
66
  text_features /= text_features.norm(dim=-1, keepdim=True)
67
  return text_features.cpu().numpy()
68
 
69
+ def get_average_embedding(main_image_url, additional_image_urls):
70
+ embeddings = []
 
 
 
 
 
71
 
72
+ # ๋ฉ”์ธ ์ด๋ฏธ์ง€ ์ž„๋ฒ ๋”ฉ
73
+ main_image = load_image_from_url(main_image_url)
74
+ if main_image:
75
+ embeddings.append(get_image_embedding(main_image))
76
 
77
+ # ์ถ”๊ฐ€ ์ด๋ฏธ์ง€ ์ž„๋ฒ ๋”ฉ
78
+ for url in additional_image_urls:
79
+ img = load_image_from_url(url)
80
+ if img:
81
+ embeddings.append(get_image_embedding(img))
82
 
83
+ if embeddings:
84
+ return np.mean(embeddings, axis=0)
85
+ else:
86
+ return None
87
+
88
+ def find_similar_images(query_embedding, collection, top_k=5):
89
+ results = collection.query(
90
+ query_embeddings=[query_embedding.squeeze().tolist()],
91
+ n_results=top_k,
92
+ include=["metadatas", "distances"]
93
+ )
94
 
95
+ similar_items = []
96
+ for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
97
+ similar_items.append({
98
  'info': metadata,
99
+ 'similarity': 1 - distance # ๊ฑฐ๋ฆฌ๋ฅผ ์œ ์‚ฌ๋„๋กœ ๋ณ€ํ™˜
100
  })
101
+
102
+ return similar_items
103
 
104
+ def update_collection_embeddings():
105
+ all_ids = collection.get(include=['metadatas'])['ids']
106
+ all_metadata = collection.get(include=['metadatas'])['metadatas']
107
+
108
+ for id, metadata in zip(all_ids, all_metadata):
109
+ main_image_url = metadata['image_url']
110
+ additional_image_urls = metadata.get('additional_images', [])
111
+
112
+ avg_embedding = get_average_embedding(main_image_url, additional_image_urls)
113
+
114
+ if avg_embedding is not None:
115
+ collection.update(
116
+ ids=[id],
117
+ embeddings=[avg_embedding.tolist()]
118
+ )
119
 
120
  def detect_clothing(image):
121
  results = yolo_model(image)
 
148
  # Streamlit app
149
  st.title("Advanced Fashion Search App")
150
 
151
+ # ์ปฌ๋ ‰์…˜ ์ž„๋ฒ ๋”ฉ ์—…๋ฐ์ดํŠธ (์ฒซ ์‹คํ–‰ ์‹œ ํ•œ ๋ฒˆ๋งŒ)
152
+ if 'embeddings_updated' not in st.session_state:
153
+ with st.spinner("Updating collection embeddings... This may take a while."):
154
+ update_collection_embeddings()
155
+ st.session_state.embeddings_updated = True
156
+
157
  # ๋‹จ๊ณ„๋ณ„ ์ฒ˜๋ฆฌ
158
  if st.session_state.step == 'input':
159
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
 
172
  else:
173
  st.warning("Please enter an image URL.")
174
 
 
175
  elif st.session_state.step == 'select_category':
176
  st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
177
  st.subheader("Detected Clothing Items:")
 
214
  st.write(f"Price: {img['info']['price']}")
215
  st.write(f"Discount: {img['info']['discount']}%")
216
  st.write(f"Similarity: {img['similarity']:.2f}")
217
+
218
+ # ์ถ”๊ฐ€ ์ด๋ฏธ์ง€ ํ‘œ์‹œ
219
+ additional_images = img['info'].get('additional_images', [])
220
+ if additional_images:
221
+ st.write("Additional Images:")
222
+ for add_img_url in additional_images[:3]: # ์ตœ๋Œ€ 3๊ฐœ๊นŒ์ง€๋งŒ ํ‘œ์‹œ
223
+ st.image(add_img_url, width=100)
224
 
225
  if st.button("Start New Search"):
226
  st.session_state.step = 'input'
 
228
  st.session_state.detections = []
229
  st.session_state.selected_category = None
230
 
231
+ # Text search
232
+ st.sidebar.title("Text Search")
233
+ query_text = st.sidebar.text_input("Enter search text:")
234
+ if st.sidebar.button("Search by Text"):
235
+ if query_text:
236
+ text_embedding = get_text_embedding(query_text)
237
+ similar_images = find_similar_images(text_embedding, collection)
238
+ st.sidebar.subheader("Similar Items:")
239
+ for img in similar_images:
240
+ st.sidebar.image(img['info']['image_url'], use_column_width=True)
241
+ st.sidebar.write(f"Name: {img['info']['name']}")
242
+ st.sidebar.write(f"Brand: {img['info']['brand']}")
243
+ category = img['info'].get('category')
244
+ if category:
245
+ st.sidebar.write(f"Category: {category}")
246
+ st.sidebar.write(f"Price: {img['info']['price']}")
247
+ st.sidebar.write(f"Discount: {img['info']['discount']}%")
248
+ st.sidebar.write(f"Similarity: {img['similarity']:.2f}")
249
+ else:
250
+ st.sidebar.warning("Please enter a search text.")