Spaces:
Build error
Build error
# --- Streamlit --- | |
import streamlit as st | |
# --- Data --- | |
import robustnessgym as rg | |
import pandas as pd | |
# --- Misc --- | |
from math import floor | |
from random import sample | |
from interactive_model_cards import utils as ut | |
def format_data(user_text, model): | |
""" Helper Function : Formatting and preparing the user's input data""" | |
# adding user data to the data panel | |
dp = rg.DataPanel({"sentence": [user_text], "label": [1]}) | |
# run prediction | |
dp, pred = ut.update_pred(dp, model) | |
# summarizing the prediction | |
idx_max = pred["Probability"].argmax() | |
pred_sum = pred["Label"][idx_max] | |
pred_bin = int(1) if pred["Label"][idx_max] == "Positive Sentiment" else int(0) | |
pred_num = floor(pred["Probability"][idx_max] * 10 ** 3) / 10 ** 3 | |
pred_conf = ut.conf_level(pred["Probability"][idx_max]) | |
new_example = { | |
"sentence": user_text, | |
"model label": pred_sum, | |
"model label binary": pred_bin, | |
"probability": pred_num, | |
"confidence": pred_conf, | |
"user label": None, | |
"user label binary": None, | |
} | |
return new_example | |
def slice_misc(table): | |
""" Helper Function: format new slice""" | |
table = st.session_state["user_data"][ | |
["sentence", "model label binary", "user label binary"] | |
] | |
table.columns = ["sentence", "pred", "label"] | |
dp = rg.DataPanel( | |
{ | |
"sentence": table["sentence"].tolist(), | |
"label": table["label"].tolist(), | |
"pred": table["pred"].tolist(), | |
} | |
) | |
# give the sentence a name | |
dp._identifier = "Your Sentences" | |
# updated the dev bench | |
rg_bench = ut.new_bench() | |
rg_bench.add_slices(dp) | |
return rg_bench | |
# ***** ADDING CUSTOM SENTENCES ******* | |
def examples(): | |
""" DEPRECATED METHOD FOR UI for displaying the custom sentences""" | |
# writing the metrics out to a column | |
st.markdown("** Custom Example Sentences **") | |
if not st.session_state["user_data"].empty: | |
# remove the user data slice | |
# visualize the overall performance | |
st.markdown("*Model Performance*") | |
key = "Your Sentences" | |
all_metrics = {key: {}} | |
all_metrics[key]["metrics"] = st.session_state["quant_ex"][ "User Custom Sentence"][key] | |
all_metrics[key]["source"] = key | |
# chart = ut.visualize_metrics(st.session_state["quant_ex"]["User Custom Sentence"]) | |
chart = ut.visualize_metrics(all_metrics, col_val="#ff7f0e") | |
st.altair_chart(chart) | |
# add to overall model performance | |
# visualize examples | |
st.markdown("*Examples*") | |
st.dataframe( | |
st.session_state["user_data"][ | |
["sentence", "model label", "user label", "probability"] | |
] | |
) | |
else: | |
st.write("No examples added yet") | |
def example_sentence(sentence_examples, model,doc2vec): | |
""" UI for creating a custom sentences""" | |
# **** Entering Text *** | |
placeholder = st.empty() | |
user_text = placeholder.text_input( | |
"Write your own example sentences, or click 'Get Suggest Examples'", | |
st.session_state["example_sent"], | |
) | |
gen_button = st.button("Get Suggested Example", key="user_text") | |
if gen_button: | |
st.session_state["example_sent"] = sample( | |
set(sentence_examples["sentences"]), 1 | |
)[0] | |
user_text = placeholder.text_input( | |
"Write your own example sentences, or click 'Get Suggested Example'", | |
st.session_state["example_sent"], | |
) | |
if user_text != "": | |
new_example = format_data(user_text, model) | |
# **** Prediction Summary *** | |
with st.form(key="my_form"): | |
st.markdown("**Model Prediction Summary**") | |
st.markdown( | |
f"*The sentiment model predicts that this sentence has an overall `{new_example['model label']}` with an `{new_example['confidence']}` (p={new_example['probability']})*" | |
) | |
# prediction agreement solicitation | |
st.markdown("**Do you agree with the prediction?**") | |
agreement = st.radio("Indicate your agreement below", ["Agree", "Disagree"]) | |
# getting the user label | |
user_lab = new_example["model label"] | |
user_lab_bin = ( | |
int(1) if new_example["model label"] == "Positive Sentiment" else int(0) | |
) | |
if agreement != "Agree": | |
user_lab = ( | |
"Negative Sentiment" | |
if new_example["model label"] == "Positive Sentiment" | |
else "Positive Sentiment" | |
) | |
user_lab_bin = int(0) if user_lab_bin == 1 else int(1) | |
# update robustness gym with user_example prediction | |
if st.form_submit_button("Add to exisiting sentences"): | |
# updating the user data frame | |
if user_text != "": | |
new_example["user label"] = user_lab | |
new_example["user label binary"] = user_lab_bin | |
# data frame to append to session info | |
new_example = pd.DataFrame(new_example, index=[0]) | |
# update the session | |
st.session_state["user_data"] = st.session_state[ | |
"user_data" | |
].append(new_example, ignore_index=True) | |
# update the user data dev bench | |
user_bench = slice_misc(st.session_state["user_data"]) | |
# add bench | |
st.session_state["quant_ex"][ | |
"User Custom Sentence" | |
] = user_bench.metrics["model"] | |
#update the selected data | |
st.session_state["selected_slice"] = { | |
'name':'Your Sentences', | |
'source': 'User Custom Sentence', | |
} | |
#update the sentence with an embedding | |
embedding = st.session_state["embedding"] | |
tmp = ut.prep_sentence_embedding(name ='Your Sentences', | |
source = 'User Custom Sentence', | |
sentence = user_text, | |
sentiment= user_lab, | |
sort_order= 100, #always put it on top | |
embed_model = doc2vec, | |
idx = max(embedding.index)+1) | |
st.session_state["embedding"] = embedding.append(tmp) | |
# ***** DEFINTING CUSTOM SUBGROUPS ******* | |
def subpopulation_slice(sst_db,doc2vec): | |
with st.form(key="subpop_form"): | |
st.markdown("Define you subpopulation") | |
user_terms = st.text_input( | |
"Enter a set of comma separated words", "comedy, hilarious, clown" | |
) | |
slice_choice = st.selectbox( | |
"Choose Data Source", ["Training Data", "Evaluation Data"] | |
) | |
slice_name = st.text_input( | |
"Give your subpopulation a name", "subpop_1", key="custom_slice_name" | |
) | |
if st.form_submit_button("Create Subpopulation"): | |
# build a new slice | |
user_terms = [x.strip() for x in user_terms.split(",")] | |
slice_builder = rg.HasAnyPhrase([user_terms], identifiers=[slice_name]) | |
# on test data | |
slice_ids = ut.get_sliceid(list(sst_db.slices)) | |
if slice_choice == "Training Data": | |
#st.write("returning training data") | |
idx = ut.get_sliceidx(slice_ids,"xyz_train") | |
else: | |
#st.write("returning evaluation data") | |
idx = ut.get_sliceidx(slice_ids,"xyz_test") | |
sst_db(slice_builder, list(sst_db.slices)[idx], ["sentence"]) | |
#get store slice name | |
slice_ids = ut.get_sliceid(list(sst_db.slices)) | |
slice_idx= [i for i, elem in enumerate(slice_ids) if slice_name in str(elem)][0] | |
slice_rg_name = [elem for i, elem in enumerate(slice_ids) if slice_name in str(elem)] | |
slice_data = list(sst_db.slices)[slice_idx] | |
# updating the the selected slice | |
st.session_state["selected_slice"] = { | |
'name': slice_rg_name[0], | |
'source': 'Custom Slice', | |
} | |
#storing the slice terms | |
st.session_state["slice_terms"][slice_rg_name[0]] = user_terms | |
#adding slice to embedding | |
#update the sentence with an embedding | |
embedding = st.session_state["embedding"] | |
tmp = ut.prep_sentence_embedding(name = slice_name, | |
source = "Custom Slice", | |
sentence = slice_data['sentence'], | |
sentiment= ["Positive Sentiment" if int(round(x)) == 1 else "Negative Sentiment" for x in slice_data["label"]], | |
sort_order=5, | |
embed_model = doc2vec, | |
idx = max(embedding.index)+1, | |
type="multi") | |
st.session_state["embedding"] = embedding.append(tmp) | |
return slice_name | |
def slice_vis(terms, sst_db, slice_name): | |
''' DEPRECIATED FUNCTION TO VISUALIZE SLICE DATA''' | |
st.write(terms) | |
# TO DO - FORMATTING AND ADD METRICS | |
if len(list(sst_db.slices)) > 2: | |
# write out the dataset for this subset | |
# get selected slice data | |
slice_ids = ut.get_sliceid(list(sst_db.slices)) | |
idx = [i for i, elem in enumerate(slice_ids) if slice_name in str(elem)] | |
if len(idx) > 1: | |
raise ValueError("More than one slice with the same name") | |
else: | |
idx = idx[0] | |
if idx is not None: | |
slice_data = list(sst_db.slices)[idx] | |
slice_id = str(slice_data._identifier) | |
# visualize performance | |
all_metrics = ut.metrics_to_dict(sst_db.metrics["model"], slice_id) | |
chart = ut.visualize_metrics(all_metrics) | |
st.altair_chart(chart) | |
# write slice data to UI | |
st.dataframe(ut.slice_to_df(slice_data)) | |
else: | |
st.write("No slice found") | |
# ***** EXAMPLE PANEL UI ******* | |
def example_panel(sentence_examples, model, sst_db,doc2vec): | |
""" Layout for the custom example panel""" | |
# Data Expander | |
''' | |
st.markdown( | |
"Here's an overview of the ways you can add customized the performance results. Using the drop down menu above, you can choose from one of three options" | |
) | |
st.markdown( | |
"1. **Define a new subpopulation** : Create a new subset from the model's training or testing data" | |
) | |
st.markdown("1. **Add your own sentences** : Add your own sentences as examples") | |
st.markdown( | |
"3. **Add your own dataset** : Upload your own (small) dataset from a csv file" | |
) | |
''' | |
st.markdown("Modify the quantitative analysis results by defining your own subpopulations in the data, including your own data by adding your own sentences or dataset.") | |
with st.expander("Explore new subpopulations in model data"): | |
# create slice | |
slice_terms = subpopulation_slice(sst_db,doc2vec) | |
# visualize slice | |
slice_name = st.session_state["custom_slice_name"] | |
with st.expander("Explore with your own sentences"): | |
# adding a column for user text input | |
example_sentence(sentence_examples, model,doc2vec) | |
# examples() | |
with st.expander("Explore with your own dataset"): | |
st.error("This feature is not enabled for the online deployment") | |
__all__=["example_panel"] | |