Spaces:
Runtime error
Runtime error
from datasets import load_dataset | |
import streamlit as st | |
from data_utils import get_embedding | |
from bokeh.plotting import figure,show | |
from bokeh.io import push_notebook, output_notebook | |
# output_notebook() | |
from bokeh.palettes import d3 | |
from bokeh.models import ColumnDataSource, Grid, LinearAxis, Plot, Scatter | |
from bokeh.transform import factor_cmap, factor_mark | |
import base64 | |
from io import BytesIO | |
label_columns=["gender","subCategory","masterCategory"] | |
model_interest=['facebook/deit-tiny-patch16-224', # very small model 5M param model | |
'microsoft/beit-base-patch16-224', # big model | |
"facebook/dino-vits8", | |
"facebook/levit-128S"] | |
def convert_base64(img): | |
buffered = BytesIO() | |
img.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return "data:image/jpeg;base64,"+img_str | |
def cache_embedding(model_name): | |
dataset=load_dataset("ceyda/fashion-products-small", split="train") | |
dataset=dataset.shuffle(seed=100) #pick a random seed | |
viz_dat=dataset.train_test_split(0.1,shuffle=False) #일부를 visualization위해서 뽑시단 | |
viz_dat=viz_dat["test"] | |
embedding = get_embedding(model_name,viz_dat) | |
embedding["image"]=embedding["image"].apply(convert_base64) | |
labels = {label:viz_dat.unique(label) for label in label_columns} | |
return embedding,labels | |
def cache_graph(model_name,color_column): | |
embedding,labels=cache_embedding(model_name) | |
color_palette = (d3['Category20'][20]+d3['Category20b'][20]+d3['Category20c'][20])[:len(labels[color_column])] | |
source = ColumnDataSource(data=embedding) | |
# colors = factor_cmap('gender', palette=["purple","navy","green","blue","pink"], factors=embedding["gender"].unique()) | |
TOOLS="hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,reset,tap,save,box_select,lasso_select," | |
TOOLTIPS = """ | |
<div> | |
<div> | |
<img | |
src="@image" height="42" alt="@image" width="42" | |
style="float: left; margin: 0px 15px 15px 0px;" | |
border="2" | |
></img> | |
</div> | |
""" | |
p = figure(tools=TOOLS,tooltips=TOOLTIPS) | |
p.scatter(x="x", y="y", source=source, | |
# marker=factor_mark('gender', ['circle', 'circle_cross', 'circle_dot','circle_x','circle_y'], labels["gender"]), | |
color=factor_cmap(color_column, color_palette, labels[color_column]) | |
) | |
return p | |
st.write("It takes some time for the graph to load...wait please") | |
model_name=st.sidebar.selectbox("Model",model_interest) | |
color_column=st.selectbox("Color by",label_columns) | |
p=cache_graph(model_name,color_column) | |
st.bokeh_chart(p, use_container_width=False) |