search-demo / embeddings_demo.py
Shreemit's picture
Uploaded files
6beecf5
raw
history blame
13.5 kB
import streamlit as st
import numpy as np
import numpy.linalg as la
import pickle
import os
import gdown
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt
import math
#import streamlit_analytics
# Compute Cosine Similarity
def cosine_similarity(x,y):
"""
Exponentiated cosine similarity
"""
x_arr = np.array(x)
y_arr = np.array(y)
if la.norm(x_arr) == 0 or la.norm(y_arr) == 0:
return math.exp(-1)
else:
return math.exp(np.dot(x_arr,y_arr)/(max(la.norm(x_arr)*la.norm(y_arr),1)))
# Function to Load Glove Embeddings
def load_glove_embeddings(glove_path="Data/embeddings.pkl"):
with open(glove_path,"rb") as f:
embeddings_dict = pickle.load(f, encoding="latin1")
return embeddings_dict
def get_model_id_gdrive(model_type):
if model_type == "25d":
word_index_id = "13qMXs3-oB9C6kfSRMwbAtzda9xuAUtt8"
embeddings_id = "1-RXcfBvWyE-Av3ZHLcyJVsps0RYRRr_2"
elif model_type == "50d":
embeddings_id = "1DBaVpJsitQ1qxtUvV1Kz7ThDc3az16kZ"
word_index_id = "1rB4ksHyHZ9skes-fJHMa2Z8J1Qa7awQ9"
elif model_type == "100d":
word_index_id = "1-oWV0LqG3fmrozRZ7WB1jzeTJHRUI3mq"
embeddings_id = "1SRHfX130_6Znz7zbdfqboKosz-PfNvNp"
return word_index_id, embeddings_id
def download_glove_embeddings_gdrive(model_type):
# Get glove embeddings from google drive
word_index_id, embeddings_id = get_model_id_gdrive(model_type)
# Use gdown to get files from google drive
embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
# Download word_index pickle file
print("Downloading word index dictionary....\n")
gdown.download(id=word_index_id, output = word_index_temp, quiet=False)
# Download embeddings numpy file
print("Donwloading embedings...\n\n")
gdown.download(id=embeddings_id, output = embeddings_temp, quiet=False)
#@st.cache_data()
def load_glove_embeddings_gdrive(model_type):
word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
# Load word index dictionary
word_index_dict = pickle.load(open(word_index_temp,"rb"), encoding="latin")
# Load embeddings numpy
embeddings = np.load(embeddings_temp)
return word_index_dict, embeddings
@st.cache_resource()
def load_sentence_transformer_model(model_name):
sentenceTransformer = SentenceTransformer(model_name)
return sentenceTransformer
def get_sentence_transformer_embeddings(sentence, model_name="all-MiniLM-L6-v2"):
# 384 dimensional embedding
# Default model: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
sentenceTransformer = load_sentence_transformer_model(model_name)
try:
return sentenceTransformer.encode(sentence)
except:
if model_name=="all-MiniLM-L6-v2":
return np.zeros(384)
else:
return np.zeros(512)
def get_result_from_gpt(sentence, gpt_model="3.5"):
### GPT Authentication ###
pass
###
def get_glove_embeddings(word, word_index_dict, embeddings, model_type):
"""
Get glove embedding for a single word
"""
if word.lower() in word_index_dict:
return embeddings[word_index_dict[word.lower()]]
else:
return np.zeros(int(model_type.split("d")[0]))
# Get Averaged Glove Embedding of a sentence
def averaged_glove_embeddings(sentence, embeddings_dict):
words = sentence.split(" ")
glove_embedding = np.zeros(50)
count_words = 0
for word in words:
word = word.lower()
if word.lower() in embeddings_dict:
glove_embedding += embeddings_dict[word.lower()]
count_words += 1
return glove_embedding/max(count_words,1)
def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, model_type=50):
words = sentence.split(" ")
embedding = np.zeros(int(model_type.split("d")[0]))
count_words = 0
for word in words:
if word in word_index_dict:
embedding += embeddings[word_index_dict[word]]
count_words += 1
return embedding/max(count_words,1)
def get_category_embeddings(embeddings_metadata):
model_name = embeddings_metadata["model_name"]
st.session_state["cat_embed_" + model_name] = {}
for category in st.session_state.categories.split(" "):
if model_name:
if not category in st.session_state["cat_embed_" + model_name]:
st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category, model_name=model_name)
else:
if not category in st.session_state["cat_embed_" + model_name]:
st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category)
def update_category_embeddings(embedings_metadata):
get_category_embeddings(embeddings_metadata)
def get_sorted_cosine_similarity(input_sentence, embeddings_metadata):
categories = st.session_state.categories.split(" ")
cosine_sim = {}
if embeddings_metadata["embedding_model"] == "glove":
word_index_dict = embeddings_metadata["word_index_dict"]
embeddings = embeddings_metadata["embeddings"]
model_type = embeddings_metadata["model_type"]
input_embedding = averaged_glove_embeddings_gdrive(st.session_state.text_search, word_index_dict, embeddings, model_type)
for index in range(len(categories)):
cosine_sim[index] = cosine_similarity(input_embedding, get_glove_embeddings(categories[index], word_index_dict, embeddings, model_type))
else:
model_name = embeddings_metadata["model_name"]
if not "cat_embed_" + model_name in st.session_state:
get_category_embeddings(embeddings_metadata)
category_embeddings = st.session_state["cat_embed_" + model_name]
print("text_search = ", st.session_state.text_search)
if model_name:
input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search, model_name=model_name)
else:
input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search)
for index in range(len(categories)):
#cosine_sim[index] = cosine_similarity(input_embedding, get_sentence_transformer_embeddings(categories[index], model_name=model_name))
# Update category embeddings if category not found
if not categories[index] in category_embeddings:
update_category_embeddings(embeddings_metadata)
category_embeddings = st.session_state["cat_embed_" + model_name]
cosine_sim[index] = cosine_similarity(input_embedding, category_embeddings[categories[index]])
sorted_cosine_sim = sorted(cosine_sim.items(), key = lambda x: x[1], reverse=True)
return sorted_cosine_sim
def plot_piechart(sorted_cosine_scores_items):
sorted_cosine_scores = np.array([sorted_cosine_scores_items[index][1] for index in range(len(sorted_cosine_scores_items))])
categories = st.session_state.categories.split(" ")
categories_sorted = [categories[sorted_cosine_scores_items[index][0]] for index in range(len(sorted_cosine_scores_items))]
fig, ax = plt.subplots()
ax.pie(sorted_cosine_scores, labels = categories_sorted, autopct='%1.1f%%')
st.pyplot(fig) # Figure
def plot_piechart_helper(sorted_cosine_scores_items):
sorted_cosine_scores = np.array([sorted_cosine_scores_items[index][1] for index in range(len(sorted_cosine_scores_items))])
categories = st.session_state.categories.split(" ")
categories_sorted = [categories[sorted_cosine_scores_items[index][0]] for index in range(len(sorted_cosine_scores_items))]
fig, ax = plt.subplots(figsize=(3,3))
my_explode = np.zeros(len(categories_sorted))
my_explode[0] = 0.2
if len(categories_sorted) == 3:
my_explode[1] = 0.1 # explode this by 0.2
elif len(categories_sorted) > 3:
my_explode[2] = 0.05
ax.pie(sorted_cosine_scores, labels = categories_sorted, autopct='%1.1f%%', explode=my_explode)
return fig
def plot_piecharts(sorted_cosine_scores_models):
scores_list = []
categories = st.session_state.categories.split(" ")
index = 0
for model in sorted_cosine_scores_models:
scores_list.append(sorted_cosine_scores_models[model])
#scores_list[index] = np.array([scores_list[index][ind2][1] for ind2 in range(len(scores_list[index]))])
index += 1
if len(sorted_cosine_scores_models) == 2:
fig, (ax1, ax2) = plt.subplots(2)
categories_sorted = [categories[scores_list[0][index][0]] for index in range(len(scores_list[0]))]
sorted_scores = np.array([scores_list[0][index][1] for index in range(len(scores_list[0]))])
ax1.pie(sorted_scores, labels = categories_sorted, autopct='%1.1f%%')
categories_sorted = [categories[scores_list[1][index][0]] for index in range(len(scores_list[1]))]
sorted_scores = np.array([scores_list[1][index][1] for index in range(len(scores_list[1]))])
ax2.pie(sorted_scores, labels = categories_sorted, autopct='%1.1f%%')
st.pyplot(fig)
def plot_alatirchart(sorted_cosine_scores_models):
models = list(sorted_cosine_scores_models.keys())
tabs = st.tabs(models)
figs = {}
for model in models:
figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
for index in range(len(tabs)):
with tabs[index]:
st.pyplot(figs[models[index]])
# Text Search
#with streamlit_analytics.track():
# ---------------------
# Common part
# ---------------------
st.sidebar.title('GloVe Twitter')
st.sidebar.markdown("""
GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Pretrained on
2 billion tweets with vocabulary size of 1.2 million. Download from [Stanford NLP](http://nlp.stanford.edu/data/glove.twitter.27B.zip).
Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Global Vectors for Word Representation*.
""")
model_type = st.sidebar.selectbox(
'Choose the model',
('25d', '50d'),
index=1
)
st.title("Search Based Retrieval Demo")
st.subheader("Pass in space separated categories you want this search demo to be about.")
#st.selectbox(label="Pick the categories you want this search demo to be about...",
# options=("Flowers Colors Cars Weather Food", "Chocolate Milk", "Anger Joy Sad Frustration Worry Happiness", "Positive Negative"),
# key="categories"
# )
st.text_input(label="Categories", key="categories",value="Flowers Colors Cars Weather Food")
print(st.session_state["categories"])
print(type(st.session_state["categories"]))
#print("Categories = ", categories)
#st.session_state.categories = categories
st.subheader("Pass in an input word or even a sentence")
text_search = st.text_input(label="Input your sentence", key="text_search", value="Roses are red, trucks are blue, and Seattle is grey right now")
#st.session_state.text_search = text_search
# Download glove embeddings if it doesn't exist
embeddings_path = "embeddings_" + str(model_type) + "_temp.npy"
word_index_dict_path = "word_index_dict_" + str(model_type) + "_temp.pkl"
if not os.path.isfile(embeddings_path) or not os.path.isfile(word_index_dict_path):
print("Model type = ", model_type)
glove_path = "Data/glove_" + str(model_type) + ".pkl"
print("glove_path = ", glove_path)
# Download embeddings from google drive
with st.spinner("Downloading glove embeddings..."):
download_glove_embeddings_gdrive(model_type)
# Load glove embeddings
word_index_dict, embeddings = load_glove_embeddings_gdrive(model_type)
# Find closest word to an input word
if st.session_state.text_search:
# Glove embeddings
print("Glove Embedding")
embeddings_metadata = {"embedding_model": "glove", "word_index_dict": word_index_dict, "embeddings": embeddings, "model_type": model_type}
with st.spinner("Obtaining Cosine similarity for Glove..."):
sorted_cosine_sim_glove = get_sorted_cosine_similarity(st.session_state.text_search, embeddings_metadata)
# Sentence transformer embeddings
print("Sentence Transformer Embedding")
embeddings_metadata = {"embedding_model": "transformers","model_name": ""}
with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."):
sorted_cosine_sim_transformer = get_sorted_cosine_similarity(st.session_state.text_search, embeddings_metadata)
# Results and Plot Pie Chart for Glove
print("Categories are: ", st.session_state.categories)
st.subheader("Closest word I have between: " + st.session_state.categories + " as per different Embeddings")
print(sorted_cosine_sim_glove)
print(sorted_cosine_sim_transformer)
#print(sorted_distilbert)
# Altair Chart for all models
plot_alatirchart({"glove_" + str(model_type): sorted_cosine_sim_glove, \
"sentence_transformer_384": sorted_cosine_sim_transformer})
#"distilbert_512": sorted_distilbert})
st.write("")
st.write("Demo developed by [Dr. Karthik Mohan](https://www.linkedin.com/in/karthik-mohan-72a4b323/)")