aparnak1 commited on
Commit
718b08a
1 Parent(s): 2faaa97
app.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # helper functions to get segmented mask
2
+
3
+ import gradio as gr
4
+ import os
5
+ from PIL import Image
6
+
7
+ url = "http://static.okkular.io/scripted.model"
8
+
9
+ output_file = "./scripted.model"
10
+ with urllib.request.urlopen(url) as response, open(output_file, 'wb') as out_file:
11
+ shutil.copyfileobj(response, out_file)
12
+
13
+ def get_stl(input_sku):
14
+ preds=shop_the_look(f'./data/dress_{input_sku}.jpg')
15
+ ret_bag = preds['./segs/bag.jpg'][1]
16
+ ret_shoes = preds['./segs/shoe.jpg'][1]
17
+ return Image.open(f'./data/dress_{input_sku}.jpg'), Image.open(ret_bag), Image.open(ret_shoes)
18
+
19
+ sku = gr.Dropdown(
20
+ ["1", "2", "3", '4', '5'], label="Dress Sku",
21
+ ),
22
+
23
+
24
+ demo = gr.Interface(get_stl, gr.Dropdown(
25
+ ["1", "2", "3", '4', '5'], label="Dress Sku"), ["image", "image", "image"])
26
+
27
+
28
+ demo.launch(root_path=f"/{os.getenv('TOKEN')}")
29
+
30
+ from PIL import Image, ImageChops
31
+ import numpy as np
32
+ from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
33
+ from PIL import Image
34
+ import requests
35
+ import matplotlib.pyplot as plt
36
+ import torch
37
+ import torch.nn as nn
38
+ import os
39
+ import nmslib
40
+ from fastai.vision.all import *
41
+
42
+
43
+
44
+ def get_segment(image, num,ret=False):
45
+
46
+ extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
47
+ model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
48
+
49
+ inputs = extractor(images=image, return_tensors="pt")
50
+
51
+ outputs = model(**inputs)
52
+ logits = outputs.logits.cpu()
53
+
54
+ upsampled_logits = nn.functional.interpolate(
55
+ logits,
56
+ size=image.size[::-1],
57
+ mode="bilinear",
58
+ align_corners=False,
59
+ )
60
+
61
+ #pred_seg = upsampled_logits.argmax(dim=1)[0]
62
+
63
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
64
+ np_im = np.array(image)
65
+ pred_seg[pred_seg != num] = 0
66
+ mask = pred_seg.detach().cpu().numpy()
67
+
68
+ # masked region
69
+ np_im[mask.squeeze()==0] = 0
70
+ # white bg
71
+ np_im[np.where((np_im==[0,0,0]).all(axis=2))] = [255,255,255]
72
+
73
+ # trim extra whitespace
74
+ im = Image.fromarray(np.uint8(np_im)).convert('RGB')
75
+ im = trim(im)
76
+
77
+ if ret==False:
78
+ plt.imshow(im)
79
+ plt.show()
80
+ elif ret==True:
81
+ print('here and returning', im)
82
+ return im
83
+
84
+
85
+ def trim(im):
86
+ bg = Image.new(im.mode, im.size, im.getpixel((0,0)))
87
+ diff = ImageChops.difference(im, bg)
88
+ diff = ImageChops.add(diff, diff, 2.0, -100)
89
+ bbox = diff.getbbox()
90
+ if bbox:
91
+ return im.crop(bbox)
92
+
93
+
94
+ def get_pred_seg(image_url):
95
+ extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
96
+ model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
97
+
98
+
99
+ image = Image.open(image_url)
100
+ inputs = extractor(images=image, return_tensors="pt")
101
+
102
+ outputs = model(**inputs)
103
+ logits = outputs.logits.cpu()
104
+
105
+ upsampled_logits = nn.functional.interpolate(
106
+ logits,
107
+ size=image.size[::-1],
108
+ mode="bilinear",
109
+ align_corners=False,
110
+ )
111
+
112
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
113
+ #plt.imshow(pred_seg)
114
+ return upsampled_logits,pred_seg
115
+
116
+
117
+
118
+
119
+ #### get predictions and neighbours
120
+
121
+
122
+ def get_predictions(feed):
123
+ pairs = [[x["sku"], x["category"]] for x in feed]
124
+ skus, labels = zip(*pairs)
125
+ labels = np.array(labels)
126
+ skus = np.array(skus)
127
+ categories = list(set(labels))
128
+
129
+ def get_image_fpath(x):
130
+ return x[0]
131
+
132
+ data = DataBlock(
133
+ blocks=(ImageBlock, CategoryBlock),
134
+ get_x = get_image_fpath,
135
+ get_y = ItemGetter(1),
136
+ item_tfms=[Resize(256)],
137
+ batch_tfms=[Normalize.from_stats(*imagenet_stats)],
138
+ splitter=IndexSplitter([])
139
+ )
140
+
141
+
142
+ dls = data.dataloaders(
143
+ pairs,
144
+ device=default_device(),
145
+ shuffle_fn=lambda x:x,
146
+ drop_last=False
147
+ )
148
+
149
+ #model = torch.jit.load("../inference_script/scripted.model").cpu()
150
+ #model = torch.jit.load("scripted.model").cuda()
151
+ with open('./scripted.model', 'rb') as f:
152
+ buffer = io.BytesIO(f.read())
153
+
154
+ # Load all tensors to the original device
155
+ model = torch.jit.load(buffer, map_location=torch.device('cpu'))
156
+
157
+ preds_list = []
158
+ with torch.no_grad():
159
+ for x,y in progress_bar(iter(dls.train), total=len(dls.train)):
160
+ pred = model(x)
161
+ preds_list.append(pred)
162
+ preds = torch.cat(preds_list)
163
+ preds = to_np(preds)
164
+
165
+ predictions_json = {}
166
+ for cat in categories:
167
+ filtered_preds = preds[labels == cat]
168
+ filtered_skus = skus[labels==cat]
169
+ neighbours,dists = get_neighbours(filtered_preds)
170
+ #neighbours = neighbours[:,1:]
171
+
172
+ for i, sku in enumerate(filtered_skus):
173
+ predictions_json[sku] = [filtered_skus[j] for j in neighbours[i]]
174
+
175
+ return predictions_json
176
+
177
+ INDEX_TIME_PARAMS = {'M': 100, 'indexThreadQty': 8,
178
+ 'efConstruction': 2000, 'post': 0}
179
+ QUERY_TIME_PARAMS = {"efSearch": 2000}
180
+ N_NEIGHBOURS = 4
181
+ def get_neighbours(embeddings):
182
+ index = nmslib.init(method='hnsw', space='l2')
183
+ index.addDataPointBatch(embeddings)
184
+ index.createIndex(INDEX_TIME_PARAMS)
185
+ index.setQueryTimeParams(QUERY_TIME_PARAMS)
186
+ res = index.knnQueryBatch(
187
+ embeddings, k=min(N_NEIGHBOURS+1, embeddings.shape[0]), num_threads=8)
188
+ proc_res = [l[None] for (l, d) in res]
189
+ neighbours = np.concatenate(proc_res).astype(np.int32)
190
+ dists = np.array([d for (_, d) in res]).astype(np.float32)
191
+ return neighbours , dists
192
+
193
+
194
+ def shop_the_look(prod):
195
+
196
+ #upsampled_logits, all_segs = get_pred_seg(prod)
197
+ bag_segment=get_segment(Image.open(prod), 16, ret=True)
198
+ bag_segment.save('./segs/bag.jpg')
199
+ shoe_l = get_segment(Image.open(prod), 9, True)#left shoe
200
+ shoe_r = get_segment(Image.open(prod), 10, True) #right shoe
201
+ shoe_segment = concat_h(shoe_l, shoe_r)
202
+ shoe_segment.save('./segs/shoe.jpg')
203
+
204
+
205
+ feed= []
206
+ main_prods=os.listdir('./data')
207
+ for sku in main_prods:
208
+ if 'checkpoint' not in sku:
209
+ cat = sku.split('_')[0]
210
+ x={'sku':f'./data/{sku}', 'category':cat}
211
+ feed.append(x)
212
+
213
+ feed.extend([{'sku':'./segs/shoe.jpg',
214
+ 'category':'shoes'},
215
+ {'sku':'./segs/bag.jpg',
216
+ 'category':'bag'}])
217
+
218
+ preds=get_predictions(feed)
219
+ return preds
220
+
221
+
222
+
223
+ def concat_h(image1,image2):
224
+
225
+ #resize, first image
226
+ image1 = image1.resize((426, 240))
227
+ image1_size = image1.size
228
+ image2_size = image2.size
229
+ new_image = Image.new('RGB',(2*image1_size[0], image1_size[1]), (250,250,250))
230
+ new_image.paste(image1,(0,0))
231
+ new_image.paste(image2,(image1_size[0],0))
232
+ return new_image
data/bag_1.jpg ADDED
data/bag_2.jpg ADDED
data/bag_3.jpg ADDED
data/bag_4.jpg ADDED
data/bag_5.jpg ADDED
data/dress_1.jpg ADDED
data/dress_2.jpg ADDED
data/dress_3.jpg ADDED
data/dress_4.jpg ADDED
data/dress_5.jpg ADDED
data/shoes_1.jpg ADDED
data/shoes_2.jpg ADDED
data/shoes_3.jpg ADDED
data/shoes_4.jpg ADDED
data/shoes_5.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastai
2
+ transformers
3
+ nmslib