leedoming commited on
Commit
44014d4
·
verified ·
1 Parent(s): e17e81a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -117
app.py CHANGED
@@ -7,68 +7,90 @@ from io import BytesIO
7
  import time
8
  import json
9
  import numpy as np
 
 
 
10
 
11
  # Load model and tokenizer
12
  @st.cache_resource
13
  def load_model():
14
- model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
15
- tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model.to(device)
18
  return model, preprocess_val, tokenizer, device
19
 
20
  model, preprocess_val, tokenizer, device = load_model()
21
 
22
- # Load and process data
23
- @st.cache_data
24
- def load_data():
25
- with open('./musinsa-final.json', 'r', encoding='utf-8') as f:
26
- return json.load(f)
 
27
 
28
- data = load_data()
29
-
30
- # Helper functions
31
- def load_image_from_url(url, max_retries=3):
32
- for attempt in range(max_retries):
33
- try:
34
- response = requests.get(url, timeout=10)
35
- response.raise_for_status()
36
- img = Image.open(BytesIO(response.content)).convert('RGB')
37
- return img
38
- except (requests.RequestException, Image.UnidentifiedImageError) as e:
39
- #st.warning(f"Attempt {attempt + 1} failed: {str(e)}")
40
- if attempt < max_retries - 1:
41
- time.sleep(1)
42
- else:
43
- #st.error(f"Failed to load image from {url} after {max_retries} attempts")
44
- return None
45
-
46
- def get_image_embedding_from_url(image_url):
47
- image = load_image_from_url(image_url)
48
- if image is None:
49
- return None
50
-
51
- image_tensor = preprocess_val(image).unsqueeze(0).to(device)
52
-
53
- with torch.no_grad():
54
- image_features = model.encode_image(image_tensor)
55
- image_features /= image_features.norm(dim=-1, keepdim=True)
56
-
57
- return image_features.cpu().numpy()
58
-
59
- @st.cache_data
60
- def process_database():
61
- database_embeddings = []
62
- database_info = []
63
-
64
- for item in data:
65
- image_url = item['이미지 링크'][0]
66
- embedding = get_image_embedding_from_url(image_url)
67
-
68
- if embedding is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  database_embeddings.append(embedding)
70
  database_info.append({
71
- 'id': item['\ufeff상품 ID'],
72
  'category': item['카테고리'],
73
  'brand': item['브랜드명'],
74
  'name': item['제품명'],
@@ -76,52 +98,43 @@ def process_database():
76
  'discount': item['할인율'],
77
  'image_url': image_url
78
  })
79
- else:
80
- st.warning(f"Skipping item {item['상품 ID']} due to image loading failure")
81
-
82
- if database_embeddings:
83
  return np.vstack(database_embeddings), database_info
84
- else:
85
- st.error("No valid embeddings were generated.")
86
- return None, None
87
-
88
- database_embeddings, database_info = process_database()
89
 
90
- def get_text_embedding(text):
91
- text_tokens = tokenizer([text]).to(device)
92
-
93
- with torch.no_grad():
94
- text_features = model.encode_text(text_tokens)
95
- text_features /= text_features.norm(dim=-1, keepdim=True)
96
-
97
- return text_features.cpu().numpy()
98
-
99
- def find_similar_images(query_embedding, top_k=5):
100
- similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
101
- top_indices = np.argsort(similarities)[::-1][:top_k]
102
-
103
- results = []
104
- for idx in top_indices:
105
- results.append({
106
- 'info': database_info[idx],
107
- 'similarity': similarities[idx]
108
- })
109
-
110
- return results
111
-
112
- # Streamlit app
113
- st.title("Fashion Search App")
114
-
115
- search_type = st.radio("Search by:", ("Image URL", "Text"))
116
-
117
- if search_type == "Image URL":
118
- query_image_url = st.text_input("Enter image URL:")
119
- if st.button("Search by Image"):
120
- if query_image_url:
121
- query_embedding = get_image_embedding_from_url(query_image_url)
122
- if query_embedding is not None:
123
  similar_images = find_similar_images(query_embedding)
124
- st.image(query_image_url, caption="Query Image", use_column_width=True)
125
  st.subheader("Similar Items:")
126
  for img in similar_images:
127
  col1, col2 = st.columns(2)
@@ -134,28 +147,5 @@ if search_type == "Image URL":
134
  st.write(f"Price: {img['info']['price']}")
135
  st.write(f"Discount: {img['info']['discount']}%")
136
  st.write(f"Similarity: {img['similarity']:.2f}")
137
- else:
138
- st.error("Failed to process the image. Please try another URL.")
139
- else:
140
- st.warning("Please enter an image URL.")
141
-
142
- else: # Text search
143
- query_text = st.text_input("Enter search text:")
144
- if st.button("Search by Text"):
145
- if query_text:
146
- text_embedding = get_text_embedding(query_text)
147
- similar_images = find_similar_images(text_embedding)
148
- st.subheader("Similar Items:")
149
- for img in similar_images:
150
- col1, col2 = st.columns(2)
151
- with col1:
152
- st.image(img['info']['image_url'], use_column_width=True)
153
- with col2:
154
- st.write(f"Name: {img['info']['name']}")
155
- st.write(f"Brand: {img['info']['brand']}")
156
- st.write(f"Category: {img['info']['category']}")
157
- st.write(f"Price: {img['info']['price']}")
158
- st.write(f"Discount: {img['info']['discount']}%")
159
- st.write(f"Similarity: {img['similarity']:.2f}")
160
- else:
161
- st.warning("Please enter a search text.")
 
7
  import time
8
  import json
9
  import numpy as np
10
+ import cv2
11
+ from inference_sdk import InferenceHTTPClient
12
+ import matplotlib.pyplot as plt
13
 
14
  # Load model and tokenizer
15
  @st.cache_resource
16
  def load_model():
17
+ model, preprocess_val, tokenizer = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.to(device)
20
  return model, preprocess_val, tokenizer, device
21
 
22
  model, preprocess_val, tokenizer, device = load_model()
23
 
24
+ # Roboflow client setup function
25
+ def setup_roboflow_client(api_key):
26
+ return InferenceHTTPClient(
27
+ api_url="https://outline.roboflow.com",
28
+ api_key=api_key
29
+ )
30
 
31
+ # Streamlit app
32
+ st.title("Fashion Search App with Segmentation")
33
+
34
+ # API Key input
35
+ api_key = st.text_input("Enter your Roboflow API Key", type="password")
36
+
37
+ if api_key:
38
+ CLIENT = setup_roboflow_client(api_key)
39
+
40
+ def segment_image(image_path):
41
+ results = CLIENT.infer(image_path, model_id="closet/1")
42
+ results = json.loads(results)
43
+
44
+ image = cv2.imread(image_path)
45
+ image = cv2.resize(image, (800, 600))
46
+ mask = np.zeros(image.shape, dtype=np.uint8)
47
+
48
+ for prediction in results['predictions']:
49
+ points = prediction['points']
50
+ pts = np.array([[p['x'], p['y']] for p in points], np.int32)
51
+ scale_x = image.shape[1] / results['image']['width']
52
+ scale_y = image.shape[0] / results['image']['height']
53
+ pts = pts * [scale_x, scale_y]
54
+ pts = pts.astype(np.int32)
55
+ pts = pts.reshape((-1, 1, 2))
56
+ cv2.fillPoly(mask, [pts], color=(255, 255, 255)) # White mask
57
+
58
+ segmented_image = cv2.bitwise_and(image, mask)
59
+ return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
60
+
61
+ def get_image_embedding(image):
62
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
63
+ with torch.no_grad():
64
+ image_features = model.encode_image(image_tensor)
65
+ image_features /= image_features.norm(dim=-1, keepdim=True)
66
+ return image_features.cpu().numpy()
67
+
68
+ # Load and process data
69
+ @st.cache_data
70
+ def load_data():
71
+ with open('musinsa-final.json', 'r', encoding='utf-8') as f:
72
+ return json.load(f)
73
+
74
+ data = load_data()
75
+
76
+ # Process database with segmentation
77
+ @st.cache_data
78
+ def process_database():
79
+ database_embeddings = []
80
+ database_info = []
81
+ for item in data:
82
+ image_url = item['이미지 링크'][0]
83
+ image_path = f"temp_{item['상품 ID']}.jpg"
84
+ response = requests.get(image_url)
85
+ with open(image_path, 'wb') as f:
86
+ f.write(response.content)
87
+
88
+ segmented_image = segment_image(image_path)
89
+ embedding = get_image_embedding(segmented_image)
90
+
91
  database_embeddings.append(embedding)
92
  database_info.append({
93
+ 'id': item['상품 ID'],
94
  'category': item['카테고리'],
95
  'brand': item['브랜드명'],
96
  'name': item['제품명'],
 
98
  'discount': item['할인율'],
99
  'image_url': image_url
100
  })
101
+
 
 
 
102
  return np.vstack(database_embeddings), database_info
 
 
 
 
 
103
 
104
+ database_embeddings, database_info = process_database()
105
+
106
+ def find_similar_images(query_embedding, top_k=5):
107
+ similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
108
+ top_indices = np.argsort(similarities)[::-1][:top_k]
109
+
110
+ results = []
111
+ for idx in top_indices:
112
+ results.append({
113
+ 'info': database_info[idx],
114
+ 'similarity': similarities[idx]
115
+ })
116
+
117
+ return results
118
+
119
+ uploaded_file = st.file_uploader("Choose an image...", type="jpg")
120
+ if uploaded_file is not None:
121
+ image = Image.open(uploaded_file)
122
+ st.image(image, caption='Uploaded Image', use_column_width=True)
123
+
124
+ if st.button('Find Similar Items'):
125
+ with st.spinner('Processing...'):
126
+ # Save uploaded image temporarily
127
+ temp_path = "temp_upload.jpg"
128
+ image.save(temp_path)
129
+
130
+ # Segment the uploaded image
131
+ segmented_image = segment_image(temp_path)
132
+ st.image(segmented_image, caption='Segmented Image', use_column_width=True)
133
+
134
+ # Get embedding for segmented image
135
+ query_embedding = get_image_embedding(segmented_image)
 
136
  similar_images = find_similar_images(query_embedding)
137
+
138
  st.subheader("Similar Items:")
139
  for img in similar_images:
140
  col1, col2 = st.columns(2)
 
147
  st.write(f"Price: {img['info']['price']}")
148
  st.write(f"Discount: {img['info']['discount']}%")
149
  st.write(f"Similarity: {img['similarity']:.2f}")
150
+ else:
151
+ st.warning("Please enter your Roboflow API Key to use the app.")