|
import io |
|
import ast |
|
import json |
|
import base64 |
|
import spaces |
|
import requests |
|
import numpy as np |
|
import gradio as gr |
|
from PIL import Image |
|
from io import BytesIO |
|
import face_recognition |
|
from turtle import title |
|
from openai import OpenAI |
|
from collections import Counter |
|
from transformers import pipeline |
|
|
|
import urllib.request |
|
from transformers import YolosImageProcessor, YolosForObjectDetection |
|
import torch |
|
import matplotlib.pyplot as plt |
|
from torchvision.transforms import ToTensor, ToPILImage |
|
|
|
|
|
client = OpenAI() |
|
|
|
pipe = pipeline("zero-shot-image-classification", model="patrickjohncyh/fashion-clip") |
|
|
|
color_file_path = 'color_config.json' |
|
attributes_file_path = 'attributes_config.json' |
|
import os |
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
|
|
with open(color_file_path, 'r') as file: |
|
color_data = json.load(file) |
|
|
|
|
|
with open(attributes_file_path, 'r') as file: |
|
attributes_data = json.load(file) |
|
|
|
COLOURS_DICT = color_data['color_mapping'] |
|
ATTRIBUTES_DICT = attributes_data['attribute_mapping'] |
|
|
|
|
|
def shot(input, category, level): |
|
output_dict = {} |
|
if level == 'variant': |
|
subColour, mainColour, score = get_colour(ast.literal_eval(str(input)), category) |
|
openai_parsed_response = get_openAI_tags(ast.literal_eval(str(input))) |
|
face_embeddings = get_face_embeddings(ast.literal_eval(str(input))) |
|
cropped_images = get_cropped_images(ast.literal_eval(str(input)), category) |
|
|
|
|
|
output_dict['colors'] = { |
|
"main": mainColour, |
|
"sub": subColour, |
|
"score": score |
|
} |
|
output_dict['image_mapping'] = openai_parsed_response |
|
output_dict['face_embeddings'] = face_embeddings |
|
output_dict['cropped_images'] = cropped_images |
|
|
|
if level == 'product': |
|
common_result = get_predicted_attributes(ast.literal_eval(str(input)), category) |
|
output_dict['attributes'] = common_result |
|
output_dict['subcategory'] = category |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return json.dumps(output_dict) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def get_colour(image_urls, category): |
|
colourLabels = list(COLOURS_DICT.keys()) |
|
for i in range(len(colourLabels)): |
|
colourLabels[i] = colourLabels[i] + " clothing: " + category |
|
|
|
print("Colour Labels:", colourLabels) |
|
print("Image URLs:", image_urls) |
|
|
|
responses = pipe(image_urls, candidate_labels=colourLabels) |
|
mainColour = responses[0][0]['label'].split(" clothing:")[0] |
|
|
|
if mainColour not in COLOURS_DICT: |
|
return None, None, None |
|
|
|
labels = COLOURS_DICT[mainColour] |
|
for i in range(len(labels)): |
|
labels[i] = labels[i] + " clothing: " + category |
|
|
|
print("Labels for pipe:", labels) |
|
responses = pipe(image_urls, candidate_labels=labels) |
|
subColour = responses[0][0]['label'].split(" clothing:")[0] |
|
|
|
return subColour, mainColour, responses[0][0]['score'] |
|
|
|
|
|
|
|
@spaces.GPU |
|
def get_predicted_attributes(image_urls, category): |
|
|
|
attributes = list(ATTRIBUTES_DICT.get(category, {}).keys()) |
|
|
|
|
|
common_result = [] |
|
for attribute in attributes: |
|
values = ATTRIBUTES_DICT.get(category, {}).get(attribute, []) |
|
|
|
if len(values) == 0: |
|
continue |
|
|
|
|
|
attribute_formatted = attribute.replace("colartype", "collar").replace("sleevelength", "sleeve length").replace("fabricstyle", "fabric") |
|
values_formatted = [f"{attribute_formatted}: {value}, clothing: {category}" for value in values] |
|
|
|
|
|
responses = pipe(image_urls, candidate_labels=values_formatted) |
|
result = [response[0]['label'].split(", clothing:")[0] for response in responses] |
|
|
|
|
|
if attribute_formatted == "details": |
|
result += [response[1]['label'].split(", clothing:")[0] for response in responses] |
|
common_result.append(Counter(result).most_common(2)) |
|
else: |
|
common_result.append(Counter(result).most_common(1)) |
|
|
|
|
|
for i, result in enumerate(common_result): |
|
common_result[i] = ", ".join([f"{x[0]}" for x in result]) |
|
|
|
result = {} |
|
|
|
|
|
for item in common_result: |
|
|
|
key, value = item.split(': ', 1) |
|
|
|
if key == "details": |
|
details_split = value.split(" , ") |
|
if len(details_split) == 2: |
|
result["details_1"] = details_split[0] |
|
result["details_2"] = details_split[1] |
|
else: |
|
result["details_1"] = value |
|
else: |
|
result[key] = value |
|
|
|
return result |
|
|
|
def get_openAI_tags(image_urls): |
|
|
|
imageList = [] |
|
for image in image_urls: |
|
imageList.append({"type": "image_url", "image_url": {"url": image}}) |
|
|
|
openai_response = client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": "You're a tagging assistant, you will help label and tag product pictures for my online e-commerce platform. Your tasks will be to return which angle the product images were taken from. You will have to choose from 'full-body', 'half-body', 'side', 'back', or 'zoomed' angles. You should label each of the images with one of these labels depending on which you think fits best (ideally, every label should be used at least once, but only if there are 5 or more images), and should respond with an unformatted dictionary where the key is a string representation of the url index of the url and the value is the assigned label." |
|
} |
|
] |
|
}, |
|
{ |
|
"role": "user", |
|
"content": imageList |
|
}, |
|
], |
|
temperature=1, |
|
max_tokens=500, |
|
top_p=1, |
|
frequency_penalty=0, |
|
presence_penalty=0 |
|
) |
|
response = json.loads(openai_response.choices[0].message.content) |
|
return response |
|
|
|
|
|
@spaces.GPU |
|
def get_face_embeddings(image_urls): |
|
|
|
results = {} |
|
|
|
|
|
for index, url in enumerate(image_urls): |
|
try: |
|
|
|
response = requests.get(url) |
|
|
|
response.raise_for_status() |
|
|
|
|
|
image = face_recognition.load_image_file(BytesIO(response.content)) |
|
|
|
|
|
face_encodings = face_recognition.face_encodings(image) |
|
|
|
|
|
if not face_encodings: |
|
results[str(index)] = [] |
|
else: |
|
|
|
results[str(index)] = face_encodings[0].tolist() |
|
except Exception as e: |
|
|
|
results[str(index)] = f"Error processing image: {str(e)}" |
|
|
|
return results |
|
|
|
|
|
ACCURACY_THRESHOLD = 0.86 |
|
|
|
def open_image_from_url(url): |
|
|
|
response = requests.get(url, stream=True) |
|
response.raise_for_status() |
|
|
|
|
|
image = Image.open(BytesIO(response.content)) |
|
|
|
return image |
|
|
|
|
|
main = [['Product Id', 'Sku', 'Color', 'Images', 'Status', 'Category', 'Text']] |
|
|
|
|
|
cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel'] |
|
|
|
filter = ['dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel'] |
|
|
|
|
|
yolo_mapping = { |
|
'shirt, blouse': 3, |
|
'top, t-shirt, sweatshirt' : 1, |
|
'sweater': 1, |
|
'cardigan': 1, |
|
'jacket': 3, |
|
'vest': 1, |
|
'pants': 2, |
|
'shorts': 2, |
|
'skirt': 2, |
|
'coat': 3, |
|
'dress': 0, |
|
'jumpsuit': 0, |
|
'bag, wallet': 4 |
|
} |
|
|
|
|
|
label_mapping = [ |
|
['women-dress-mini', 'women-dress-dress', 'women-dress-maxi', 'women-dress-midi', 'women-playsuitsjumpsuits-playsuit', 'women-playsuitsjumpsuits-jumpsuit', 'women-coords-coords', 'women-swimwear-onepieces', 'women-swimwear-bikinisets'], |
|
['women-sweatersknits-cardigan', 'women-top-waistcoat', 'women-top-blouse', 'women-sweatersknits-blouse', 'women-sweatersknits-sweater', 'women-top-top', 'women-loungewear-hoodie', 'women-top-camistanks', 'women-top-tshirt', 'women-top-croptop', 'women-loungewear-sweatshirt', 'women-top-body'], |
|
['women-loungewear-joggers', 'women-bottom-trousers', 'women-bottom-leggings', 'women-bottom-jeans', 'women-bottom-shorts', 'women-bottom-skirt', 'women-loungewear-activewear', 'women-bottom-joggers'], |
|
['women-top-shirt', 'women-outwear-coatjacket', 'women-outwear-blazer', 'women-outwear-coatjacket', 'women-outwear-kimonos'], |
|
['women-accessories-bags'] |
|
] |
|
|
|
MODEL_NAME = "valentinafeve/yolos-fashionpedia" |
|
|
|
feature_extractor = YolosImageProcessor.from_pretrained('hustvl/yolos-small') |
|
model = YolosForObjectDetection.from_pretrained(MODEL_NAME) |
|
|
|
def get_category_index(category): |
|
|
|
for i, labels in enumerate(label_mapping): |
|
if category in labels: |
|
break |
|
return i |
|
|
|
def get_yolo_index(category): |
|
|
|
return yolo_mapping[category] |
|
|
|
def fix_channels(t): |
|
""" |
|
Some images may have 4 channels (transparent images) or just 1 channel (black and white images), in order to let the images have only 3 channels. I am going to remove the fourth channel in transparent images and stack the single channel in back and white images. |
|
:param t: Tensor-like image |
|
:return: Tensor-like image with three channels |
|
""" |
|
if len(t.shape) == 2: |
|
return ToPILImage()(torch.stack([t for i in (0, 0, 0)])) |
|
if t.shape[0] == 4: |
|
return ToPILImage()(t[:3]) |
|
if t.shape[0] == 1: |
|
return ToPILImage()(torch.stack([t[0] for i in (0, 0, 0)])) |
|
return ToPILImage()(t) |
|
|
|
def idx_to_text(i): |
|
return cats[i] |
|
|
|
|
|
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], |
|
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] |
|
|
|
|
|
def box_cxcywh_to_xyxy(x): |
|
x_c, y_c, w, h = x.unbind(1) |
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), |
|
(x_c + 0.5 * w), (y_c + 0.5 * h)] |
|
return torch.stack(b, dim=1) |
|
|
|
def rescale_bboxes(out_bbox, size): |
|
img_w, img_h = size |
|
b = box_cxcywh_to_xyxy(out_bbox) |
|
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) |
|
|
|
return b |
|
|
|
def plot_results(pil_img, prob, boxes): |
|
plt.figure(figsize=(16,10)) |
|
plt.imshow(pil_img) |
|
ax = plt.gca() |
|
colors = COLORS * 100 |
|
i = 0 |
|
|
|
crops = [] |
|
crop_classes = [] |
|
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): |
|
cl = p.argmax() |
|
|
|
|
|
box_img = pil_img.crop((xmin, ymin, xmax, ymax)) |
|
crops.append(box_img) |
|
crop_classes.append(idx_to_text(cl)) |
|
|
|
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, |
|
fill=False, color=c, linewidth=3)) |
|
|
|
ax.text(xmin, ymin, idx_to_text(cl), fontsize=10, |
|
bbox=dict(facecolor=c, alpha=0.8)) |
|
|
|
i += 1 |
|
|
|
|
|
plt.axis('off') |
|
plt.subplots_adjust(left=0, right=1, top=1, bottom=0) |
|
output_img = plt.gcf() |
|
plt.close() |
|
|
|
return output_img, crops, crop_classes |
|
|
|
|
|
def visualize_predictions(image, outputs, threshold=0.8): |
|
|
|
probas = outputs.logits.softmax(-1)[0, :, :-1] |
|
keep = probas.max(-1).values > threshold |
|
|
|
|
|
bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size) |
|
|
|
|
|
filter_set = set(filter) |
|
filtered_probas_boxes = [ |
|
(proba, box) for proba, box in zip(probas[keep], bboxes_scaled) |
|
if idx_to_text(proba.argmax()) not in filter_set |
|
] |
|
|
|
|
|
contains_jumpsuit_or_dress = any(idx_to_text(proba.argmax()) in ["jumpsuit", "dress"] for proba, _ in filtered_probas_boxes) |
|
if contains_jumpsuit_or_dress and len(filtered_probas_boxes) > 1: |
|
filtered_probas_boxes = [ |
|
(proba, box) for proba, box in filtered_probas_boxes |
|
if idx_to_text(proba.argmax()) not in ["jumpsuit", "dress"] |
|
] |
|
|
|
|
|
unique_classes = set() |
|
unique_filtered_probas_boxes = [] |
|
for proba, box in filtered_probas_boxes: |
|
class_text = idx_to_text(proba.argmax()) |
|
if class_text not in unique_classes: |
|
unique_classes.add(class_text) |
|
unique_filtered_probas_boxes.append((proba, box)) |
|
|
|
|
|
output_img = None |
|
crops = None |
|
crop_classes = None |
|
if unique_filtered_probas_boxes: |
|
final_probas, final_boxes = zip(*unique_filtered_probas_boxes) |
|
output_img, crops, crop_classes = plot_results(image, list(final_probas), torch.stack(final_boxes)) |
|
|
|
|
|
return [proba.argmax().item() for proba, _ in unique_filtered_probas_boxes], output_img, crops, crop_classes |
|
|
|
@spaces.GPU |
|
def get_objects(image, threshold=0.8): |
|
class_counts = {} |
|
image = fix_channels(ToTensor()(image)) |
|
image = image.resize((600, 800)) |
|
|
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
outputs = model(**inputs) |
|
|
|
detected_classes, output_img, crops, crop_classes = visualize_predictions(image, outputs, threshold=threshold) |
|
for cl in detected_classes: |
|
class_name = idx_to_text(cl) |
|
if class_name not in class_counts: |
|
class_counts[class_name] = 0 |
|
class_counts[class_name] += 1 |
|
|
|
if crop_classes is not None: |
|
crop_classes = [get_yolo_index(c) for c in crop_classes] |
|
|
|
return class_counts, output_img, crops, crop_classes |
|
|
|
def encode_images_to_base64(cropped_list): |
|
base64_images = [] |
|
for image in cropped_list: |
|
with io.BytesIO() as buffer: |
|
image.convert('RGB').save(buffer, format='JPEG') |
|
base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
base64_images.append(base64_image) |
|
return base64_images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_cropped_images(images, category): |
|
cropped_list = [] |
|
resultsPerCategory = {} |
|
for num, image in enumerate(images): |
|
image = open_image_from_url(image) |
|
class_counts, output_img, cropped_images, cropped_classes = get_objects(image, 0.37) |
|
|
|
if not class_counts: |
|
continue |
|
|
|
for i, image in enumerate(cropped_images): |
|
cropped_list.append(image) |
|
|
|
|
|
base64_images = encode_images_to_base64(cropped_list) |
|
|
|
return base64_images |
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=shot, |
|
inputs=[ |
|
gr.Textbox(label="Image URLs (starting with http/https) comma seperated "), |
|
gr.Textbox(label="Category"), |
|
gr.Textbox(label="Level; accepted 'variant' or 'product'") |
|
], |
|
outputs="text", |
|
examples=[ |
|
[['https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiAiaGljY3VwLWltYWdlLWhvc3RpbmciLCAia2V5IjogIlc4MDAwMDAwMTM0LU9SL1c4MDAwMDAwMTM0LU9SLTEuanBnIiwgImVkaXRzIjogeyJyZXNpemUiOiB7IndpZHRoIjogODAwLCAiaGVpZ2h0IjogMTIwMC4wLCAiZml0IjogIm91dHNpZGUifX19', |
|
'https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiAiaGljY3VwLWltYWdlLWhvc3RpbmciLCAia2V5IjogIlc4MDAwMDAwMTM0LU9SL1c4MDAwMDAwMTM0LU9SLTIuanBnIiwgImVkaXRzIjogeyJyZXNpemUiOiB7IndpZHRoIjogODAwLCAiaGVpZ2h0IjogMTIwMC4wLCAiZml0IjogIm91dHNpZGUifX19', |
|
'https://d2q1sfov6ca7my.cloudfront.net/eyJidWNrZXQiOiAiaGljY3VwLWltYWdlLWhvc3RpbmciLCAia2V5IjogIlc4MDAwMDAwMTM0LU9SL1c4MDAwMDAwMTM0LU9SLTMuanBnIiwgImVkaXRzIjogeyJyZXNpemUiOiB7IndpZHRoIjogODAwLCAiaGVpZ2h0IjogMTIwMC4wLCAiZml0IjogIm91dHNpZGUifX19'], "women-top-shirt","variant"]], |
|
description="Add an image URL (starting with http/https) or upload a picture, and provide a list of labels separated by commas.", |
|
title="Full product flow" |
|
) |
|
|
|
|
|
iface.launch() |
|
|