initial
Browse files- app.py +232 -0
- data/bag_1.jpg +0 -0
- data/bag_2.jpg +0 -0
- data/bag_3.jpg +0 -0
- data/bag_4.jpg +0 -0
- data/bag_5.jpg +0 -0
- data/dress_1.jpg +0 -0
- data/dress_2.jpg +0 -0
- data/dress_3.jpg +0 -0
- data/dress_4.jpg +0 -0
- data/dress_5.jpg +0 -0
- data/shoes_1.jpg +0 -0
- data/shoes_2.jpg +0 -0
- data/shoes_3.jpg +0 -0
- data/shoes_4.jpg +0 -0
- data/shoes_5.jpg +0 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# helper functions to get segmented mask
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
url = "http://static.okkular.io/scripted.model"
|
8 |
+
|
9 |
+
output_file = "./scripted.model"
|
10 |
+
with urllib.request.urlopen(url) as response, open(output_file, 'wb') as out_file:
|
11 |
+
shutil.copyfileobj(response, out_file)
|
12 |
+
|
13 |
+
def get_stl(input_sku):
|
14 |
+
preds=shop_the_look(f'./data/dress_{input_sku}.jpg')
|
15 |
+
ret_bag = preds['./segs/bag.jpg'][1]
|
16 |
+
ret_shoes = preds['./segs/shoe.jpg'][1]
|
17 |
+
return Image.open(f'./data/dress_{input_sku}.jpg'), Image.open(ret_bag), Image.open(ret_shoes)
|
18 |
+
|
19 |
+
sku = gr.Dropdown(
|
20 |
+
["1", "2", "3", '4', '5'], label="Dress Sku",
|
21 |
+
),
|
22 |
+
|
23 |
+
|
24 |
+
demo = gr.Interface(get_stl, gr.Dropdown(
|
25 |
+
["1", "2", "3", '4', '5'], label="Dress Sku"), ["image", "image", "image"])
|
26 |
+
|
27 |
+
|
28 |
+
demo.launch(root_path=f"/{os.getenv('TOKEN')}")
|
29 |
+
|
30 |
+
from PIL import Image, ImageChops
|
31 |
+
import numpy as np
|
32 |
+
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
|
33 |
+
from PIL import Image
|
34 |
+
import requests
|
35 |
+
import matplotlib.pyplot as plt
|
36 |
+
import torch
|
37 |
+
import torch.nn as nn
|
38 |
+
import os
|
39 |
+
import nmslib
|
40 |
+
from fastai.vision.all import *
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def get_segment(image, num,ret=False):
|
45 |
+
|
46 |
+
extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
|
47 |
+
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
|
48 |
+
|
49 |
+
inputs = extractor(images=image, return_tensors="pt")
|
50 |
+
|
51 |
+
outputs = model(**inputs)
|
52 |
+
logits = outputs.logits.cpu()
|
53 |
+
|
54 |
+
upsampled_logits = nn.functional.interpolate(
|
55 |
+
logits,
|
56 |
+
size=image.size[::-1],
|
57 |
+
mode="bilinear",
|
58 |
+
align_corners=False,
|
59 |
+
)
|
60 |
+
|
61 |
+
#pred_seg = upsampled_logits.argmax(dim=1)[0]
|
62 |
+
|
63 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
64 |
+
np_im = np.array(image)
|
65 |
+
pred_seg[pred_seg != num] = 0
|
66 |
+
mask = pred_seg.detach().cpu().numpy()
|
67 |
+
|
68 |
+
# masked region
|
69 |
+
np_im[mask.squeeze()==0] = 0
|
70 |
+
# white bg
|
71 |
+
np_im[np.where((np_im==[0,0,0]).all(axis=2))] = [255,255,255]
|
72 |
+
|
73 |
+
# trim extra whitespace
|
74 |
+
im = Image.fromarray(np.uint8(np_im)).convert('RGB')
|
75 |
+
im = trim(im)
|
76 |
+
|
77 |
+
if ret==False:
|
78 |
+
plt.imshow(im)
|
79 |
+
plt.show()
|
80 |
+
elif ret==True:
|
81 |
+
print('here and returning', im)
|
82 |
+
return im
|
83 |
+
|
84 |
+
|
85 |
+
def trim(im):
|
86 |
+
bg = Image.new(im.mode, im.size, im.getpixel((0,0)))
|
87 |
+
diff = ImageChops.difference(im, bg)
|
88 |
+
diff = ImageChops.add(diff, diff, 2.0, -100)
|
89 |
+
bbox = diff.getbbox()
|
90 |
+
if bbox:
|
91 |
+
return im.crop(bbox)
|
92 |
+
|
93 |
+
|
94 |
+
def get_pred_seg(image_url):
|
95 |
+
extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
|
96 |
+
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
|
97 |
+
|
98 |
+
|
99 |
+
image = Image.open(image_url)
|
100 |
+
inputs = extractor(images=image, return_tensors="pt")
|
101 |
+
|
102 |
+
outputs = model(**inputs)
|
103 |
+
logits = outputs.logits.cpu()
|
104 |
+
|
105 |
+
upsampled_logits = nn.functional.interpolate(
|
106 |
+
logits,
|
107 |
+
size=image.size[::-1],
|
108 |
+
mode="bilinear",
|
109 |
+
align_corners=False,
|
110 |
+
)
|
111 |
+
|
112 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
113 |
+
#plt.imshow(pred_seg)
|
114 |
+
return upsampled_logits,pred_seg
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
#### get predictions and neighbours
|
120 |
+
|
121 |
+
|
122 |
+
def get_predictions(feed):
|
123 |
+
pairs = [[x["sku"], x["category"]] for x in feed]
|
124 |
+
skus, labels = zip(*pairs)
|
125 |
+
labels = np.array(labels)
|
126 |
+
skus = np.array(skus)
|
127 |
+
categories = list(set(labels))
|
128 |
+
|
129 |
+
def get_image_fpath(x):
|
130 |
+
return x[0]
|
131 |
+
|
132 |
+
data = DataBlock(
|
133 |
+
blocks=(ImageBlock, CategoryBlock),
|
134 |
+
get_x = get_image_fpath,
|
135 |
+
get_y = ItemGetter(1),
|
136 |
+
item_tfms=[Resize(256)],
|
137 |
+
batch_tfms=[Normalize.from_stats(*imagenet_stats)],
|
138 |
+
splitter=IndexSplitter([])
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
+
dls = data.dataloaders(
|
143 |
+
pairs,
|
144 |
+
device=default_device(),
|
145 |
+
shuffle_fn=lambda x:x,
|
146 |
+
drop_last=False
|
147 |
+
)
|
148 |
+
|
149 |
+
#model = torch.jit.load("../inference_script/scripted.model").cpu()
|
150 |
+
#model = torch.jit.load("scripted.model").cuda()
|
151 |
+
with open('./scripted.model', 'rb') as f:
|
152 |
+
buffer = io.BytesIO(f.read())
|
153 |
+
|
154 |
+
# Load all tensors to the original device
|
155 |
+
model = torch.jit.load(buffer, map_location=torch.device('cpu'))
|
156 |
+
|
157 |
+
preds_list = []
|
158 |
+
with torch.no_grad():
|
159 |
+
for x,y in progress_bar(iter(dls.train), total=len(dls.train)):
|
160 |
+
pred = model(x)
|
161 |
+
preds_list.append(pred)
|
162 |
+
preds = torch.cat(preds_list)
|
163 |
+
preds = to_np(preds)
|
164 |
+
|
165 |
+
predictions_json = {}
|
166 |
+
for cat in categories:
|
167 |
+
filtered_preds = preds[labels == cat]
|
168 |
+
filtered_skus = skus[labels==cat]
|
169 |
+
neighbours,dists = get_neighbours(filtered_preds)
|
170 |
+
#neighbours = neighbours[:,1:]
|
171 |
+
|
172 |
+
for i, sku in enumerate(filtered_skus):
|
173 |
+
predictions_json[sku] = [filtered_skus[j] for j in neighbours[i]]
|
174 |
+
|
175 |
+
return predictions_json
|
176 |
+
|
177 |
+
INDEX_TIME_PARAMS = {'M': 100, 'indexThreadQty': 8,
|
178 |
+
'efConstruction': 2000, 'post': 0}
|
179 |
+
QUERY_TIME_PARAMS = {"efSearch": 2000}
|
180 |
+
N_NEIGHBOURS = 4
|
181 |
+
def get_neighbours(embeddings):
|
182 |
+
index = nmslib.init(method='hnsw', space='l2')
|
183 |
+
index.addDataPointBatch(embeddings)
|
184 |
+
index.createIndex(INDEX_TIME_PARAMS)
|
185 |
+
index.setQueryTimeParams(QUERY_TIME_PARAMS)
|
186 |
+
res = index.knnQueryBatch(
|
187 |
+
embeddings, k=min(N_NEIGHBOURS+1, embeddings.shape[0]), num_threads=8)
|
188 |
+
proc_res = [l[None] for (l, d) in res]
|
189 |
+
neighbours = np.concatenate(proc_res).astype(np.int32)
|
190 |
+
dists = np.array([d for (_, d) in res]).astype(np.float32)
|
191 |
+
return neighbours , dists
|
192 |
+
|
193 |
+
|
194 |
+
def shop_the_look(prod):
|
195 |
+
|
196 |
+
#upsampled_logits, all_segs = get_pred_seg(prod)
|
197 |
+
bag_segment=get_segment(Image.open(prod), 16, ret=True)
|
198 |
+
bag_segment.save('./segs/bag.jpg')
|
199 |
+
shoe_l = get_segment(Image.open(prod), 9, True)#left shoe
|
200 |
+
shoe_r = get_segment(Image.open(prod), 10, True) #right shoe
|
201 |
+
shoe_segment = concat_h(shoe_l, shoe_r)
|
202 |
+
shoe_segment.save('./segs/shoe.jpg')
|
203 |
+
|
204 |
+
|
205 |
+
feed= []
|
206 |
+
main_prods=os.listdir('./data')
|
207 |
+
for sku in main_prods:
|
208 |
+
if 'checkpoint' not in sku:
|
209 |
+
cat = sku.split('_')[0]
|
210 |
+
x={'sku':f'./data/{sku}', 'category':cat}
|
211 |
+
feed.append(x)
|
212 |
+
|
213 |
+
feed.extend([{'sku':'./segs/shoe.jpg',
|
214 |
+
'category':'shoes'},
|
215 |
+
{'sku':'./segs/bag.jpg',
|
216 |
+
'category':'bag'}])
|
217 |
+
|
218 |
+
preds=get_predictions(feed)
|
219 |
+
return preds
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
def concat_h(image1,image2):
|
224 |
+
|
225 |
+
#resize, first image
|
226 |
+
image1 = image1.resize((426, 240))
|
227 |
+
image1_size = image1.size
|
228 |
+
image2_size = image2.size
|
229 |
+
new_image = Image.new('RGB',(2*image1_size[0], image1_size[1]), (250,250,250))
|
230 |
+
new_image.paste(image1,(0,0))
|
231 |
+
new_image.paste(image2,(image1_size[0],0))
|
232 |
+
return new_image
|
data/bag_1.jpg
ADDED
data/bag_2.jpg
ADDED
data/bag_3.jpg
ADDED
data/bag_4.jpg
ADDED
data/bag_5.jpg
ADDED
data/dress_1.jpg
ADDED
data/dress_2.jpg
ADDED
data/dress_3.jpg
ADDED
data/dress_4.jpg
ADDED
data/dress_5.jpg
ADDED
data/shoes_1.jpg
ADDED
data/shoes_2.jpg
ADDED
data/shoes_3.jpg
ADDED
data/shoes_4.jpg
ADDED
data/shoes_5.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
fastai
|
2 |
+
transformers
|
3 |
+
nmslib
|