fashion_classification / data_analysis_app.py
ceyda's picture
Update data_analysis_app.py
1f42cb1
from datasets import load_dataset
import streamlit as st
from data_utils import get_embedding
from bokeh.plotting import figure,show
from bokeh.io import push_notebook, output_notebook
# output_notebook()
from bokeh.palettes import d3
from bokeh.models import ColumnDataSource, Grid, LinearAxis, Plot, Scatter
from bokeh.transform import factor_cmap, factor_mark
import base64
from io import BytesIO
label_columns=["gender","subCategory","masterCategory"]
model_interest=['facebook/deit-tiny-patch16-224', # very small model 5M param model
'microsoft/beit-base-patch16-224', # big model
"facebook/dino-vits8",
"facebook/levit-128S"]
def convert_base64(img):
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return "data:image/jpeg;base64,"+img_str
@st.experimental_singleton
def cache_embedding(model_name):
dataset=load_dataset("ceyda/fashion-products-small", split="train")
dataset=dataset.shuffle(seed=100) #pick a random seed
viz_dat=dataset.train_test_split(0.1,shuffle=False) #일부를 visualization위해서 뽑시단
viz_dat=viz_dat["test"]
embedding = get_embedding(model_name,viz_dat)
embedding["image"]=embedding["image"].apply(convert_base64)
labels = {label:viz_dat.unique(label) for label in label_columns}
return embedding,labels
@st.experimental_singleton
def cache_graph(model_name,color_column):
embedding,labels=cache_embedding(model_name)
color_palette = (d3['Category20'][20]+d3['Category20b'][20]+d3['Category20c'][20])[:len(labels[color_column])]
source = ColumnDataSource(data=embedding)
# colors = factor_cmap('gender', palette=["purple","navy","green","blue","pink"], factors=embedding["gender"].unique())
TOOLS="hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,reset,tap,save,box_select,lasso_select,"
TOOLTIPS = """
<div>
<div>
<img
src="@image" height="42" alt="@image" width="42"
style="float: left; margin: 0px 15px 15px 0px;"
border="2"
></img>
</div>
"""
p = figure(tools=TOOLS,tooltips=TOOLTIPS)
p.scatter(x="x", y="y", source=source,
# marker=factor_mark('gender', ['circle', 'circle_cross', 'circle_dot','circle_x','circle_y'], labels["gender"]),
color=factor_cmap(color_column, color_palette, labels[color_column])
)
return p
st.write("It takes some time for the graph to load...wait please")
model_name=st.sidebar.selectbox("Model",model_interest)
color_column=st.selectbox("Color by",label_columns)
p=cache_graph(model_name,color_column)
st.bokeh_chart(p, use_container_width=False)