zmbfeng commited on
Commit
c7fa677
1 Parent(s): 0612a6a

original paragraph info hidden until button pressed

Browse files
Files changed (1) hide show
  1. app.py +79 -69
app.py CHANGED
@@ -157,73 +157,83 @@ if 'paragraph_sentence_encodings' in st.session_state:
157
  query = st.text_input("Enter your query")
158
 
159
  if query:
160
- query_tokens = st.session_state.bert_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(
161
- 'cuda')
162
- with torch.no_grad(): # Disable gradient calculation for inference
163
- query_encoding = st.session_state.bert_model(**query_tokens).last_hidden_state[:, 0,
164
- :].cpu().numpy() # Move the result to CPU and convert to NumPy
165
-
166
- paragraph_scores = []
167
- sentence_scores = []
168
- total_count = len(st.session_state.paragraph_sentence_encodings)
169
- processing_progress_bar = st.progress(0)
170
-
171
- for index, paragraph_sentence_encoding in enumerate(st.session_state.paragraph_sentence_encodings):
172
- progress_percentage = index / (total_count - 1)
173
- processing_progress_bar.progress(progress_percentage)
174
-
175
- sentence_similarities = []
176
- for sentence_encoding in paragraph_sentence_encoding[1]:
177
- if sentence_encoding:
178
- similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0]
179
- combined_score, similarity_score, commonality_score = combined_similarity(similarity,
180
- sentence_encoding[0],
181
- query)
182
- sentence_similarities.append((combined_score, sentence_encoding[0], commonality_score))
183
- sentence_scores.append((combined_score, sentence_encoding[0]))
184
-
185
- sentence_similarities.sort(reverse=True, key=lambda x: x[0])
186
-
187
- if len(sentence_similarities) >= 3:
188
- top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities[:3]])
189
- top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities[:3]])
190
- top_three_sentences = sentence_similarities[:3]
191
- elif sentence_similarities:
192
- top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities])
193
- top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities])
194
- top_three_sentences = sentence_similarities
195
- else:
196
- top_three_avg_similarity = 0
197
- top_three_avg_commonality = 0
198
- top_three_sentences = []
199
-
200
- top_three_texts = [s[1] for s in top_three_sentences]
201
- remaining_texts = [s[0] for s in paragraph_sentence_encoding[1] if s and s[0] not in top_three_texts]
202
- reordered_paragraph = top_three_texts + remaining_texts
203
-
204
- original_paragraph = ' '.join([s[0] for s in paragraph_sentence_encoding[1] if s])
205
- modified_paragraph = ' '.join(reordered_paragraph)
206
-
207
-
208
-
209
-
210
- paragraph_scores.append(
211
- (top_three_avg_similarity, top_three_avg_commonality,
212
- {'modified_text': modified_paragraph, 'original_text': paragraph_sentence_encoding[0]})
213
- )
214
-
215
- sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True)
216
- paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True)
217
-
218
- st.write("Top scored paragraphs and their scores:")
219
- for similarity_score, commonality_score, paragraph in paragraph_scores[:5]:
220
- st.write(f"Similarity Score: {similarity_score}, Commonality Score: {commonality_score}")
221
-
222
- output_1 = paraphrase(paragraph['modified_text'])
223
- print(output_1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- output_2 = paraphrase(output_1)
226
- print(output_2)
227
- st.write("Paraphrased Paragraph: ", output_2)
228
- st.write("Modified Paragraph: ", paragraph['modified_text'])
229
- st.write("Original Paragraph: ", paragraph['original_text'])
 
157
  query = st.text_input("Enter your query")
158
 
159
  if query:
160
+ if 'paragraph_scores' not in st.session_state:
161
+ query_tokens = st.session_state.bert_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(
162
+ 'cuda')
163
+ with torch.no_grad(): # Disable gradient calculation for inference
164
+ query_encoding = st.session_state.bert_model(**query_tokens).last_hidden_state[:, 0,
165
+ :].cpu().numpy() # Move the result to CPU and convert to NumPy
166
+
167
+ paragraph_scores = []
168
+ sentence_scores = []
169
+ total_count = len(st.session_state.paragraph_sentence_encodings)
170
+ processing_progress_bar = st.progress(0)
171
+
172
+ for index, paragraph_sentence_encoding in enumerate(st.session_state.paragraph_sentence_encodings):
173
+ progress_percentage = index / (total_count - 1)
174
+ processing_progress_bar.progress(progress_percentage)
175
+
176
+ sentence_similarities = []
177
+ for sentence_encoding in paragraph_sentence_encoding[1]:
178
+ if sentence_encoding:
179
+ similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0]
180
+ combined_score, similarity_score, commonality_score = combined_similarity(similarity,
181
+ sentence_encoding[0],
182
+ query)
183
+ sentence_similarities.append((combined_score, sentence_encoding[0], commonality_score))
184
+ sentence_scores.append((combined_score, sentence_encoding[0]))
185
+
186
+ sentence_similarities.sort(reverse=True, key=lambda x: x[0])
187
+
188
+ if len(sentence_similarities) >= 3:
189
+ top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities[:3]])
190
+ top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities[:3]])
191
+ top_three_sentences = sentence_similarities[:3]
192
+ elif sentence_similarities:
193
+ top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities])
194
+ top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities])
195
+ top_three_sentences = sentence_similarities
196
+ else:
197
+ top_three_avg_similarity = 0
198
+ top_three_avg_commonality = 0
199
+ top_three_sentences = []
200
+
201
+ top_three_texts = [s[1] for s in top_three_sentences]
202
+ remaining_texts = [s[0] for s in paragraph_sentence_encoding[1] if s and s[0] not in top_three_texts]
203
+ reordered_paragraph = top_three_texts + remaining_texts
204
+
205
+ original_paragraph = ' '.join([s[0] for s in paragraph_sentence_encoding[1] if s])
206
+ modified_paragraph = ' '.join(reordered_paragraph)
207
+
208
+
209
+
210
+
211
+ paragraph_scores.append(
212
+ (top_three_avg_similarity, top_three_avg_commonality,
213
+ {'modified_text': modified_paragraph, 'original_text': paragraph_sentence_encoding[0]})
214
+ )
215
+
216
+ sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True)
217
+ st.session_state.paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True)
218
+
219
+ if 'paragraph_scores' in st.session_state:
220
+
221
+ if "paraphrased_paragrpahs" not in st.session_state:
222
+ st.session_state.paraphrased_paragrpahs = []
223
+ for i, (similarity_score, commonality_score, paragraph) in enumerate(st.session_state.paragraph_scores[:5]):
224
+
225
+
226
+ output_1 = paraphrase(paragraph['modified_text'])
227
+ # print(output_1)
228
+ output_2 = paraphrase(output_1)
229
+ # print(output_2)
230
+ st.session_state.paraphrased_paragrpahs.append(output_2)
231
+ st.write("Top scored paragraphs and their scores:")
232
+ for i, (similarity_score, commonality_score, paragraph) in enumerate(
233
+ st.session_state.paragraph_scores[:5]):
234
+ st.write("Paraphrased Paragraph: ", st.session_state.paraphrased_paragrpahs[i])
235
+ if st.button(f"Show Original Paragraph {i + 1}", key=f"button_{i}"):
236
+ st.write(f"Similarity Score: {similarity_score}, Commonality Score: {commonality_score}")
237
+ st.write("Original Paragraph: ", paragraph['original_text'])
238
+ # st.write("Modified Paragraph: ", paragraph['modified_text'])
239