JoJosmin commited on
Commit
b646871
·
verified ·
1 Parent(s): 2c7cf10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -20
app.py CHANGED
@@ -11,6 +11,7 @@ import onnxruntime as ort
11
  import cv2
12
  import chromadb
13
 
 
14
  @st.cache_resource
15
  def load_clip_model():
16
  model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
@@ -21,6 +22,7 @@ def load_clip_model():
21
 
22
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
23
 
 
24
  @st.cache_resource
25
  def load_onnx_model():
26
  session = ort.InferenceSession("./accessary_weights.onnx")
@@ -28,6 +30,7 @@ def load_onnx_model():
28
 
29
  onnx_session = load_onnx_model()
30
 
 
31
  def load_image_from_url(url, max_retries=3):
32
  for attempt in range(max_retries):
33
  try:
@@ -41,9 +44,11 @@ def load_image_from_url(url, max_retries=3):
41
  else:
42
  return None
43
 
 
44
  client = chromadb.PersistentClient(path="./accessaryDB")
45
  collection = client.get_collection(name="accessary_items_ver2")
46
 
 
47
  def get_image_embedding(image):
48
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
49
  with torch.no_grad():
@@ -51,6 +56,7 @@ def get_image_embedding(image):
51
  image_features /= image_features.norm(dim=-1, keepdim=True)
52
  return image_features.cpu().numpy()
53
 
 
54
  def get_text_embedding(text):
55
  text_tokens = tokenizer([text]).to(device)
56
  with torch.no_grad():
@@ -58,14 +64,17 @@ def get_text_embedding(text):
58
  text_features /= text_features.norm(dim=-1, keepdim=True)
59
  return text_features.cpu().numpy()
60
 
 
61
  def get_all_embeddings_from_collection(collection):
62
  all_embeddings = collection.get(include=['embeddings'])['embeddings']
63
  return np.array(all_embeddings)
64
 
 
65
  def get_metadata_from_ids(collection, ids):
66
  results = collection.get(ids=ids)
67
  return results['metadatas']
68
 
 
69
  def find_similar_images(query_embedding, collection, top_k=5):
70
  database_embeddings = get_all_embeddings_from_collection(collection)
71
  similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
@@ -85,41 +94,35 @@ def find_similar_images(query_embedding, collection, top_k=5):
85
 
86
  onnx_model_labels = ['Bracelets', 'Broches', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara']
87
 
 
88
  def preprocess_for_onnx(image, input_size=(640, 640)):
89
- # 이미지 크기 조정 (ONNX 모델의 입력 크기에 맞춰 리사이즈)
90
  resized_image = image.resize(input_size)
91
-
92
- # 이미지를 NumPy 배열로 변환하고, 0~1 사이의 값으로 정규화 (필요한 경우)
93
  image_np = np.array(resized_image).astype(np.float32) / 255.0
94
-
95
- # 모델이 기대하는 순서대로 차원 변경 (예: HWC -> CHW)
96
- image_np = np.transpose(image_np, (2, 0, 1)) # 채널 순서를 변경 (HWC -> CHW)
97
-
98
- # 배치 차원을 추가 (ONNX 모델은 보통 [batch, channel, height, width] 형식을 요구)
99
  input_tensor = np.expand_dims(image_np, axis=0)
100
-
101
  return input_tensor
102
 
 
103
  def detect_clothing_onnx(image):
104
- # ONNX 모델로 이미지에서 객체 탐지 수행
105
- input_tensor = preprocess_for_onnx(image)
106
  outputs = onnx_session.run(None, {onnx_session.get_inputs()[0].name: input_tensor})
107
 
108
- # 탐지된 객체에 대한 좌표 클래스 정보 추출
109
- detections = outputs[0] # 모델의 출력 형식에 맞게 수정
110
-
111
  categories = []
 
112
  for detection in detections:
113
- x1, y1, x2, y2, conf, cls = detection
114
- category = onnx_model_labels[int(cls)] # 클래스 인덱스를 카테고리 이름으로 변환
115
- if category in ['Bracelets', 'Broches', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara'] :
116
- categories.append({
117
  'category': category,
118
- 'bbox': [int(x1), int(y1), int(x2), int(y2)],
119
  'confidence': conf
120
  })
 
121
  return categories
122
 
 
123
  def crop_image(image, bbox):
124
  return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
125
 
@@ -154,7 +157,6 @@ if st.session_state.step == 'input':
154
  else:
155
  st.warning("Please enter an image URL.")
156
 
157
- # Update the 'select_category' step
158
  elif st.session_state.step == 'select_category':
159
  st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
160
  st.subheader("Detected Clothing Items:")
@@ -202,6 +204,7 @@ elif st.session_state.step == 'show_results':
202
  st.session_state.detections = []
203
  st.session_state.selected_category = None
204
 
 
205
  else: # Text search
206
  query_text = st.text_input("Enter search text:")
207
  if st.button("Search by Text"):
 
11
  import cv2
12
  import chromadb
13
 
14
+ # CLIP 모델 로드
15
  @st.cache_resource
16
  def load_clip_model():
17
  model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
 
22
 
23
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
24
 
25
+ # ONNX 모델 로드
26
  @st.cache_resource
27
  def load_onnx_model():
28
  session = ort.InferenceSession("./accessary_weights.onnx")
 
30
 
31
  onnx_session = load_onnx_model()
32
 
33
+ # URL에서 이미지 로드
34
  def load_image_from_url(url, max_retries=3):
35
  for attempt in range(max_retries):
36
  try:
 
44
  else:
45
  return None
46
 
47
+ # ChromaDB 클라이언트 설정
48
  client = chromadb.PersistentClient(path="./accessaryDB")
49
  collection = client.get_collection(name="accessary_items_ver2")
50
 
51
+ # CLIP 이미지 임베딩 추출
52
  def get_image_embedding(image):
53
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
54
  with torch.no_grad():
 
56
  image_features /= image_features.norm(dim=-1, keepdim=True)
57
  return image_features.cpu().numpy()
58
 
59
+ # CLIP 텍스트 임베딩 추출
60
  def get_text_embedding(text):
61
  text_tokens = tokenizer([text]).to(device)
62
  with torch.no_grad():
 
64
  text_features /= text_features.norm(dim=-1, keepdim=True)
65
  return text_features.cpu().numpy()
66
 
67
+ # 컬렉션에서 모든 임베딩 가져오기
68
  def get_all_embeddings_from_collection(collection):
69
  all_embeddings = collection.get(include=['embeddings'])['embeddings']
70
  return np.array(all_embeddings)
71
 
72
+ # ID를 통해 메타데이터 가져오기
73
  def get_metadata_from_ids(collection, ids):
74
  results = collection.get(ids=ids)
75
  return results['metadatas']
76
 
77
+ # 유사 이미지 찾기
78
  def find_similar_images(query_embedding, collection, top_k=5):
79
  database_embeddings = get_all_embeddings_from_collection(collection)
80
  similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
 
94
 
95
  onnx_model_labels = ['Bracelets', 'Broches', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara']
96
 
97
+ # ONNX 모델에 맞춘 전처리 함수
98
  def preprocess_for_onnx(image, input_size=(640, 640)):
 
99
  resized_image = image.resize(input_size)
 
 
100
  image_np = np.array(resized_image).astype(np.float32) / 255.0
101
+ image_np = np.transpose(image_np, (2, 0, 1))
 
 
 
 
102
  input_tensor = np.expand_dims(image_np, axis=0)
 
103
  return input_tensor
104
 
105
+ # 의류 탐지
106
  def detect_clothing_onnx(image):
107
+ input_tensor = preprocess_for_onnx(image) # 전처리 함수 호출
 
108
  outputs = onnx_session.run(None, {onnx_session.get_inputs()[0].name: input_tensor})
109
 
110
+ detections = outputs[0] # 번째 출력값이 탐지 결과라고 가정
 
 
111
  categories = []
112
+
113
  for detection in detections:
114
+ x1, y1, x2, y2, conf, cls = detection[:6] # 필요한 값만 추출
115
+ if conf > 0.5: # 신뢰도 임계값 설정
116
+ category = onnx_model_labels[int(cls)]
117
+ categories.append({
118
  'category': category,
119
+ 'bbox': [x1, y1, x2, y2],
120
  'confidence': conf
121
  })
122
+
123
  return categories
124
 
125
+ # 이미지 자르기
126
  def crop_image(image, bbox):
127
  return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
128
 
 
157
  else:
158
  st.warning("Please enter an image URL.")
159
 
 
160
  elif st.session_state.step == 'select_category':
161
  st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
162
  st.subheader("Detected Clothing Items:")
 
204
  st.session_state.detections = []
205
  st.session_state.selected_category = None
206
 
207
+
208
  else: # Text search
209
  query_text = st.text_input("Enter search text:")
210
  if st.button("Search by Text"):