Spaces:
Build error
Build error
File size: 6,091 Bytes
3e601d6 73b733e 2cadbb9 73b733e 3e601d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
### LIBRARIES ###
# # Data
import numpy as np
import pandas as pd
import json
from math import floor
# Robustness Gym and Analysis
import robustnessgym as rg
from gensim.models.doc2vec import Doc2Vec
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import nltk
nltk.download('punkt') #make sure that punkt is downloaded
# App & Visualization
import streamlit as st
import altair as alt
# utils
from interactive_model_cards import utils as ut
from interactive_model_cards import app_layout as al
from random import sample
from PIL import Image
### LOADING DATA ###
# model card data
@st.experimental_memo
def load_model_card():
with open("./assets/data/text_explainer/model_card.json") as f:
mc_text = json.load(f)
return mc_text
# pre-computed robusntess gym dev bench
# @st.experimental_singleton
@st.cache(allow_output_mutation=True)
def load_data():
# load dev bench
devBench = rg.DevBench.load("./assets/data/rg/sst_db.devbench")
return devBench
# load model
@st.experimental_singleton
def load_model():
model = rg.HuggingfaceModel(
"distilbert-base-uncased-finetuned-sst-2-english", is_classifier=True
)
return model
#load pre-computed embedding
def load_embedding():
embedding = pd.read_pickle("./assets/models/sst_vectors.pkl")
return embedding
#load doc2vec model
@st.experimental_singleton
def load_doc2vec():
doc2vec = Doc2Vec.load("./assets/models/sst_train.doc2vec")
return(doc2vec)
# @st.experimental_memo
def load_examples():
with open("./assets/data/user_data/example_sentence.json") as f:
examples = json.load(f)
return examples
# loading the dataset
def load_basic():
# load data
devBench = load_data()
# load model
model = load_model()
#protected_classes
protected_classes = json.load(open("./assets/data/protected_terms.json"))
return devBench, model, protected_classes
@st.experimental_singleton
def load_title():
img = Image.open("./assets/img/title.png")
return(img)
if __name__ == "__main__":
### STREAMLIT APP CONGFIG ###
st.set_page_config(layout="wide", page_title="Interactive Model Card")
# import custom styling
ut.init_style()
### LOAD DATA AND SESSION VARIABLES ###
# ******* loading the mode and the data
with st.spinner():
sst_db, model,protected_classes = load_basic()
embedding = load_embedding()
doc2vec = load_doc2vec()
# load example sentences
sentence_examples = load_examples()
# ******* session state variables
if "user_data" not in st.session_state:
st.session_state["user_data"] = pd.DataFrame()
if "example_sent" not in st.session_state:
st.session_state["example_sent"] = "I like you. I love you"
if "quant_ex" not in st.session_state:
st.session_state["quant_ex"] = {"Overall Performance": sst_db.metrics["model"]}
if "selected_slice" not in st.session_state:
st.session_state["selected_slice"] = None
if "slice_terms" not in st.session_state:
st.session_state["slice_terms"] = {}
if "embedding" not in st.session_state:
st.session_state["embedding"] = embedding
if "protected_class" not in st.session_state:
st.session_state["protected_class"] = protected_classes
### STREAMLIT APP LAYOUT###
# ******* MODEL CARD PANEL *******
#st.sidebar.title("Interactive Model Card")
img = load_title()
st.sidebar.image(img,width=400)
st.sidebar.warning("Data is not permanently collected or stored from your interactions, but is temporarily cached during usage.")
st.markdown('''
<a href="javascript:document.getElementsByClassName('css-1rs6os edgvbvh3')[1].click();">
<img src="./assets/img/info.png" style="width:30px;height:30px;"/>
</a>
''', unsafe_allow_html=True)
# load model card data
errors = st.sidebar.checkbox("Show Warnings", value=True)
model_card = load_model_card()
al.model_card_panel(model_card,errors)
lcol, rcol = st.columns([4, 8])
# ******* USER EXAMPLE DATA PANEL *******
st.markdown("---")
with lcol:
# Choose waht to show for the qunatiative analysis.
st.write("""<h1 style="font-size:20px;padding-top:0px;"> Quantitative Analysis</h1>""",
unsafe_allow_html=True)
st.markdown("View the model's performance or visually explore the model's training and testing dataset")
data_view = st.selectbox("Show:",
["Model Performance Metrics","Data Subpopulation Comparison Visualization"])
st.markdown("Any groups you define via the *analysis actions* will be automatically added to the view")
st.markdown("---")
# Additional Analysis Actions
st.write(
"""<h1 style="font-size:18px;padding-top:5px;"> Analysis Actions</h1>""",
unsafe_allow_html=True,
)
al.example_panel(sentence_examples, model, sst_db,doc2vec)
# ****** GUIDANCE PANEL *****
with st.expander("Guidance"):
st.markdown(
"Need help understanding what you're seeing in this model card?"
)
st.markdown(
" * **[Understanding Metrics](https://stanford.edu/~shervine/teaching/cs-229/cheatsheet-machine-learning-tips-and-tricks)**: A cheatsheet of model metrics"
)
st.markdown(
" * **[Understanding Sentiment Models](https://www.semanticscholar.org/topic/Sentiment-analysis/6011)**: An overview of sentiment analysis"
)
st.markdown(
"* **[Next Steps](https://docs.google.com/document/d/1r9J1NQ7eTibpXkCpcucDEPhASGbOQAMhRTBvosGu4Pk/edit?usp=sharin)**: Suggestions for follow-on actions"
)
st.markdown("Feel free to submit feedback via our [online form](https://sfdc.co/imc_feedback)")
# ******* QUANTITATIVE DATA PANEL *******
al.quant_panel(sst_db, st.session_state["embedding"], rcol,data_view)
|