# helper functions to get segmented mask import gradio as gr import os from PIL import Image url = "http://static.okkular.io/scripted.model" output_file = "./scripted.model" with urllib.request.urlopen(url) as response, open(output_file, 'wb') as out_file: shutil.copyfileobj(response, out_file) def get_stl(input_sku): preds=shop_the_look(f'./data/dress_{input_sku}.jpg') ret_bag = preds['./segs/bag.jpg'][1] ret_shoes = preds['./segs/shoe.jpg'][1] return Image.open(f'./data/dress_{input_sku}.jpg'), Image.open(ret_bag), Image.open(ret_shoes) sku = gr.Dropdown( ["1", "2", "3", '4', '5'], label="Dress Sku", ), demo = gr.Interface(get_stl, gr.Dropdown( ["1", "2", "3", '4', '5'], label="Dress Sku"), ["image", "image", "image"]) demo.launch(root_path=f"/{os.getenv('TOKEN')}") from PIL import Image, ImageChops import numpy as np from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation from PIL import Image import requests import matplotlib.pyplot as plt import torch import torch.nn as nn import os import nmslib from fastai.vision.all import * def get_segment(image, num,ret=False): extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes") model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") inputs = extractor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits.cpu() upsampled_logits = nn.functional.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False, ) #pred_seg = upsampled_logits.argmax(dim=1)[0] pred_seg = upsampled_logits.argmax(dim=1)[0] np_im = np.array(image) pred_seg[pred_seg != num] = 0 mask = pred_seg.detach().cpu().numpy() # masked region np_im[mask.squeeze()==0] = 0 # white bg np_im[np.where((np_im==[0,0,0]).all(axis=2))] = [255,255,255] # trim extra whitespace im = Image.fromarray(np.uint8(np_im)).convert('RGB') im = trim(im) if ret==False: plt.imshow(im) plt.show() elif ret==True: print('here and returning', im) return im def trim(im): bg = Image.new(im.mode, im.size, im.getpixel((0,0))) diff = ImageChops.difference(im, bg) diff = ImageChops.add(diff, diff, 2.0, -100) bbox = diff.getbbox() if bbox: return im.crop(bbox) def get_pred_seg(image_url): extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes") model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") image = Image.open(image_url) inputs = extractor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits.cpu() upsampled_logits = nn.functional.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False, ) pred_seg = upsampled_logits.argmax(dim=1)[0] #plt.imshow(pred_seg) return upsampled_logits,pred_seg #### get predictions and neighbours def get_predictions(feed): pairs = [[x["sku"], x["category"]] for x in feed] skus, labels = zip(*pairs) labels = np.array(labels) skus = np.array(skus) categories = list(set(labels)) def get_image_fpath(x): return x[0] data = DataBlock( blocks=(ImageBlock, CategoryBlock), get_x = get_image_fpath, get_y = ItemGetter(1), item_tfms=[Resize(256)], batch_tfms=[Normalize.from_stats(*imagenet_stats)], splitter=IndexSplitter([]) ) dls = data.dataloaders( pairs, device=default_device(), shuffle_fn=lambda x:x, drop_last=False ) #model = torch.jit.load("../inference_script/scripted.model").cpu() #model = torch.jit.load("scripted.model").cuda() with open('./scripted.model', 'rb') as f: buffer = io.BytesIO(f.read()) # Load all tensors to the original device model = torch.jit.load(buffer, map_location=torch.device('cpu')) preds_list = [] with torch.no_grad(): for x,y in progress_bar(iter(dls.train), total=len(dls.train)): pred = model(x) preds_list.append(pred) preds = torch.cat(preds_list) preds = to_np(preds) predictions_json = {} for cat in categories: filtered_preds = preds[labels == cat] filtered_skus = skus[labels==cat] neighbours,dists = get_neighbours(filtered_preds) #neighbours = neighbours[:,1:] for i, sku in enumerate(filtered_skus): predictions_json[sku] = [filtered_skus[j] for j in neighbours[i]] return predictions_json INDEX_TIME_PARAMS = {'M': 100, 'indexThreadQty': 8, 'efConstruction': 2000, 'post': 0} QUERY_TIME_PARAMS = {"efSearch": 2000} N_NEIGHBOURS = 4 def get_neighbours(embeddings): index = nmslib.init(method='hnsw', space='l2') index.addDataPointBatch(embeddings) index.createIndex(INDEX_TIME_PARAMS) index.setQueryTimeParams(QUERY_TIME_PARAMS) res = index.knnQueryBatch( embeddings, k=min(N_NEIGHBOURS+1, embeddings.shape[0]), num_threads=8) proc_res = [l[None] for (l, d) in res] neighbours = np.concatenate(proc_res).astype(np.int32) dists = np.array([d for (_, d) in res]).astype(np.float32) return neighbours , dists def shop_the_look(prod): #upsampled_logits, all_segs = get_pred_seg(prod) bag_segment=get_segment(Image.open(prod), 16, ret=True) bag_segment.save('./segs/bag.jpg') shoe_l = get_segment(Image.open(prod), 9, True)#left shoe shoe_r = get_segment(Image.open(prod), 10, True) #right shoe shoe_segment = concat_h(shoe_l, shoe_r) shoe_segment.save('./segs/shoe.jpg') feed= [] main_prods=os.listdir('./data') for sku in main_prods: if 'checkpoint' not in sku: cat = sku.split('_')[0] x={'sku':f'./data/{sku}', 'category':cat} feed.append(x) feed.extend([{'sku':'./segs/shoe.jpg', 'category':'shoes'}, {'sku':'./segs/bag.jpg', 'category':'bag'}]) preds=get_predictions(feed) return preds def concat_h(image1,image2): #resize, first image image1 = image1.resize((426, 240)) image1_size = image1.size image2_size = image2.size new_image = Image.new('RGB',(2*image1_size[0], image1_size[1]), (250,250,250)) new_image.paste(image1,(0,0)) new_image.paste(image2,(image1_size[0],0)) return new_image