stl / app.py
aparnak1's picture
initial
718b08a
raw
history blame
6.68 kB
# helper functions to get segmented mask
import gradio as gr
import os
from PIL import Image
url = "http://static.okkular.io/scripted.model"
output_file = "./scripted.model"
with urllib.request.urlopen(url) as response, open(output_file, 'wb') as out_file:
shutil.copyfileobj(response, out_file)
def get_stl(input_sku):
preds=shop_the_look(f'./data/dress_{input_sku}.jpg')
ret_bag = preds['./segs/bag.jpg'][1]
ret_shoes = preds['./segs/shoe.jpg'][1]
return Image.open(f'./data/dress_{input_sku}.jpg'), Image.open(ret_bag), Image.open(ret_shoes)
sku = gr.Dropdown(
["1", "2", "3", '4', '5'], label="Dress Sku",
),
demo = gr.Interface(get_stl, gr.Dropdown(
["1", "2", "3", '4', '5'], label="Dress Sku"), ["image", "image", "image"])
demo.launch(root_path=f"/{os.getenv('TOKEN')}")
from PIL import Image, ImageChops
import numpy as np
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import os
import nmslib
from fastai.vision.all import *
def get_segment(image, num,ret=False):
extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
inputs = extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
#pred_seg = upsampled_logits.argmax(dim=1)[0]
pred_seg = upsampled_logits.argmax(dim=1)[0]
np_im = np.array(image)
pred_seg[pred_seg != num] = 0
mask = pred_seg.detach().cpu().numpy()
# masked region
np_im[mask.squeeze()==0] = 0
# white bg
np_im[np.where((np_im==[0,0,0]).all(axis=2))] = [255,255,255]
# trim extra whitespace
im = Image.fromarray(np.uint8(np_im)).convert('RGB')
im = trim(im)
if ret==False:
plt.imshow(im)
plt.show()
elif ret==True:
print('here and returning', im)
return im
def trim(im):
bg = Image.new(im.mode, im.size, im.getpixel((0,0)))
diff = ImageChops.difference(im, bg)
diff = ImageChops.add(diff, diff, 2.0, -100)
bbox = diff.getbbox()
if bbox:
return im.crop(bbox)
def get_pred_seg(image_url):
extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
image = Image.open(image_url)
inputs = extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
#plt.imshow(pred_seg)
return upsampled_logits,pred_seg
#### get predictions and neighbours
def get_predictions(feed):
pairs = [[x["sku"], x["category"]] for x in feed]
skus, labels = zip(*pairs)
labels = np.array(labels)
skus = np.array(skus)
categories = list(set(labels))
def get_image_fpath(x):
return x[0]
data = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_x = get_image_fpath,
get_y = ItemGetter(1),
item_tfms=[Resize(256)],
batch_tfms=[Normalize.from_stats(*imagenet_stats)],
splitter=IndexSplitter([])
)
dls = data.dataloaders(
pairs,
device=default_device(),
shuffle_fn=lambda x:x,
drop_last=False
)
#model = torch.jit.load("../inference_script/scripted.model").cpu()
#model = torch.jit.load("scripted.model").cuda()
with open('./scripted.model', 'rb') as f:
buffer = io.BytesIO(f.read())
# Load all tensors to the original device
model = torch.jit.load(buffer, map_location=torch.device('cpu'))
preds_list = []
with torch.no_grad():
for x,y in progress_bar(iter(dls.train), total=len(dls.train)):
pred = model(x)
preds_list.append(pred)
preds = torch.cat(preds_list)
preds = to_np(preds)
predictions_json = {}
for cat in categories:
filtered_preds = preds[labels == cat]
filtered_skus = skus[labels==cat]
neighbours,dists = get_neighbours(filtered_preds)
#neighbours = neighbours[:,1:]
for i, sku in enumerate(filtered_skus):
predictions_json[sku] = [filtered_skus[j] for j in neighbours[i]]
return predictions_json
INDEX_TIME_PARAMS = {'M': 100, 'indexThreadQty': 8,
'efConstruction': 2000, 'post': 0}
QUERY_TIME_PARAMS = {"efSearch": 2000}
N_NEIGHBOURS = 4
def get_neighbours(embeddings):
index = nmslib.init(method='hnsw', space='l2')
index.addDataPointBatch(embeddings)
index.createIndex(INDEX_TIME_PARAMS)
index.setQueryTimeParams(QUERY_TIME_PARAMS)
res = index.knnQueryBatch(
embeddings, k=min(N_NEIGHBOURS+1, embeddings.shape[0]), num_threads=8)
proc_res = [l[None] for (l, d) in res]
neighbours = np.concatenate(proc_res).astype(np.int32)
dists = np.array([d for (_, d) in res]).astype(np.float32)
return neighbours , dists
def shop_the_look(prod):
#upsampled_logits, all_segs = get_pred_seg(prod)
bag_segment=get_segment(Image.open(prod), 16, ret=True)
bag_segment.save('./segs/bag.jpg')
shoe_l = get_segment(Image.open(prod), 9, True)#left shoe
shoe_r = get_segment(Image.open(prod), 10, True) #right shoe
shoe_segment = concat_h(shoe_l, shoe_r)
shoe_segment.save('./segs/shoe.jpg')
feed= []
main_prods=os.listdir('./data')
for sku in main_prods:
if 'checkpoint' not in sku:
cat = sku.split('_')[0]
x={'sku':f'./data/{sku}', 'category':cat}
feed.append(x)
feed.extend([{'sku':'./segs/shoe.jpg',
'category':'shoes'},
{'sku':'./segs/bag.jpg',
'category':'bag'}])
preds=get_predictions(feed)
return preds
def concat_h(image1,image2):
#resize, first image
image1 = image1.resize((426, 240))
image1_size = image1.size
image2_size = image2.size
new_image = Image.new('RGB',(2*image1_size[0], image1_size[1]), (250,250,250))
new_image.paste(image1,(0,0))
new_image.paste(image2,(image1_size[0],0))
return new_image