zmbfeng commited on
Commit
49fa3c4
1 Parent(s): 7cb50e9

get similiarity scores refactored

Browse files
Files changed (1) hide show
  1. app.py +48 -43
app.py CHANGED
@@ -199,6 +199,51 @@ if uploaded_pdf_file is not None:
199
  st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count}')
200
  st.rerun()
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  if 'paragraph_sentence_encodings' in st.session_state:
203
  query = st.text_input("Enter your query")
204
 
@@ -209,55 +254,15 @@ if 'paragraph_sentence_encodings' in st.session_state:
209
 
210
  query_encoding = encode_sentence(query)
211
  paragraph_scores = []
212
- sentence_scores = []
213
  total_count = len(st.session_state.paragraph_sentence_encodings)
214
  processing_progress_bar = st.progress(0)
215
 
216
- for index, paragraph_sentence_encoding in enumerate(st.session_state.paragraph_sentence_encodings):
217
- progress_percentage = index / (total_count - 1)
218
- processing_progress_bar.progress(progress_percentage)
219
-
220
- sentence_similarities = []
221
- for sentence_encoding in paragraph_sentence_encoding[1]:
222
- if sentence_encoding:
223
- similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0]
224
- combined_score, similarity_score, commonality_score = add_commonality_to_similarity_score(similarity,
225
- sentence_encoding[0],
226
- query)
227
- sentence_similarities.append((combined_score, sentence_encoding[0], commonality_score))
228
- sentence_scores.append((combined_score, sentence_encoding[0]))
229
-
230
- sentence_similarities.sort(reverse=True, key=lambda x: x[0])
231
- # print(sentence_similarities)
232
- if len(sentence_similarities) >= 3:
233
- top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities[:3]])
234
- top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities[:3]])
235
- top_three_sentences = sentence_similarities[:3]
236
- elif sentence_similarities:
237
- top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities])
238
- top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities])
239
- top_three_sentences = sentence_similarities
240
- else:
241
- top_three_avg_similarity = 0
242
- top_three_avg_commonality = 0
243
- top_three_sentences = []
244
- # print(f"top_three_sentences={top_three_sentences}")
245
- # top_three_texts = [s[1] for s in top_three_sentences]
246
- # remaining_texts = [s[0] for s in paragraph_sentence_encoding[1] if s and s[0] not in top_three_texts]
247
- # reordered_paragraph = top_three_texts + remaining_texts
248
- #
249
- # original_paragraph = ' '.join([s[0] for s in paragraph_sentence_encoding[1] if s])
250
- # modified_paragraph = ' '.join(reordered_paragraph)
251
-
252
-
253
 
254
 
255
- paragraph_scores.append(
256
- (top_three_avg_similarity, top_three_avg_commonality,
257
- {'top_three_sentences': top_three_sentences, 'original_text': paragraph_sentence_encoding[0]})
258
- )
259
 
260
- sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True)
261
  st.session_state.paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True)
262
 
263
  if 'paragraph_scores' in st.session_state:
 
199
  st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count}')
200
  st.rerun()
201
 
202
+ def find_sentences_scores(paragraph_sentence_encodings, query_encoding, processing_progress_bar,total_count):
203
+ sentence_scores = []
204
+ for index, paragraph_sentence_encoding in enumerate(paragraph_sentence_encodings):
205
+ progress_percentage = index / (total_count - 1)
206
+ processing_progress_bar.progress(progress_percentage)
207
+
208
+ sentence_similarities = []
209
+ for sentence_encoding in paragraph_sentence_encoding[1]:
210
+ if sentence_encoding:
211
+ similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0]
212
+ combined_score, similarity_score, commonality_score = add_commonality_to_similarity_score(similarity,
213
+ sentence_encoding[0],
214
+ query)
215
+ sentence_similarities.append((combined_score, sentence_encoding[0], commonality_score))
216
+ sentence_scores.append((combined_score, sentence_encoding[0]))
217
+
218
+ sentence_similarities.sort(reverse=True, key=lambda x: x[0])
219
+ # print(sentence_similarities)
220
+ if len(sentence_similarities) >= 3:
221
+ top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities[:3]])
222
+ top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities[:3]])
223
+ top_three_sentences = sentence_similarities[:3]
224
+ elif sentence_similarities:
225
+ top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities])
226
+ top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities])
227
+ top_three_sentences = sentence_similarities
228
+ else:
229
+ top_three_avg_similarity = 0
230
+ top_three_avg_commonality = 0
231
+ top_three_sentences = []
232
+ # print(f"top_three_sentences={top_three_sentences}")
233
+ # top_three_texts = [s[1] for s in top_three_sentences]
234
+ # remaining_texts = [s[0] for s in paragraph_sentence_encoding[1] if s and s[0] not in top_three_texts]
235
+ # reordered_paragraph = top_three_texts + remaining_texts
236
+ #
237
+ # original_paragraph = ' '.join([s[0] for s in paragraph_sentence_encoding[1] if s])
238
+ # modified_paragraph = ' '.join(reordered_paragraph)
239
+
240
+ paragraph_scores.append(
241
+ (top_three_avg_similarity, top_three_avg_commonality,
242
+ {'top_three_sentences': top_three_sentences, 'original_text': paragraph_sentence_encoding[0]})
243
+ )
244
+
245
+ sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True)
246
+
247
  if 'paragraph_sentence_encodings' in st.session_state:
248
  query = st.text_input("Enter your query")
249
 
 
254
 
255
  query_encoding = encode_sentence(query)
256
  paragraph_scores = []
257
+
258
  total_count = len(st.session_state.paragraph_sentence_encodings)
259
  processing_progress_bar = st.progress(0)
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
 
263
+ sentence_scores = find_sentences_scores(
264
+ st.session_state.paragraph_sentence_encodings, query_encoding, processing_progress_bar,total_count)
 
 
265
 
 
266
  st.session_state.paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True)
267
 
268
  if 'paragraph_scores' in st.session_state: