JoJosmin commited on
Commit
2c7cf10
·
verified ·
1 Parent(s): 4688a56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -208
app.py CHANGED
@@ -1,209 +1,224 @@
1
- import streamlit as st
2
- import open_clip
3
- import torch
4
- import requests
5
- from PIL import Image
6
- from io import BytesIO
7
- import time
8
- import json
9
- import numpy as np
10
- import onnxruntime as ort
11
- import cv2
12
- import chromadb
13
-
14
- @st.cache_resource
15
- def load_clip_model():
16
- model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
17
- tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
18
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- model.to(device)
20
- return model, preprocess_val, tokenizer, device
21
-
22
- clip_model, preprocess_val, tokenizer, device = load_clip_model()
23
-
24
- @st.cache_resource
25
- def load_onnx_model():
26
- session = ort.InferenceSession("./accessary_weights.onnx")
27
- return session
28
-
29
- onnx_session = load_onnx_model()
30
-
31
- def load_image_from_url(url, max_retries=3):
32
- for attempt in range(max_retries):
33
- try:
34
- response = requests.get(url, timeout=10)
35
- response.raise_for_status()
36
- img = Image.open(BytesIO(response.content)).convert('RGB')
37
- return img
38
- except (requests.RequestException, Image.UnidentifiedImageError) as e:
39
- if attempt < max_retries - 1:
40
- time.sleep(1)
41
- else:
42
- return None
43
-
44
- client = chromadb.PersistentClient(path="./accessaryDB")
45
- collection = client.get_collection(name="accessary_items_ver2")
46
-
47
- def get_image_embedding(image):
48
- image_tensor = preprocess_val(image).unsqueeze(0).to(device)
49
- with torch.no_grad():
50
- image_features = clip_model.encode_image(image_tensor)
51
- image_features /= image_features.norm(dim=-1, keepdim=True)
52
- return image_features.cpu().numpy()
53
-
54
- def get_text_embedding(text):
55
- text_tokens = tokenizer([text]).to(device)
56
- with torch.no_grad():
57
- text_features = clip_model.encode_text(text_tokens)
58
- text_features /= text_features.norm(dim=-1, keepdim=True)
59
- return text_features.cpu().numpy()
60
-
61
- def get_all_embeddings_from_collection(collection):
62
- all_embeddings = collection.get(include=['embeddings'])['embeddings']
63
- return np.array(all_embeddings)
64
-
65
- def get_metadata_from_ids(collection, ids):
66
- results = collection.get(ids=ids)
67
- return results['metadatas']
68
-
69
- def find_similar_images(query_embedding, collection, top_k=5):
70
- database_embeddings = get_all_embeddings_from_collection(collection)
71
- similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
72
- top_indices = np.argsort(similarities)[::-1][:top_k]
73
-
74
- all_data = collection.get(include=['metadatas'])['metadatas']
75
-
76
- top_metadatas = [all_data[idx] for idx in top_indices]
77
-
78
- results = []
79
- for idx, metadata in enumerate(top_metadatas):
80
- results.append({
81
- 'info': metadata,
82
- 'similarity': similarities[top_indices[idx]]
83
- })
84
- return results
85
-
86
- def detect_clothing_onnx(image):
87
- input_image = np.array(image.resize((640, 640)), dtype=np.float32)
88
- input_image = np.transpose(input_image, [2, 0, 1])
89
- input_image = np.expand_dims(input_image, axis=0)
90
- input_image /= 255.0
91
-
92
- inputs = {onnx_session.get_inputs()[0].name: input_image}
93
- outputs = onnx_session.run(None, inputs)
94
-
95
- detections = outputs[0]
96
- categories = []
97
- for detection in detections:
98
- x1, y1, x2, y2, conf, cls = detection
99
- category = str(int(cls))
100
- if category in ['Bracelets', 'Broches', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara']:
101
- categories.append({
102
- 'category': category,
103
- 'bbox': [int(x1), int(y1), int(x2), int(y2)],
104
- 'confidence': conf
105
- })
106
- return categories
107
-
108
- def crop_image(image, bbox):
109
- return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
110
-
111
- # 세션 상태 초기화
112
- if 'step' not in st.session_state:
113
- st.session_state.step = 'input'
114
- if 'query_image_url' not in st.session_state:
115
- st.session_state.query_image_url = ''
116
- if 'detections' not in st.session_state:
117
- st.session_state.detections = []
118
- if 'selected_category' not in st.session_state:
119
- st.session_state.selected_category = None
120
-
121
- # Streamlit app
122
- st.title("Advanced Fashion Search App")
123
-
124
- # 단계별 처리
125
- if st.session_state.step == 'input':
126
- st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
127
- if st.button("Detect Clothing"):
128
- if st.session_state.query_image_url:
129
- query_image = load_image_from_url(st.session_state.query_image_url)
130
- if query_image is not None:
131
- st.session_state.query_image = query_image
132
- st.session_state.detections = detect_clothing_onnx(query_image)
133
- if st.session_state.detections:
134
- st.session_state.step = 'select_category'
135
- else:
136
- st.warning("No clothing items detected in the image.")
137
- else:
138
- st.error("Failed to load the image. Please try another URL.")
139
- else:
140
- st.warning("Please enter an image URL.")
141
-
142
- # Update the 'select_category' step
143
- elif st.session_state.step == 'select_category':
144
- st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
145
- st.subheader("Detected Clothing Items:")
146
-
147
- for detection in st.session_state.detections:
148
- col1, col2 = st.columns([1, 3])
149
- with col1:
150
- st.write(f"{detection['category']} (Confidence: {detection['confidence']:.2f})")
151
- with col2:
152
- cropped_image = crop_image(st.session_state.query_image, detection['bbox'])
153
- st.image(cropped_image, caption=detection['category'], use_column_width=True)
154
-
155
- options = [f"{d['category']} (Confidence: {d['confidence']:.2f})" for d in st.session_state.detections]
156
- selected_option = st.selectbox("Select a category to search:", options)
157
-
158
- if st.button("Search Similar Items"):
159
- st.session_state.selected_category = selected_option
160
- st.session_state.step = 'show_results'
161
-
162
- elif st.session_state.step == 'show_results':
163
- st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
164
- selected_detection = next(d for d in st.session_state.detections
165
- if f"{d['category']} (Confidence: {d['confidence']:.2f})" == st.session_state.selected_category)
166
- cropped_image = crop_image(st.session_state.query_image, selected_detection['bbox'])
167
- st.image(cropped_image, caption="Cropped Image", use_column_width=True)
168
- query_embedding = get_image_embedding(cropped_image)
169
- similar_images = find_similar_images(query_embedding, collection)
170
-
171
- st.subheader("Similar Items:")
172
- for img in similar_images:
173
- col1, col2 = st.columns(2)
174
- with col1:
175
- st.image(img['info']['image_url'], use_column_width=True)
176
- with col2:
177
- st.write(f"Name: {img['info']['name']}")
178
- st.write(f"Brand: {img['info']['brand']}")
179
- st.write(f"Category: {img['info']['category']}")
180
- st.write(f"Price: {img['info']['price']}")
181
- st.write(f"Discount: {img['info']['discount']}%")
182
- st.write(f"Similarity: {img['similarity']:.2f}")
183
-
184
- if st.button("Start New Search"):
185
- st.session_state.step = 'input'
186
- st.session_state.query_image_url = ''
187
- st.session_state.detections = []
188
- st.session_state.selected_category = None
189
-
190
- else: # Text search
191
- query_text = st.text_input("Enter search text:")
192
- if st.button("Search by Text"):
193
- if query_text:
194
- text_embedding = get_text_embedding(query_text)
195
- similar_images = find_similar_images(text_embedding, collection)
196
- st.subheader("Similar Items:")
197
- for img in similar_images:
198
- col1, col2 = st.columns(2)
199
- with col1:
200
- st.image(img['info']['image_url'], use_column_width=True)
201
- with col2:
202
- st.write(f"Name: {img['info']['name']}")
203
- st.write(f"Brand: {img['info']['brand']}")
204
- st.write(f"Category: {img['info']['category']}")
205
- st.write(f"Price: {img['info']['price']}")
206
- st.write(f"Discount: {img['info']['discount']}%")
207
- st.write(f"Similarity: {img['similarity']:.2f}")
208
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  st.warning("Please enter a search text.")
 
1
+ import streamlit as st
2
+ import open_clip
3
+ import torch
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ import onnxruntime as ort
11
+ import cv2
12
+ import chromadb
13
+
14
+ @st.cache_resource
15
+ def load_clip_model():
16
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
17
+ tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model.to(device)
20
+ return model, preprocess_val, tokenizer, device
21
+
22
+ clip_model, preprocess_val, tokenizer, device = load_clip_model()
23
+
24
+ @st.cache_resource
25
+ def load_onnx_model():
26
+ session = ort.InferenceSession("./accessary_weights.onnx")
27
+ return session
28
+
29
+ onnx_session = load_onnx_model()
30
+
31
+ def load_image_from_url(url, max_retries=3):
32
+ for attempt in range(max_retries):
33
+ try:
34
+ response = requests.get(url, timeout=10)
35
+ response.raise_for_status()
36
+ img = Image.open(BytesIO(response.content)).convert('RGB')
37
+ return img
38
+ except (requests.RequestException, Image.UnidentifiedImageError) as e:
39
+ if attempt < max_retries - 1:
40
+ time.sleep(1)
41
+ else:
42
+ return None
43
+
44
+ client = chromadb.PersistentClient(path="./accessaryDB")
45
+ collection = client.get_collection(name="accessary_items_ver2")
46
+
47
+ def get_image_embedding(image):
48
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
49
+ with torch.no_grad():
50
+ image_features = clip_model.encode_image(image_tensor)
51
+ image_features /= image_features.norm(dim=-1, keepdim=True)
52
+ return image_features.cpu().numpy()
53
+
54
+ def get_text_embedding(text):
55
+ text_tokens = tokenizer([text]).to(device)
56
+ with torch.no_grad():
57
+ text_features = clip_model.encode_text(text_tokens)
58
+ text_features /= text_features.norm(dim=-1, keepdim=True)
59
+ return text_features.cpu().numpy()
60
+
61
+ def get_all_embeddings_from_collection(collection):
62
+ all_embeddings = collection.get(include=['embeddings'])['embeddings']
63
+ return np.array(all_embeddings)
64
+
65
+ def get_metadata_from_ids(collection, ids):
66
+ results = collection.get(ids=ids)
67
+ return results['metadatas']
68
+
69
+ def find_similar_images(query_embedding, collection, top_k=5):
70
+ database_embeddings = get_all_embeddings_from_collection(collection)
71
+ similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
72
+ top_indices = np.argsort(similarities)[::-1][:top_k]
73
+
74
+ all_data = collection.get(include=['metadatas'])['metadatas']
75
+
76
+ top_metadatas = [all_data[idx] for idx in top_indices]
77
+
78
+ results = []
79
+ for idx, metadata in enumerate(top_metadatas):
80
+ results.append({
81
+ 'info': metadata,
82
+ 'similarity': similarities[top_indices[idx]]
83
+ })
84
+ return results
85
+
86
+ onnx_model_labels = ['Bracelets', 'Broches', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara']
87
+
88
+ def preprocess_for_onnx(image, input_size=(640, 640)):
89
+ # 이미지 크기 조정 (ONNX 모델의 입력 크기에 맞춰 리사이즈)
90
+ resized_image = image.resize(input_size)
91
+
92
+ # 이미지를 NumPy 배열로 변환하고, 0~1 사이의 값으로 정규화 (필요한 경우)
93
+ image_np = np.array(resized_image).astype(np.float32) / 255.0
94
+
95
+ # 모델이 기대하는 순서대로 차원 변경 (예: HWC -> CHW)
96
+ image_np = np.transpose(image_np, (2, 0, 1)) # 채널 순서를 변경 (HWC -> CHW)
97
+
98
+ # 배치 차원을 추가 (ONNX 모델은 보통 [batch, channel, height, width] 형식을 요구)
99
+ input_tensor = np.expand_dims(image_np, axis=0)
100
+
101
+ return input_tensor
102
+
103
+ def detect_clothing_onnx(image):
104
+ # ONNX 모델로 이미지에서 객체 탐지 수행
105
+ input_tensor = preprocess_for_onnx(image)
106
+ outputs = onnx_session.run(None, {onnx_session.get_inputs()[0].name: input_tensor})
107
+
108
+ # 탐지된 객체에 대한 좌표 및 클래스 정보 추출
109
+ detections = outputs[0] # 모델의 출력 형식에 맞게 수정
110
+
111
+ categories = []
112
+ for detection in detections:
113
+ x1, y1, x2, y2, conf, cls = detection
114
+ category = onnx_model_labels[int(cls)] # 클래스 인덱스를 카테고리 이름으로 변환
115
+ if category in ['Bracelets', 'Broches', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara'] :
116
+ categories.append({
117
+ 'category': category,
118
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
119
+ 'confidence': conf
120
+ })
121
+ return categories
122
+
123
+ def crop_image(image, bbox):
124
+ return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
125
+
126
+ # 세션 상태 초기화
127
+ if 'step' not in st.session_state:
128
+ st.session_state.step = 'input'
129
+ if 'query_image_url' not in st.session_state:
130
+ st.session_state.query_image_url = ''
131
+ if 'detections' not in st.session_state:
132
+ st.session_state.detections = []
133
+ if 'selected_category' not in st.session_state:
134
+ st.session_state.selected_category = None
135
+
136
+ # Streamlit app
137
+ st.title("Advanced Fashion Search App")
138
+
139
+ # 단계별 처리
140
+ if st.session_state.step == 'input':
141
+ st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
142
+ if st.button("Detect Clothing"):
143
+ if st.session_state.query_image_url:
144
+ query_image = load_image_from_url(st.session_state.query_image_url)
145
+ if query_image is not None:
146
+ st.session_state.query_image = query_image
147
+ st.session_state.detections = detect_clothing_onnx(query_image)
148
+ if st.session_state.detections:
149
+ st.session_state.step = 'select_category'
150
+ else:
151
+ st.warning("No clothing items detected in the image.")
152
+ else:
153
+ st.error("Failed to load the image. Please try another URL.")
154
+ else:
155
+ st.warning("Please enter an image URL.")
156
+
157
+ # Update the 'select_category' step
158
+ elif st.session_state.step == 'select_category':
159
+ st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
160
+ st.subheader("Detected Clothing Items:")
161
+
162
+ for detection in st.session_state.detections:
163
+ col1, col2 = st.columns([1, 3])
164
+ with col1:
165
+ st.write(f"{detection['category']} (Confidence: {detection['confidence']:.2f})")
166
+ with col2:
167
+ cropped_image = crop_image(st.session_state.query_image, detection['bbox'])
168
+ st.image(cropped_image, caption=detection['category'], use_column_width=True)
169
+
170
+ options = [f"{d['category']} (Confidence: {d['confidence']:.2f})" for d in st.session_state.detections]
171
+ selected_option = st.selectbox("Select a category to search:", options)
172
+
173
+ if st.button("Search Similar Items"):
174
+ st.session_state.selected_category = selected_option
175
+ st.session_state.step = 'show_results'
176
+
177
+ elif st.session_state.step == 'show_results':
178
+ st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
179
+ selected_detection = next(d for d in st.session_state.detections
180
+ if f"{d['category']} (Confidence: {d['confidence']:.2f})" == st.session_state.selected_category)
181
+ cropped_image = crop_image(st.session_state.query_image, selected_detection['bbox'])
182
+ st.image(cropped_image, caption="Cropped Image", use_column_width=True)
183
+ query_embedding = get_image_embedding(cropped_image)
184
+ similar_images = find_similar_images(query_embedding, collection)
185
+
186
+ st.subheader("Similar Items:")
187
+ for img in similar_images:
188
+ col1, col2 = st.columns(2)
189
+ with col1:
190
+ st.image(img['info']['image_url'], use_column_width=True)
191
+ with col2:
192
+ st.write(f"Name: {img['info']['name']}")
193
+ st.write(f"Brand: {img['info']['brand']}")
194
+ st.write(f"Category: {img['info']['category']}")
195
+ st.write(f"Price: {img['info']['price']}")
196
+ st.write(f"Discount: {img['info']['discount']}%")
197
+ st.write(f"Similarity: {img['similarity']:.2f}")
198
+
199
+ if st.button("Start New Search"):
200
+ st.session_state.step = 'input'
201
+ st.session_state.query_image_url = ''
202
+ st.session_state.detections = []
203
+ st.session_state.selected_category = None
204
+
205
+ else: # Text search
206
+ query_text = st.text_input("Enter search text:")
207
+ if st.button("Search by Text"):
208
+ if query_text:
209
+ text_embedding = get_text_embedding(query_text)
210
+ similar_images = find_similar_images(text_embedding, collection)
211
+ st.subheader("Similar Items:")
212
+ for img in similar_images:
213
+ col1, col2 = st.columns(2)
214
+ with col1:
215
+ st.image(img['info']['image_url'], use_column_width=True)
216
+ with col2:
217
+ st.write(f"Name: {img['info']['name']}")
218
+ st.write(f"Brand: {img['info']['brand']}")
219
+ st.write(f"Category: {img['info']['category']}")
220
+ st.write(f"Price: {img['info']['price']}")
221
+ st.write(f"Discount: {img['info']['discount']}%")
222
+ st.write(f"Similarity: {img['similarity']:.2f}")
223
+ else:
224
  st.warning("Please enter a search text.")