Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ import onnxruntime as ort
|
|
11 |
import cv2
|
12 |
import chromadb
|
13 |
|
|
|
14 |
@st.cache_resource
|
15 |
def load_clip_model():
|
16 |
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
|
@@ -21,6 +22,7 @@ def load_clip_model():
|
|
21 |
|
22 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
23 |
|
|
|
24 |
@st.cache_resource
|
25 |
def load_onnx_model():
|
26 |
session = ort.InferenceSession("./accessary_weights.onnx")
|
@@ -28,6 +30,7 @@ def load_onnx_model():
|
|
28 |
|
29 |
onnx_session = load_onnx_model()
|
30 |
|
|
|
31 |
def load_image_from_url(url, max_retries=3):
|
32 |
for attempt in range(max_retries):
|
33 |
try:
|
@@ -41,9 +44,11 @@ def load_image_from_url(url, max_retries=3):
|
|
41 |
else:
|
42 |
return None
|
43 |
|
|
|
44 |
client = chromadb.PersistentClient(path="./accessaryDB")
|
45 |
collection = client.get_collection(name="accessary_items_ver2")
|
46 |
|
|
|
47 |
def get_image_embedding(image):
|
48 |
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
49 |
with torch.no_grad():
|
@@ -51,6 +56,7 @@ def get_image_embedding(image):
|
|
51 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
52 |
return image_features.cpu().numpy()
|
53 |
|
|
|
54 |
def get_text_embedding(text):
|
55 |
text_tokens = tokenizer([text]).to(device)
|
56 |
with torch.no_grad():
|
@@ -58,14 +64,17 @@ def get_text_embedding(text):
|
|
58 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
59 |
return text_features.cpu().numpy()
|
60 |
|
|
|
61 |
def get_all_embeddings_from_collection(collection):
|
62 |
all_embeddings = collection.get(include=['embeddings'])['embeddings']
|
63 |
return np.array(all_embeddings)
|
64 |
|
|
|
65 |
def get_metadata_from_ids(collection, ids):
|
66 |
results = collection.get(ids=ids)
|
67 |
return results['metadatas']
|
68 |
|
|
|
69 |
def find_similar_images(query_embedding, collection, top_k=5):
|
70 |
database_embeddings = get_all_embeddings_from_collection(collection)
|
71 |
similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
|
@@ -85,41 +94,35 @@ def find_similar_images(query_embedding, collection, top_k=5):
|
|
85 |
|
86 |
onnx_model_labels = ['Bracelets', 'Broches', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara']
|
87 |
|
|
|
88 |
def preprocess_for_onnx(image, input_size=(640, 640)):
|
89 |
-
# 이미지 크기 조정 (ONNX 모델의 입력 크기에 맞춰 리사이즈)
|
90 |
resized_image = image.resize(input_size)
|
91 |
-
|
92 |
-
# 이미지를 NumPy 배열로 변환하고, 0~1 사이의 값으로 정규화 (필요한 경우)
|
93 |
image_np = np.array(resized_image).astype(np.float32) / 255.0
|
94 |
-
|
95 |
-
# 모델이 기대하는 순서대로 차원 변경 (예: HWC -> CHW)
|
96 |
-
image_np = np.transpose(image_np, (2, 0, 1)) # 채널 순서를 변경 (HWC -> CHW)
|
97 |
-
|
98 |
-
# 배치 차원을 추가 (ONNX 모델은 보통 [batch, channel, height, width] 형식을 요구)
|
99 |
input_tensor = np.expand_dims(image_np, axis=0)
|
100 |
-
|
101 |
return input_tensor
|
102 |
|
|
|
103 |
def detect_clothing_onnx(image):
|
104 |
-
|
105 |
-
input_tensor = preprocess_for_onnx(image)
|
106 |
outputs = onnx_session.run(None, {onnx_session.get_inputs()[0].name: input_tensor})
|
107 |
|
108 |
-
|
109 |
-
detections = outputs[0] # 모델의 출력 형식에 맞게 수정
|
110 |
-
|
111 |
categories = []
|
|
|
112 |
for detection in detections:
|
113 |
-
x1, y1, x2, y2, conf, cls = detection
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
'category': category,
|
118 |
-
'bbox': [
|
119 |
'confidence': conf
|
120 |
})
|
|
|
121 |
return categories
|
122 |
|
|
|
123 |
def crop_image(image, bbox):
|
124 |
return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
125 |
|
@@ -154,7 +157,6 @@ if st.session_state.step == 'input':
|
|
154 |
else:
|
155 |
st.warning("Please enter an image URL.")
|
156 |
|
157 |
-
# Update the 'select_category' step
|
158 |
elif st.session_state.step == 'select_category':
|
159 |
st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
|
160 |
st.subheader("Detected Clothing Items:")
|
@@ -202,6 +204,7 @@ elif st.session_state.step == 'show_results':
|
|
202 |
st.session_state.detections = []
|
203 |
st.session_state.selected_category = None
|
204 |
|
|
|
205 |
else: # Text search
|
206 |
query_text = st.text_input("Enter search text:")
|
207 |
if st.button("Search by Text"):
|
|
|
11 |
import cv2
|
12 |
import chromadb
|
13 |
|
14 |
+
# CLIP 모델 로드
|
15 |
@st.cache_resource
|
16 |
def load_clip_model():
|
17 |
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
|
|
|
22 |
|
23 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
24 |
|
25 |
+
# ONNX 모델 로드
|
26 |
@st.cache_resource
|
27 |
def load_onnx_model():
|
28 |
session = ort.InferenceSession("./accessary_weights.onnx")
|
|
|
30 |
|
31 |
onnx_session = load_onnx_model()
|
32 |
|
33 |
+
# URL에서 이미지 로드
|
34 |
def load_image_from_url(url, max_retries=3):
|
35 |
for attempt in range(max_retries):
|
36 |
try:
|
|
|
44 |
else:
|
45 |
return None
|
46 |
|
47 |
+
# ChromaDB 클라이언트 설정
|
48 |
client = chromadb.PersistentClient(path="./accessaryDB")
|
49 |
collection = client.get_collection(name="accessary_items_ver2")
|
50 |
|
51 |
+
# CLIP 이미지 임베딩 추출
|
52 |
def get_image_embedding(image):
|
53 |
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
54 |
with torch.no_grad():
|
|
|
56 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
57 |
return image_features.cpu().numpy()
|
58 |
|
59 |
+
# CLIP 텍스트 임베딩 추출
|
60 |
def get_text_embedding(text):
|
61 |
text_tokens = tokenizer([text]).to(device)
|
62 |
with torch.no_grad():
|
|
|
64 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
65 |
return text_features.cpu().numpy()
|
66 |
|
67 |
+
# 컬렉션에서 모든 임베딩 가져오기
|
68 |
def get_all_embeddings_from_collection(collection):
|
69 |
all_embeddings = collection.get(include=['embeddings'])['embeddings']
|
70 |
return np.array(all_embeddings)
|
71 |
|
72 |
+
# ID를 통해 메타데이터 가져오기
|
73 |
def get_metadata_from_ids(collection, ids):
|
74 |
results = collection.get(ids=ids)
|
75 |
return results['metadatas']
|
76 |
|
77 |
+
# 유사 이미지 찾기
|
78 |
def find_similar_images(query_embedding, collection, top_k=5):
|
79 |
database_embeddings = get_all_embeddings_from_collection(collection)
|
80 |
similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
|
|
|
94 |
|
95 |
onnx_model_labels = ['Bracelets', 'Broches', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara']
|
96 |
|
97 |
+
# ONNX 모델에 맞춘 전처리 함수
|
98 |
def preprocess_for_onnx(image, input_size=(640, 640)):
|
|
|
99 |
resized_image = image.resize(input_size)
|
|
|
|
|
100 |
image_np = np.array(resized_image).astype(np.float32) / 255.0
|
101 |
+
image_np = np.transpose(image_np, (2, 0, 1))
|
|
|
|
|
|
|
|
|
102 |
input_tensor = np.expand_dims(image_np, axis=0)
|
|
|
103 |
return input_tensor
|
104 |
|
105 |
+
# 의류 탐지
|
106 |
def detect_clothing_onnx(image):
|
107 |
+
input_tensor = preprocess_for_onnx(image) # 전처리 함수 호출
|
|
|
108 |
outputs = onnx_session.run(None, {onnx_session.get_inputs()[0].name: input_tensor})
|
109 |
|
110 |
+
detections = outputs[0] # 첫 번째 출력값이 탐지 결과라고 가정
|
|
|
|
|
111 |
categories = []
|
112 |
+
|
113 |
for detection in detections:
|
114 |
+
x1, y1, x2, y2, conf, cls = detection[:6] # 필요한 값만 추출
|
115 |
+
if conf > 0.5: # 신뢰도 임계값 설정
|
116 |
+
category = onnx_model_labels[int(cls)]
|
117 |
+
categories.append({
|
118 |
'category': category,
|
119 |
+
'bbox': [x1, y1, x2, y2],
|
120 |
'confidence': conf
|
121 |
})
|
122 |
+
|
123 |
return categories
|
124 |
|
125 |
+
# 이미지 자르기
|
126 |
def crop_image(image, bbox):
|
127 |
return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
128 |
|
|
|
157 |
else:
|
158 |
st.warning("Please enter an image URL.")
|
159 |
|
|
|
160 |
elif st.session_state.step == 'select_category':
|
161 |
st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
|
162 |
st.subheader("Detected Clothing Items:")
|
|
|
204 |
st.session_state.detections = []
|
205 |
st.session_state.selected_category = None
|
206 |
|
207 |
+
|
208 |
else: # Text search
|
209 |
query_text = st.text_input("Enter search text:")
|
210 |
if st.button("Search by Text"):
|