File size: 8,015 Bytes
5db0821
 
 
 
 
 
 
 
 
44014d4
7d37333
44014d4
7d37333
b404c7a
 
7d37333
5db0821
 
 
44014d4
5db0821
 
 
 
 
 
fcd89ff
 
 
 
 
 
 
 
b404c7a
 
 
 
 
 
4e2381b
 
 
b404c7a
4e2381b
 
 
 
 
b404c7a
4e2381b
 
b404c7a
dfc8c96
 
 
 
 
 
 
 
 
 
 
 
ed430f5
 
 
 
 
b404c7a
ed430f5
 
dfc8c96
 
b404c7a
 
dfc8c96
 
 
b404c7a
 
 
 
 
dfc8c96
 
 
 
 
 
b404c7a
dfc8c96
 
b404c7a
 
 
 
 
 
 
 
dfc8c96
fcd89ff
4737554
fcd89ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4737554
 
fcd89ff
 
4737554
 
 
 
b404c7a
 
4737554
b404c7a
4737554
b404c7a
 
 
 
 
 
 
 
 
 
4737554
 
b404c7a
4737554
fcd89ff
 
b404c7a
 
 
 
 
 
 
 
 
 
 
 
 
44014d4
b404c7a
44014d4
 
 
 
 
 
dfc8c96
 
4737554
44014d4
 
 
 
 
 
 
 
 
 
 
b404c7a
44014d4
 
b404c7a
 
 
 
44014d4
b404c7a
44014d4
5db0821
 
 
 
 
 
 
 
 
 
 
 
b404c7a
 
 
 
44014d4
4e2381b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import streamlit as st
import open_clip
import torch
import requests
from PIL import Image
from io import BytesIO
import time
import json
import numpy as np
import cv2
from inference_sdk import InferenceHTTPClient
import matplotlib.pyplot as plt
import base64
import os
import pickle

# Load model and tokenizer
@st.cache_resource
def load_model():
    model, preprocess_val, tokenizer = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model, preprocess_val, tokenizer, device

model, preprocess_val, tokenizer, device = load_model()

# Load and process data
@st.cache_data
def load_data():
    with open('musinsa-final.json', 'r', encoding='utf-8') as f:
        return json.load(f)

data = load_data()

def setup_roboflow_client(api_key):
    return InferenceHTTPClient(
        api_url="https://outline.roboflow.com",
        api_key=api_key
    )

def download_and_process_image(image_url):
    try:
        response = requests.get(image_url)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
        if image.mode == 'RGBA':
            image = image.convert('RGB')
        return image
    except Exception as e:
        st.error(f"Error downloading/processing image: {str(e)}")
        return None

def segment_image_and_get_categories(image_path, client):
    try:
        with open(image_path, "rb") as image_file:
            image_data = image_file.read()
        
        encoded_image = base64.b64encode(image_data).decode('utf-8')
        
        image = cv2.imread(image_path)
        image = cv2.resize(image, (800, 600))
        mask = np.zeros(image.shape, dtype=np.uint8)
        
        results = client.infer(encoded_image, model_id="closet/1")
        
        if isinstance(results, dict):
            predictions = results.get('predictions', [])
        else:
            predictions = json.loads(results).get('predictions', [])
        
        categories = []
        if predictions:
            for prediction in predictions:
                points = prediction['points']
                pts = np.array([[p['x'], p['y']] for p in points], np.int32)
                scale_x = image.shape[1] / results.get('image', {}).get('width', 1)
                scale_y = image.shape[0] / results.get('image', {}).get('height', 1)
                pts = pts * [scale_x, scale_y]
                pts = pts.astype(np.int32)
                pts = pts.reshape((-1, 1, 2))
                cv2.fillPoly(mask, [pts], color=(255, 255, 255))
                
                category = prediction.get('class', 'Unknown')
                confidence = prediction.get('confidence', 0)
                categories.append(f"{category} ({confidence:.2f})")
            
            segmented_image = cv2.bitwise_and(image, mask)
        else:
            st.warning("No predictions found in the image. Returning original image.")
            segmented_image = image
        
        return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB)), categories
    except Exception as e:
        st.error(f"Error in segmentation: {str(e)}")
        return Image.open(image_path), []

def get_image_embedding(image):
    image_tensor = preprocess_val(image).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image_tensor)
        image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features.cpu().numpy()

@st.cache_data
def process_database_cached(data):
    database_info = []
    for item in data:
        image_url = item['์ด๋ฏธ์ง€ ๋งํฌ'][0]
        product_id = item.get('\ufeff์ƒํ’ˆ ID') or item.get('์ƒํ’ˆ ID')
        
        image = download_and_process_image(image_url)
        if image is None:
            continue
        
        temp_path = f"temp_{product_id}.jpg"
        image.save(temp_path, 'JPEG')
        
        database_info.append({
            'id': product_id,
            'category': item['์นดํ…Œ๊ณ ๋ฆฌ'],
            'brand': item['๋ธŒ๋žœ๋“œ๋ช…'],
            'name': item['์ œํ’ˆ๋ช…'],
            'price': item['์ •๊ฐ€'],
            'discount': item['ํ• ์ธ์œจ'],
            'image_url': image_url,
            'temp_path': temp_path
        })
    
    return database_info

def process_database(client, data):
    database_info = process_database_cached(data)
    cache_dir = "segmentation_cache"
    os.makedirs(cache_dir, exist_ok=True)
    
    database_embeddings = []
    for item in database_info:
        cache_file = os.path.join(cache_dir, f"{item['id']}_segmented.pkl")
        
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as f:
                segmented_image, categories = pickle.load(f)
        else:
            segmented_image, categories = segment_image_and_get_categories(item['temp_path'], client)
            with open(cache_file, 'wb') as f:
                pickle.dump((segmented_image, categories), f)
        
        embedding = get_image_embedding(segmented_image)
        database_embeddings.append(embedding)
        item['categories'] = categories
    
    return np.vstack(database_embeddings), database_info

def find_similar_images(query_embedding, database_embeddings, database_info, top_k=5):
    similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
    top_indices = np.argsort(similarities)[::-1][:top_k]
    
    results = []
    for idx in top_indices:
        results.append({
            'info': database_info[idx],
            'similarity': similarities[idx]
        })
    
    return results

# Streamlit app
st.title("Fashion Search App with Segmentation and Category Detection")

# API Key input
api_key = st.text_input("Enter your Roboflow API Key", type="password")

if api_key:
    CLIENT = setup_roboflow_client(api_key)
    
    # Initialize database_embeddings and database_info
    database_embeddings, database_info = process_database(CLIENT, data)

    uploaded_file = st.file_uploader("Choose an image...", type="jpg")
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        st.image(image, caption='Uploaded Image', use_column_width=True)
        
        if st.button('Find Similar Items'):
            with st.spinner('Processing...'):
                temp_path = "temp_upload.jpg"
                image.save(temp_path)
                
                segmented_image, input_categories = segment_image_and_get_categories(temp_path, CLIENT)
                st.image(segmented_image, caption='Segmented Image', use_column_width=True)
                
                st.subheader("Detected Categories in Input Image:")
                for category in input_categories:
                    st.write(category)
                
                query_embedding = get_image_embedding(segmented_image)
                similar_images = find_similar_images(query_embedding, database_embeddings, database_info)
                
                st.subheader("Similar Items:")
                for img in similar_images:
                    col1, col2 = st.columns(2)
                    with col1:
                        st.image(img['info']['image_url'], use_column_width=True)
                    with col2:
                        st.write(f"Name: {img['info']['name']}")
                        st.write(f"Brand: {img['info']['brand']}")
                        st.write(f"Category: {img['info']['category']}")
                        st.write(f"Price: {img['info']['price']}")
                        st.write(f"Discount: {img['info']['discount']}%")
                        st.write(f"Similarity: {img['similarity']:.2f}")
                        
                        st.write("Detected Categories:")
                        for category in img['info']['categories']:
                            st.write(category)
else:
    st.warning("Please enter your Roboflow API Key to use the app.")