Shreemit commited on
Commit
6beecf5
1 Parent(s): 26e374b

Uploaded files

Browse files
Data/embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b6ab0cf2f12332d3208f00bc0c3964374e2a3aadb22bf251005f5d0e05674ba
3
+ size 133
embeddings_demo.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import numpy.linalg as la
4
+ import pickle
5
+ import os
6
+ import gdown
7
+ from sentence_transformers import SentenceTransformer
8
+ import matplotlib.pyplot as plt
9
+ import math
10
+ #import streamlit_analytics
11
+
12
+
13
+ # Compute Cosine Similarity
14
+ def cosine_similarity(x,y):
15
+ """
16
+ Exponentiated cosine similarity
17
+ """
18
+
19
+ x_arr = np.array(x)
20
+ y_arr = np.array(y)
21
+ if la.norm(x_arr) == 0 or la.norm(y_arr) == 0:
22
+ return math.exp(-1)
23
+ else:
24
+ return math.exp(np.dot(x_arr,y_arr)/(max(la.norm(x_arr)*la.norm(y_arr),1)))
25
+
26
+
27
+ # Function to Load Glove Embeddings
28
+ def load_glove_embeddings(glove_path="Data/embeddings.pkl"):
29
+
30
+ with open(glove_path,"rb") as f:
31
+ embeddings_dict = pickle.load(f, encoding="latin1")
32
+
33
+ return embeddings_dict
34
+
35
+
36
+ def get_model_id_gdrive(model_type):
37
+
38
+ if model_type == "25d":
39
+ word_index_id = "13qMXs3-oB9C6kfSRMwbAtzda9xuAUtt8"
40
+ embeddings_id = "1-RXcfBvWyE-Av3ZHLcyJVsps0RYRRr_2"
41
+ elif model_type == "50d":
42
+ embeddings_id = "1DBaVpJsitQ1qxtUvV1Kz7ThDc3az16kZ"
43
+ word_index_id = "1rB4ksHyHZ9skes-fJHMa2Z8J1Qa7awQ9"
44
+ elif model_type == "100d":
45
+ word_index_id = "1-oWV0LqG3fmrozRZ7WB1jzeTJHRUI3mq"
46
+ embeddings_id = "1SRHfX130_6Znz7zbdfqboKosz-PfNvNp"
47
+
48
+
49
+ return word_index_id, embeddings_id
50
+
51
+
52
+
53
+ def download_glove_embeddings_gdrive(model_type):
54
+ # Get glove embeddings from google drive
55
+
56
+ word_index_id, embeddings_id = get_model_id_gdrive(model_type)
57
+
58
+ # Use gdown to get files from google drive
59
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
60
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
61
+
62
+ # Download word_index pickle file
63
+ print("Downloading word index dictionary....\n")
64
+ gdown.download(id=word_index_id, output = word_index_temp, quiet=False)
65
+
66
+ # Download embeddings numpy file
67
+ print("Donwloading embedings...\n\n")
68
+ gdown.download(id=embeddings_id, output = embeddings_temp, quiet=False)
69
+
70
+ #@st.cache_data()
71
+ def load_glove_embeddings_gdrive(model_type):
72
+
73
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
74
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
75
+
76
+ # Load word index dictionary
77
+ word_index_dict = pickle.load(open(word_index_temp,"rb"), encoding="latin")
78
+
79
+ # Load embeddings numpy
80
+ embeddings = np.load(embeddings_temp)
81
+
82
+ return word_index_dict, embeddings
83
+
84
+ @st.cache_resource()
85
+ def load_sentence_transformer_model(model_name):
86
+
87
+ sentenceTransformer = SentenceTransformer(model_name)
88
+ return sentenceTransformer
89
+
90
+
91
+ def get_sentence_transformer_embeddings(sentence, model_name="all-MiniLM-L6-v2"):
92
+
93
+ # 384 dimensional embedding
94
+ # Default model: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
95
+
96
+ sentenceTransformer = load_sentence_transformer_model(model_name)
97
+
98
+ try:
99
+ return sentenceTransformer.encode(sentence)
100
+ except:
101
+ if model_name=="all-MiniLM-L6-v2":
102
+ return np.zeros(384)
103
+ else:
104
+ return np.zeros(512)
105
+
106
+ def get_result_from_gpt(sentence, gpt_model="3.5"):
107
+
108
+ ### GPT Authentication ###
109
+
110
+ pass
111
+
112
+ ###
113
+
114
+ def get_glove_embeddings(word, word_index_dict, embeddings, model_type):
115
+ """
116
+ Get glove embedding for a single word
117
+ """
118
+
119
+ if word.lower() in word_index_dict:
120
+ return embeddings[word_index_dict[word.lower()]]
121
+ else:
122
+ return np.zeros(int(model_type.split("d")[0]))
123
+
124
+
125
+
126
+ # Get Averaged Glove Embedding of a sentence
127
+ def averaged_glove_embeddings(sentence, embeddings_dict):
128
+ words = sentence.split(" ")
129
+ glove_embedding = np.zeros(50)
130
+ count_words = 0
131
+ for word in words:
132
+ word = word.lower()
133
+ if word.lower() in embeddings_dict:
134
+ glove_embedding += embeddings_dict[word.lower()]
135
+ count_words += 1
136
+
137
+ return glove_embedding/max(count_words,1)
138
+
139
+
140
+ def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, model_type=50):
141
+ words = sentence.split(" ")
142
+ embedding = np.zeros(int(model_type.split("d")[0]))
143
+ count_words = 0
144
+ for word in words:
145
+ if word in word_index_dict:
146
+ embedding += embeddings[word_index_dict[word]]
147
+ count_words += 1
148
+
149
+ return embedding/max(count_words,1)
150
+
151
+ def get_category_embeddings(embeddings_metadata):
152
+ model_name = embeddings_metadata["model_name"]
153
+ st.session_state["cat_embed_" + model_name] = {}
154
+ for category in st.session_state.categories.split(" "):
155
+ if model_name:
156
+ if not category in st.session_state["cat_embed_" + model_name]:
157
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category, model_name=model_name)
158
+ else:
159
+ if not category in st.session_state["cat_embed_" + model_name]:
160
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category)
161
+
162
+
163
+ def update_category_embeddings(embedings_metadata):
164
+
165
+ get_category_embeddings(embeddings_metadata)
166
+
167
+
168
+ def get_sorted_cosine_similarity(input_sentence, embeddings_metadata):
169
+
170
+ categories = st.session_state.categories.split(" ")
171
+ cosine_sim = {}
172
+ if embeddings_metadata["embedding_model"] == "glove":
173
+ word_index_dict = embeddings_metadata["word_index_dict"]
174
+ embeddings = embeddings_metadata["embeddings"]
175
+ model_type = embeddings_metadata["model_type"]
176
+
177
+ input_embedding = averaged_glove_embeddings_gdrive(st.session_state.text_search, word_index_dict, embeddings, model_type)
178
+
179
+ for index in range(len(categories)):
180
+ cosine_sim[index] = cosine_similarity(input_embedding, get_glove_embeddings(categories[index], word_index_dict, embeddings, model_type))
181
+ else:
182
+ model_name = embeddings_metadata["model_name"]
183
+ if not "cat_embed_" + model_name in st.session_state:
184
+ get_category_embeddings(embeddings_metadata)
185
+
186
+ category_embeddings = st.session_state["cat_embed_" + model_name]
187
+
188
+ print("text_search = ", st.session_state.text_search)
189
+ if model_name:
190
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search, model_name=model_name)
191
+ else:
192
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search)
193
+ for index in range(len(categories)):
194
+ #cosine_sim[index] = cosine_similarity(input_embedding, get_sentence_transformer_embeddings(categories[index], model_name=model_name))
195
+
196
+ # Update category embeddings if category not found
197
+ if not categories[index] in category_embeddings:
198
+ update_category_embeddings(embeddings_metadata)
199
+ category_embeddings = st.session_state["cat_embed_" + model_name]
200
+ cosine_sim[index] = cosine_similarity(input_embedding, category_embeddings[categories[index]])
201
+
202
+
203
+
204
+
205
+ sorted_cosine_sim = sorted(cosine_sim.items(), key = lambda x: x[1], reverse=True)
206
+
207
+ return sorted_cosine_sim
208
+
209
+
210
+ def plot_piechart(sorted_cosine_scores_items):
211
+ sorted_cosine_scores = np.array([sorted_cosine_scores_items[index][1] for index in range(len(sorted_cosine_scores_items))])
212
+ categories = st.session_state.categories.split(" ")
213
+ categories_sorted = [categories[sorted_cosine_scores_items[index][0]] for index in range(len(sorted_cosine_scores_items))]
214
+ fig, ax = plt.subplots()
215
+ ax.pie(sorted_cosine_scores, labels = categories_sorted, autopct='%1.1f%%')
216
+ st.pyplot(fig) # Figure
217
+
218
+ def plot_piechart_helper(sorted_cosine_scores_items):
219
+ sorted_cosine_scores = np.array([sorted_cosine_scores_items[index][1] for index in range(len(sorted_cosine_scores_items))])
220
+ categories = st.session_state.categories.split(" ")
221
+ categories_sorted = [categories[sorted_cosine_scores_items[index][0]] for index in range(len(sorted_cosine_scores_items))]
222
+ fig, ax = plt.subplots(figsize=(3,3))
223
+ my_explode = np.zeros(len(categories_sorted))
224
+ my_explode[0] = 0.2
225
+ if len(categories_sorted) == 3:
226
+ my_explode[1] = 0.1 # explode this by 0.2
227
+ elif len(categories_sorted) > 3:
228
+ my_explode[2] = 0.05
229
+ ax.pie(sorted_cosine_scores, labels = categories_sorted, autopct='%1.1f%%', explode=my_explode)
230
+
231
+ return fig
232
+
233
+ def plot_piecharts(sorted_cosine_scores_models):
234
+
235
+ scores_list = []
236
+ categories = st.session_state.categories.split(" ")
237
+ index = 0
238
+ for model in sorted_cosine_scores_models:
239
+ scores_list.append(sorted_cosine_scores_models[model])
240
+ #scores_list[index] = np.array([scores_list[index][ind2][1] for ind2 in range(len(scores_list[index]))])
241
+ index += 1
242
+
243
+ if len(sorted_cosine_scores_models) == 2:
244
+ fig, (ax1, ax2) = plt.subplots(2)
245
+
246
+ categories_sorted = [categories[scores_list[0][index][0]] for index in range(len(scores_list[0]))]
247
+ sorted_scores = np.array([scores_list[0][index][1] for index in range(len(scores_list[0]))])
248
+ ax1.pie(sorted_scores, labels = categories_sorted, autopct='%1.1f%%')
249
+
250
+ categories_sorted = [categories[scores_list[1][index][0]] for index in range(len(scores_list[1]))]
251
+ sorted_scores = np.array([scores_list[1][index][1] for index in range(len(scores_list[1]))])
252
+ ax2.pie(sorted_scores, labels = categories_sorted, autopct='%1.1f%%')
253
+
254
+ st.pyplot(fig)
255
+
256
+ def plot_alatirchart(sorted_cosine_scores_models):
257
+
258
+
259
+ models = list(sorted_cosine_scores_models.keys())
260
+ tabs = st.tabs(models)
261
+ figs = {}
262
+ for model in models:
263
+ figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
264
+
265
+ for index in range(len(tabs)):
266
+ with tabs[index]:
267
+ st.pyplot(figs[models[index]])
268
+
269
+
270
+
271
+ # Text Search
272
+ #with streamlit_analytics.track():
273
+
274
+ # ---------------------
275
+ # Common part
276
+ # ---------------------
277
+ st.sidebar.title('GloVe Twitter')
278
+ st.sidebar.markdown("""
279
+ GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Pretrained on
280
+ 2 billion tweets with vocabulary size of 1.2 million. Download from [Stanford NLP](http://nlp.stanford.edu/data/glove.twitter.27B.zip).
281
+
282
+ Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Global Vectors for Word Representation*.
283
+ """)
284
+
285
+ model_type = st.sidebar.selectbox(
286
+ 'Choose the model',
287
+ ('25d', '50d'),
288
+ index=1
289
+ )
290
+
291
+
292
+
293
+ st.title("Search Based Retrieval Demo")
294
+ st.subheader("Pass in space separated categories you want this search demo to be about.")
295
+ #st.selectbox(label="Pick the categories you want this search demo to be about...",
296
+ # options=("Flowers Colors Cars Weather Food", "Chocolate Milk", "Anger Joy Sad Frustration Worry Happiness", "Positive Negative"),
297
+ # key="categories"
298
+ # )
299
+ st.text_input(label="Categories", key="categories",value="Flowers Colors Cars Weather Food")
300
+ print(st.session_state["categories"])
301
+ print(type(st.session_state["categories"]))
302
+ #print("Categories = ", categories)
303
+ #st.session_state.categories = categories
304
+
305
+ st.subheader("Pass in an input word or even a sentence")
306
+ text_search = st.text_input(label="Input your sentence", key="text_search", value="Roses are red, trucks are blue, and Seattle is grey right now")
307
+ #st.session_state.text_search = text_search
308
+
309
+ # Download glove embeddings if it doesn't exist
310
+ embeddings_path = "embeddings_" + str(model_type) + "_temp.npy"
311
+ word_index_dict_path = "word_index_dict_" + str(model_type) + "_temp.pkl"
312
+ if not os.path.isfile(embeddings_path) or not os.path.isfile(word_index_dict_path):
313
+ print("Model type = ", model_type)
314
+ glove_path = "Data/glove_" + str(model_type) + ".pkl"
315
+ print("glove_path = ", glove_path)
316
+
317
+ # Download embeddings from google drive
318
+ with st.spinner("Downloading glove embeddings..."):
319
+ download_glove_embeddings_gdrive(model_type)
320
+
321
+
322
+ # Load glove embeddings
323
+ word_index_dict, embeddings = load_glove_embeddings_gdrive(model_type)
324
+
325
+
326
+
327
+
328
+ # Find closest word to an input word
329
+ if st.session_state.text_search:
330
+
331
+ # Glove embeddings
332
+ print("Glove Embedding")
333
+ embeddings_metadata = {"embedding_model": "glove", "word_index_dict": word_index_dict, "embeddings": embeddings, "model_type": model_type}
334
+ with st.spinner("Obtaining Cosine similarity for Glove..."):
335
+ sorted_cosine_sim_glove = get_sorted_cosine_similarity(st.session_state.text_search, embeddings_metadata)
336
+
337
+
338
+ # Sentence transformer embeddings
339
+ print("Sentence Transformer Embedding")
340
+ embeddings_metadata = {"embedding_model": "transformers","model_name": ""}
341
+ with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."):
342
+ sorted_cosine_sim_transformer = get_sorted_cosine_similarity(st.session_state.text_search, embeddings_metadata)
343
+
344
+
345
+ # Results and Plot Pie Chart for Glove
346
+ print("Categories are: ", st.session_state.categories)
347
+ st.subheader("Closest word I have between: " + st.session_state.categories + " as per different Embeddings")
348
+
349
+ print(sorted_cosine_sim_glove)
350
+ print(sorted_cosine_sim_transformer)
351
+ #print(sorted_distilbert)
352
+ # Altair Chart for all models
353
+ plot_alatirchart({"glove_" + str(model_type): sorted_cosine_sim_glove, \
354
+ "sentence_transformer_384": sorted_cosine_sim_transformer})
355
+ #"distilbert_512": sorted_distilbert})
356
+
357
+ st.write("")
358
+ st.write("Demo developed by [Dr. Karthik Mohan](https://www.linkedin.com/in/karthik-mohan-72a4b323/)")
359
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gdown==4.7.1
2
+ sentence_transformers
3
+ matplotlib
4
+ click<=8.0.4
test_search_bar.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import numpy.linalg as la
4
+ import pickle
5
+ #import streamlit_analytics
6
+
7
+
8
+ # Compute Cosine Similarity
9
+ def cosine_similarity(x,y):
10
+
11
+ x_arr = np.array(x)
12
+ y_arr = np.array(y)
13
+ return np.dot(x_arr,y_arr)/(la.norm(x_arr)*la.norm(y_arr))
14
+
15
+
16
+ # Function to Load Glove Embeddings
17
+ def load_glove_embeddings(glove_path="Data/embeddings.pkl"):
18
+
19
+ with open(glove_path,"rb") as f:
20
+ embeddings_dict = pickle.load(f)
21
+
22
+ return embeddings_dict
23
+
24
+ # Get Averaged Glove Embedding of a sentence
25
+ def averaged_glove_embeddings(sentence, embeddings_dict):
26
+ words = sentence.split(" ")
27
+ glove_embedding = np.zeros(50)
28
+ count_words = 0
29
+ for word in words:
30
+ if word in embeddings_dict:
31
+ glove_embedding += embeddings_dict[word]
32
+ count_words += 1
33
+
34
+ return glove_embedding/max(count_words,1)
35
+
36
+ # Load glove embeddings
37
+ glove_embeddings = load_glove_embeddings()
38
+
39
+ # Gold standard words to search from
40
+ gold_words = ["flower","mountain","tree","car","building"]
41
+
42
+ # Text Search
43
+ #with streamlit_analytics.track():
44
+ st.title("Search Based Retrieval Demo")
45
+ st.subheader("Pass in an input word or even a sentence (e.g. jasmine or mount adams)")
46
+ text_search = st.text_input("", value="")
47
+
48
+
49
+ # Find closest word to an input word
50
+ if text_search:
51
+ input_embedding = averaged_glove_embeddings(text_search, glove_embeddings)
52
+ cosine_sim = {}
53
+ for index in range(len(gold_words)):
54
+ cosine_sim[index] = cosine_similarity(input_embedding, glove_embeddings[gold_words[index]])
55
+
56
+ print(cosine_sim)
57
+ sorted_cosine_sim = sorted(cosine_sim.items(), key = lambda x: x[1], reverse=True)
58
+
59
+ st.write("(My search uses glove embeddings)")
60
+ st.write("Closest word I have between flower, mountain, tree, car and building for your input is: ")
61
+ st.subheader(gold_words[sorted_cosine_sim[0][0]] )
62
+ st.write("")
63
+ st.write("Demo developed by Dr. Karthik Mohan")
64
+