Spaces:
Runtime error
Runtime error
import streamlit as st | |
import numpy as np | |
import numpy.linalg as la | |
import pickle | |
import os | |
import gdown | |
from sentence_transformers import SentenceTransformer | |
import matplotlib.pyplot as plt | |
import math | |
#import streamlit_analytics | |
# Compute Cosine Similarity | |
def cosine_similarity(x,y): | |
""" | |
Exponentiated cosine similarity | |
""" | |
x_arr = np.array(x) | |
y_arr = np.array(y) | |
if la.norm(x_arr) == 0 or la.norm(y_arr) == 0: | |
return math.exp(-1) | |
else: | |
return math.exp(np.dot(x_arr,y_arr)/(max(la.norm(x_arr)*la.norm(y_arr),1))) | |
# Function to Load Glove Embeddings | |
def load_glove_embeddings(glove_path="Data/embeddings.pkl"): | |
with open(glove_path,"rb") as f: | |
embeddings_dict = pickle.load(f, encoding="latin1") | |
return embeddings_dict | |
def get_model_id_gdrive(model_type): | |
if model_type == "25d": | |
word_index_id = "13qMXs3-oB9C6kfSRMwbAtzda9xuAUtt8" | |
embeddings_id = "1-RXcfBvWyE-Av3ZHLcyJVsps0RYRRr_2" | |
elif model_type == "50d": | |
embeddings_id = "1DBaVpJsitQ1qxtUvV1Kz7ThDc3az16kZ" | |
word_index_id = "1rB4ksHyHZ9skes-fJHMa2Z8J1Qa7awQ9" | |
elif model_type == "100d": | |
word_index_id = "1-oWV0LqG3fmrozRZ7WB1jzeTJHRUI3mq" | |
embeddings_id = "1SRHfX130_6Znz7zbdfqboKosz-PfNvNp" | |
return word_index_id, embeddings_id | |
def download_glove_embeddings_gdrive(model_type): | |
# Get glove embeddings from google drive | |
word_index_id, embeddings_id = get_model_id_gdrive(model_type) | |
# Use gdown to get files from google drive | |
embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy" | |
word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl" | |
# Download word_index pickle file | |
print("Downloading word index dictionary....\n") | |
gdown.download(id=word_index_id, output = word_index_temp, quiet=False) | |
# Download embeddings numpy file | |
print("Donwloading embedings...\n\n") | |
gdown.download(id=embeddings_id, output = embeddings_temp, quiet=False) | |
#@st.cache_data() | |
def load_glove_embeddings_gdrive(model_type): | |
word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl" | |
embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy" | |
# Load word index dictionary | |
word_index_dict = pickle.load(open(word_index_temp,"rb"), encoding="latin") | |
# Load embeddings numpy | |
embeddings = np.load(embeddings_temp) | |
return word_index_dict, embeddings | |
def load_sentence_transformer_model(model_name): | |
sentenceTransformer = SentenceTransformer(model_name) | |
return sentenceTransformer | |
def get_sentence_transformer_embeddings(sentence, model_name="all-MiniLM-L6-v2"): | |
# 384 dimensional embedding | |
# Default model: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 | |
sentenceTransformer = load_sentence_transformer_model(model_name) | |
try: | |
return sentenceTransformer.encode(sentence) | |
except: | |
if model_name=="all-MiniLM-L6-v2": | |
return np.zeros(384) | |
else: | |
return np.zeros(512) | |
def get_result_from_gpt(sentence, gpt_model="3.5"): | |
### GPT Authentication ### | |
pass | |
### | |
def get_glove_embeddings(word, word_index_dict, embeddings, model_type): | |
""" | |
Get glove embedding for a single word | |
""" | |
if word.lower() in word_index_dict: | |
return embeddings[word_index_dict[word.lower()]] | |
else: | |
return np.zeros(int(model_type.split("d")[0])) | |
# Get Averaged Glove Embedding of a sentence | |
def averaged_glove_embeddings(sentence, embeddings_dict): | |
words = sentence.split(" ") | |
glove_embedding = np.zeros(50) | |
count_words = 0 | |
for word in words: | |
word = word.lower() | |
if word.lower() in embeddings_dict: | |
glove_embedding += embeddings_dict[word.lower()] | |
count_words += 1 | |
return glove_embedding/max(count_words,1) | |
def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, model_type=50): | |
words = sentence.split(" ") | |
embedding = np.zeros(int(model_type.split("d")[0])) | |
count_words = 0 | |
for word in words: | |
if word in word_index_dict: | |
embedding += embeddings[word_index_dict[word]] | |
count_words += 1 | |
return embedding/max(count_words,1) | |
def get_category_embeddings(embeddings_metadata): | |
model_name = embeddings_metadata["model_name"] | |
st.session_state["cat_embed_" + model_name] = {} | |
for category in st.session_state.categories.split(" "): | |
if model_name: | |
if not category in st.session_state["cat_embed_" + model_name]: | |
st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category, model_name=model_name) | |
else: | |
if not category in st.session_state["cat_embed_" + model_name]: | |
st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category) | |
def update_category_embeddings(embedings_metadata): | |
get_category_embeddings(embeddings_metadata) | |
def get_sorted_cosine_similarity(input_sentence, embeddings_metadata): | |
categories = st.session_state.categories.split(" ") | |
cosine_sim = {} | |
if embeddings_metadata["embedding_model"] == "glove": | |
word_index_dict = embeddings_metadata["word_index_dict"] | |
embeddings = embeddings_metadata["embeddings"] | |
model_type = embeddings_metadata["model_type"] | |
input_embedding = averaged_glove_embeddings_gdrive(st.session_state.text_search, word_index_dict, embeddings, model_type) | |
for index in range(len(categories)): | |
cosine_sim[index] = cosine_similarity(input_embedding, get_glove_embeddings(categories[index], word_index_dict, embeddings, model_type)) | |
else: | |
model_name = embeddings_metadata["model_name"] | |
if not "cat_embed_" + model_name in st.session_state: | |
get_category_embeddings(embeddings_metadata) | |
category_embeddings = st.session_state["cat_embed_" + model_name] | |
print("text_search = ", st.session_state.text_search) | |
if model_name: | |
input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search, model_name=model_name) | |
else: | |
input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search) | |
for index in range(len(categories)): | |
#cosine_sim[index] = cosine_similarity(input_embedding, get_sentence_transformer_embeddings(categories[index], model_name=model_name)) | |
# Update category embeddings if category not found | |
if not categories[index] in category_embeddings: | |
update_category_embeddings(embeddings_metadata) | |
category_embeddings = st.session_state["cat_embed_" + model_name] | |
cosine_sim[index] = cosine_similarity(input_embedding, category_embeddings[categories[index]]) | |
sorted_cosine_sim = sorted(cosine_sim.items(), key = lambda x: x[1], reverse=True) | |
return sorted_cosine_sim | |
def plot_piechart(sorted_cosine_scores_items): | |
sorted_cosine_scores = np.array([sorted_cosine_scores_items[index][1] for index in range(len(sorted_cosine_scores_items))]) | |
categories = st.session_state.categories.split(" ") | |
categories_sorted = [categories[sorted_cosine_scores_items[index][0]] for index in range(len(sorted_cosine_scores_items))] | |
fig, ax = plt.subplots() | |
ax.pie(sorted_cosine_scores, labels = categories_sorted, autopct='%1.1f%%') | |
st.pyplot(fig) # Figure | |
def plot_piechart_helper(sorted_cosine_scores_items): | |
sorted_cosine_scores = np.array([sorted_cosine_scores_items[index][1] for index in range(len(sorted_cosine_scores_items))]) | |
categories = st.session_state.categories.split(" ") | |
categories_sorted = [categories[sorted_cosine_scores_items[index][0]] for index in range(len(sorted_cosine_scores_items))] | |
fig, ax = plt.subplots(figsize=(3,3)) | |
my_explode = np.zeros(len(categories_sorted)) | |
my_explode[0] = 0.2 | |
if len(categories_sorted) == 3: | |
my_explode[1] = 0.1 # explode this by 0.2 | |
elif len(categories_sorted) > 3: | |
my_explode[2] = 0.05 | |
ax.pie(sorted_cosine_scores, labels = categories_sorted, autopct='%1.1f%%', explode=my_explode) | |
return fig | |
def plot_piecharts(sorted_cosine_scores_models): | |
scores_list = [] | |
categories = st.session_state.categories.split(" ") | |
index = 0 | |
for model in sorted_cosine_scores_models: | |
scores_list.append(sorted_cosine_scores_models[model]) | |
#scores_list[index] = np.array([scores_list[index][ind2][1] for ind2 in range(len(scores_list[index]))]) | |
index += 1 | |
if len(sorted_cosine_scores_models) == 2: | |
fig, (ax1, ax2) = plt.subplots(2) | |
categories_sorted = [categories[scores_list[0][index][0]] for index in range(len(scores_list[0]))] | |
sorted_scores = np.array([scores_list[0][index][1] for index in range(len(scores_list[0]))]) | |
ax1.pie(sorted_scores, labels = categories_sorted, autopct='%1.1f%%') | |
categories_sorted = [categories[scores_list[1][index][0]] for index in range(len(scores_list[1]))] | |
sorted_scores = np.array([scores_list[1][index][1] for index in range(len(scores_list[1]))]) | |
ax2.pie(sorted_scores, labels = categories_sorted, autopct='%1.1f%%') | |
st.pyplot(fig) | |
def plot_alatirchart(sorted_cosine_scores_models): | |
models = list(sorted_cosine_scores_models.keys()) | |
tabs = st.tabs(models) | |
figs = {} | |
for model in models: | |
figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model]) | |
for index in range(len(tabs)): | |
with tabs[index]: | |
st.pyplot(figs[models[index]]) | |
# Text Search | |
#with streamlit_analytics.track(): | |
# --------------------- | |
# Common part | |
# --------------------- | |
st.sidebar.title('GloVe Twitter') | |
st.sidebar.markdown(""" | |
GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Pretrained on | |
2 billion tweets with vocabulary size of 1.2 million. Download from [Stanford NLP](http://nlp.stanford.edu/data/glove.twitter.27B.zip). | |
Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Global Vectors for Word Representation*. | |
""") | |
model_type = st.sidebar.selectbox( | |
'Choose the model', | |
('25d', '50d'), | |
index=1 | |
) | |
st.title("Search Based Retrieval Demo") | |
st.subheader("Pass in space separated categories you want this search demo to be about.") | |
#st.selectbox(label="Pick the categories you want this search demo to be about...", | |
# options=("Flowers Colors Cars Weather Food", "Chocolate Milk", "Anger Joy Sad Frustration Worry Happiness", "Positive Negative"), | |
# key="categories" | |
# ) | |
st.text_input(label="Categories", key="categories",value="Flowers Colors Cars Weather Food") | |
print(st.session_state["categories"]) | |
print(type(st.session_state["categories"])) | |
#print("Categories = ", categories) | |
#st.session_state.categories = categories | |
st.subheader("Pass in an input word or even a sentence") | |
text_search = st.text_input(label="Input your sentence", key="text_search", value="Roses are red, trucks are blue, and Seattle is grey right now") | |
#st.session_state.text_search = text_search | |
# Download glove embeddings if it doesn't exist | |
embeddings_path = "embeddings_" + str(model_type) + "_temp.npy" | |
word_index_dict_path = "word_index_dict_" + str(model_type) + "_temp.pkl" | |
if not os.path.isfile(embeddings_path) or not os.path.isfile(word_index_dict_path): | |
print("Model type = ", model_type) | |
glove_path = "Data/glove_" + str(model_type) + ".pkl" | |
print("glove_path = ", glove_path) | |
# Download embeddings from google drive | |
with st.spinner("Downloading glove embeddings..."): | |
download_glove_embeddings_gdrive(model_type) | |
# Load glove embeddings | |
word_index_dict, embeddings = load_glove_embeddings_gdrive(model_type) | |
# Find closest word to an input word | |
if st.session_state.text_search: | |
# Glove embeddings | |
print("Glove Embedding") | |
embeddings_metadata = {"embedding_model": "glove", "word_index_dict": word_index_dict, "embeddings": embeddings, "model_type": model_type} | |
with st.spinner("Obtaining Cosine similarity for Glove..."): | |
sorted_cosine_sim_glove = get_sorted_cosine_similarity(st.session_state.text_search, embeddings_metadata) | |
# Sentence transformer embeddings | |
print("Sentence Transformer Embedding") | |
embeddings_metadata = {"embedding_model": "transformers","model_name": ""} | |
with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."): | |
sorted_cosine_sim_transformer = get_sorted_cosine_similarity(st.session_state.text_search, embeddings_metadata) | |
# Results and Plot Pie Chart for Glove | |
print("Categories are: ", st.session_state.categories) | |
st.subheader("Closest word I have between: " + st.session_state.categories + " as per different Embeddings") | |
print(sorted_cosine_sim_glove) | |
print(sorted_cosine_sim_transformer) | |
#print(sorted_distilbert) | |
# Altair Chart for all models | |
plot_alatirchart({"glove_" + str(model_type): sorted_cosine_sim_glove, \ | |
"sentence_transformer_384": sorted_cosine_sim_transformer}) | |
#"distilbert_512": sorted_distilbert}) | |
st.write("") | |
st.write("Demo developed by [Dr. Karthik Mohan](https://www.linkedin.com/in/karthik-mohan-72a4b323/)") | |