NewLook / app.py
alexandrecorreia's picture
Update app.py
add8f10
raw
history blame
15.3 kB
import os
import time
import gradio as gr
from gradio.themes import Size, GoogleFont
import sys
import pandas as pd
import webbrowser
from marqo import Client
from PIL import Image
import urllib.request
from PIL import Image
import requests
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
import time
import webbrowser
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
static_dir = Path('./static')
static_dir.mkdir(parents=True, exist_ok=True)
client = Client("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882")
# sys.path.insert(1, 'C:/Users/Alexandre/Documents/University/5_Ano/Estagio/repos_1')
# Create custom Color objects for our primary, secondary, and neutral colors
primary_color = gr.themes.colors.slate
secondary_color = gr.themes.colors.rose
neutral_color = gr.themes.colors.stone # Assuming black for text
# Set the sizes
spacing_size = gr.themes.sizes.spacing_md
radius_size = gr.themes.sizes.radius_md
text_size = gr.themes.sizes.text_md
# Set the fonts
font = GoogleFont("Source Sans Pro")
font_mono = GoogleFont("IBM Plex Mono")
# Create the theme
theme = gr.themes.Base(
primary_hue=primary_color,
secondary_hue=secondary_color,
neutral_hue=neutral_color,
spacing_size=spacing_size,
radius_size=radius_size,
text_size=text_size,
font=font,
font_mono=font_mono
)
def filter_by_column(dataset, search_term, column_name) -> pd.DataFrame:
return dataset[dataset[column_name].str.contains(search_term)]
def dedup_by(dataset, column_name) -> pd.DataFrame:
return dataset.drop_duplicates(subset=[column_name])
def drop_secondary_images(dataset) -> pd.DataFrame:
dataset.image = dataset.primary_image
return dataset.drop_duplicates(subset=['primary_image'])
def dataset_to_gallery(dataset: pd.DataFrame) -> list:
# convert to list of tuples
new_df = dataset[['_id', 'image', 'name', 'colour_code']].copy()
new_df['name_code_combined'] = new_df['name'] + '@@' + new_df['colour_code'].astype(str) + '@@' + new_df['image'].astype(str) + '@@' + new_df['_id'].astype(str)
final_df = new_df[['image', 'name_code_combined']]
items = final_df.to_records(index=False).tolist()
return items
def get_items_from_dataset(start_index=0, end_index=50, dataset=pd.read_json('{}')) -> pd.DataFrame:
df = dataset.sort_values(by=['best_seller_score'], ascending=False)
return df[start_index:end_index]
# def return_page(page, dataset: pd.DataFrame):
# start_index = page * result_per_page
# end_index = (page + 1) * result_per_page
# df = get_items_from_dataset(start_index, end_index, dataset)
# return dataset_to_gallery(dedup_by(df, 'colour_code'))
def start_page(num_results=50):
result = client.index("new_look_expanded_dresses").search("Dress", score_modifiers = {
"add_to_score": [{"field_name": "best_seller_score","weight": 5}],
}, searchable_attributes=['image'], device="cpu", limit=num_results)
imgs = [r for r in result["hits"]]
return return_results_page(imgs)
def return_results_page(results_list: list):
df = pd.DataFrame(results_list)
return dataset_to_gallery(drop_secondary_images(df))
def return_item(combined) -> list:
colour_code = combined.split("@@")[1]
result = client.index("new_look_expanded_dresses").search("", filter_string = "colour_code:" + str(colour_code), searchable_attributes=['image'], device="cpu")
imgs = [r for r in result["hits"]]
df = pd.DataFrame(imgs)
return dataset_to_gallery(df), imgs[0]["description_total"], imgs[0]["url"]
def return_primary_item(combined) -> list:
_id = combined.split("@@")[3]
result = client.index("new_look_expanded_dresses").search("", filter_string = "_id:" + str(_id), searchable_attributes=['image'], device="cpu")
imgs = [r for r in result["hits"]]
print(imgs)
df = pd.DataFrame(imgs)
return dataset_to_gallery(df)[0][0]
### Load local
def load_image(image_input):
image_input.save("../../../Documents/images/img_path.jpg")
os.system('docker cp "../../../Documents/images/img_path.jpg" marqo:"/images/images/"')
### Search local
def search_images(query, best_seller_score_weight):
result = client.index("new_look_expanded_dresses").search(query, score_modifiers = {
"add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}],
}, searchable_attributes=['image'], device="cpu", limit=40)
imgs = [r for r in result["hits"]]
return imgs
### Search AWS
# def search_images(query, best_seller_score_weight):
# client = Client("http://ec2-54-220-125-165.eu-west-1.compute.amazonaws.com:8882")
# result = client.index("new_look_expanded_dresses").search(query, score_modifiers = {
# "add_to_score": [{"field_name": "best_seller_score","weight": best_seller_score_weight/1000}],
# }, searchable_attributes=['primary_image'], device="cpu", limit=40)
# imgs = [r for r in result["hits"]]
# return imgs
def get_labels_probs(labels, image):
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
return probs.tolist()[0]
def get_bar_plot(labels, probs):
fig, ax = plt.subplots()
bar_container = ax.bar(labels, probs)
ax.set(ylabel='frequency', title='Labels probabilities\n', ylim=(0, 1))
ax.bar_label(bar_container, fmt='{:,.4f}')
return fig
css = """
.gradio-container {background-color: beige}
button.gallery-item {background-color: grey}
.label {background-color: grey; width: 80px}
h1 {background-color: grey; width: 180px}
"""
with gr.Blocks(theme=theme, title="New Look", css=css) as demo:
gr.Markdown(
"""
<div style="vertical-align: middle">
<div style="float: left">
<img src="https://1000logos.net/wp-content/uploads/2021/05/New-Look-logo.png" alt=""
width="250" height="250">
</div>
</div>
""")
with gr.Tab(label="Search for images"):
with gr.Row():
with gr.Column(scale=3):
text_input = gr.Text(label="Search with text:")
text_relevance = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1)
text_input_1 = gr.Text(label="Search with text:", visible=False)
text_relevance_1 = gr.Slider(label="Text search relevance", minimum = -5, maximum = 5, value = 1, step = 1, visible=False)
more_text_search = gr.Button(value="More text fields")
text_expanded = gr.State(value=False)
with gr.Column(scale=3):
best_seller_score_weight = gr.Slider(label = "Best seller relevance", minimum=-1, maximum=1, value=0, step=0.01)
search_button = gr.Button(value="Search")
with gr.Column(scale=2):
image_input = gr.Image(type="pil", label="Search with image")
image_path = gr.State(visible=False)
image_relevance = gr.Slider(label="Image search relevance", minimum = -5, maximum = 5, value = 1, step = 1)
# with gr.Row():
# with gr.Column(scale=3):
# ...
# with gr.Column(scale=3):
# search_button = gr.Button(value="Search")
# with gr.Column(scale=2):
# ...
with gr.Row():
with gr.Column(scale=3):
images_gallery = gr.Gallery(value=start_page(), columns=4,
allow_preview=False, show_label=False, object_fit="contain")
with gr.Column():
detail_gallery = gr.Gallery(value=[], columns=2, allow_preview=False, show_label=False, rows=1,
height="400",object_fit="contain")
image_description = gr.Text(label="Description")
product_link = gr.State()
# button_go_to_page = gr.Button(value="Go to product page")
page = gr.HTML()
def on_new_text_box(more_text_search): # SelectData is a subclass of EventData
if more_text_search == "More text fields":
return gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(value="Hide extra text box")
else:
return gr.update(value="", visible=False, interactive=False), gr.update(visible=False, interactive=False), gr.update(value="More text fields")
def on_focus(evt: gr.SelectData): # SelectData is a subclass of EventData
item = return_item(evt.value)
return item[0], item[1], item[2], gr.update(value="<a href= " + item[2] + " target='_blank'> Go to product page </a>")
def on_new_image_to_search(images, evt: gr.SelectData): # SelectData is a subclass of EventData
return return_primary_item(evt.value)
# def on_go_to_product_page(product_link):
# # try:
# return gr.update(value='''<button onclick="window.location.href='+ product_link +';">
# Click Here
# </button>''')
# webbrowser.open(product_link)
# except:
# print("Not able to open product page")
more_text_search.click(on_new_text_box, more_text_search, [text_input_1, text_relevance_1, more_text_search])
images_gallery.select(on_focus, None, [detail_gallery, image_description, product_link, page])
detail_gallery.select(on_new_image_to_search, detail_gallery, image_input)
# button_go_to_page.click(on_go_to_product_page, product_link, page)
# with gr.Tab(label="Search for images"):
# labels_input = gr.Text(label="List of labels")
# gr.Examples(
# ["shirt, dress, shoe",
# "short_sleeve, long_sleeve, three_quarter_sleeve, sleeveless, bell_sleeve"],
# labels_input)
# with gr.Row():
# image_labels_input = gr.Image(type="pil", label="Image to compute")
# bar_plot = gr.Plot()
# with gr.Row():
# gr.Examples(
# ["https://media2.newlookassets.com/i/newlook/869030934/womens/clothing/dresses/khaki-utility-mini-shirt-dress.jpg?strip=true&qlt=50&w=1400",
# "https://media3.newlookassets.com/i/newlook/872692409/womens/clothing/dresses/black-floral-lace-trim-mini-dress.jpg?strip=true&qlt=50&w=1400"],
# image_labels_input)
# gr.Markdown()
# compute_button = gr.Button(value="Compute")
# response_labels = gr.Text()
with gr.Tab(label="Choose dataset"):
gr.Markdown("# Choose Dataset")
with gr.Row():
gr.Dropdown(["New Look Dresses", "New Look All"], label="Available datasets")
gr.Markdown()
gr.Markdown()
with gr.Row():
gr.Button("Select")
gr.Markdown()
gr.Markdown()
def load(image_input):
if image_input != None:
file_name = f"image_to_search.jpg"
# file_path = static_dir / file_name
file_path = "static/" + file_name
print(file_path)
image_input.save(file_path)
return "https://minderalabs-newlook.hf.space/file=" + file_path
else:
return ""
def search(text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight):
# all_queries = [text_input, text_input_1, image_input]
all_queries = [text_input, text_input_1, image_path]
print(all_queries)
all_queries_relevance = [text_relevance, text_relevance_1, image_relevance]
print(all_queries_relevance)
query_is_none = [True if (query == None or query == "") else False for query in all_queries]
print(query_is_none)
if sum([1 if query == False else 0 for query in query_is_none]) == 0:
empty_response = [None] * 5
empty_response.append("")
return []
elif sum([1 if query == False else 0 for query in query_is_none]) == 1:
for i in range(3):
if query_is_none[i] == False:
### Code to run locally
# if i == 2:
# load_image(image_input)
# query = "/images/images/img_path.jpg"
# break
###
query = all_queries[i]
break
else:
query = dict()
for i in range(3):
if query_is_none[i] == False:
### Code to run locally
# if i == 2:
# load_image(image_input)
# query["/images/images/img_path.jpg"] = image_relevance
# continue
###
query[all_queries[i]] = all_queries_relevance[i]
# if text_input == "" and image_input == None:
# empty_response = [None] * 5
# empty_response.append("")
# return empty_response
# if text_input == "":
# load_image(image_input)
# query = "/images/images/img_path.jpg"
# # query = image_path
# elif image_input == None:
# query = text_input
# else:
# query = dict()
# load_image(image_input)
# query["/images/images/img_path.jpg"] = image_relevance
# # query[image_path] = image_relevance
# query[text_input] = text_relevance
list_image_results = []
response = search_images(query, best_seller_score_weight)
# for i in range(len(response)):
# urllib.request.urlretrieve(response[i]["primary_image"], "img_res_path_" + str(i) + ".jpg")
# list_image_results.append(Image.open(r"img_res_path_" + str(i) + r".jpg"))
return return_results_page(response)
# def get_labels(labels_input, image_labels_input):
# labels_probs = get_labels_probs(labels_input.split(","), image_labels_input)
# bar_plot = get_bar_plot(labels_input.split(","), labels_probs)
# return bar_plot, labels_probs
# search_button.click(
# search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], images_gallery
# )
search_button.click(
load, image_input, image_path
).then(
search, [text_input, text_input_1, image_input, image_path, text_relevance, text_relevance_1, image_relevance, best_seller_score_weight], [images_gallery]
)
# compute_button.click(
# get_labels, [labels_input, image_labels_input], [bar_plot, response_labels]
# )
demo.queue()
demo.launch()