original paragraph info hidden until button pressed
Browse files
app.py
CHANGED
@@ -157,73 +157,83 @@ if 'paragraph_sentence_encodings' in st.session_state:
|
|
157 |
query = st.text_input("Enter your query")
|
158 |
|
159 |
if query:
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
(
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
-
output_2 = paraphrase(output_1)
|
226 |
-
print(output_2)
|
227 |
-
st.write("Paraphrased Paragraph: ", output_2)
|
228 |
-
st.write("Modified Paragraph: ", paragraph['modified_text'])
|
229 |
-
st.write("Original Paragraph: ", paragraph['original_text'])
|
|
|
157 |
query = st.text_input("Enter your query")
|
158 |
|
159 |
if query:
|
160 |
+
if 'paragraph_scores' not in st.session_state:
|
161 |
+
query_tokens = st.session_state.bert_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(
|
162 |
+
'cuda')
|
163 |
+
with torch.no_grad(): # Disable gradient calculation for inference
|
164 |
+
query_encoding = st.session_state.bert_model(**query_tokens).last_hidden_state[:, 0,
|
165 |
+
:].cpu().numpy() # Move the result to CPU and convert to NumPy
|
166 |
+
|
167 |
+
paragraph_scores = []
|
168 |
+
sentence_scores = []
|
169 |
+
total_count = len(st.session_state.paragraph_sentence_encodings)
|
170 |
+
processing_progress_bar = st.progress(0)
|
171 |
+
|
172 |
+
for index, paragraph_sentence_encoding in enumerate(st.session_state.paragraph_sentence_encodings):
|
173 |
+
progress_percentage = index / (total_count - 1)
|
174 |
+
processing_progress_bar.progress(progress_percentage)
|
175 |
+
|
176 |
+
sentence_similarities = []
|
177 |
+
for sentence_encoding in paragraph_sentence_encoding[1]:
|
178 |
+
if sentence_encoding:
|
179 |
+
similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0]
|
180 |
+
combined_score, similarity_score, commonality_score = combined_similarity(similarity,
|
181 |
+
sentence_encoding[0],
|
182 |
+
query)
|
183 |
+
sentence_similarities.append((combined_score, sentence_encoding[0], commonality_score))
|
184 |
+
sentence_scores.append((combined_score, sentence_encoding[0]))
|
185 |
+
|
186 |
+
sentence_similarities.sort(reverse=True, key=lambda x: x[0])
|
187 |
+
|
188 |
+
if len(sentence_similarities) >= 3:
|
189 |
+
top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities[:3]])
|
190 |
+
top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities[:3]])
|
191 |
+
top_three_sentences = sentence_similarities[:3]
|
192 |
+
elif sentence_similarities:
|
193 |
+
top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities])
|
194 |
+
top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities])
|
195 |
+
top_three_sentences = sentence_similarities
|
196 |
+
else:
|
197 |
+
top_three_avg_similarity = 0
|
198 |
+
top_three_avg_commonality = 0
|
199 |
+
top_three_sentences = []
|
200 |
+
|
201 |
+
top_three_texts = [s[1] for s in top_three_sentences]
|
202 |
+
remaining_texts = [s[0] for s in paragraph_sentence_encoding[1] if s and s[0] not in top_three_texts]
|
203 |
+
reordered_paragraph = top_three_texts + remaining_texts
|
204 |
+
|
205 |
+
original_paragraph = ' '.join([s[0] for s in paragraph_sentence_encoding[1] if s])
|
206 |
+
modified_paragraph = ' '.join(reordered_paragraph)
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
paragraph_scores.append(
|
212 |
+
(top_three_avg_similarity, top_three_avg_commonality,
|
213 |
+
{'modified_text': modified_paragraph, 'original_text': paragraph_sentence_encoding[0]})
|
214 |
+
)
|
215 |
+
|
216 |
+
sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True)
|
217 |
+
st.session_state.paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True)
|
218 |
+
|
219 |
+
if 'paragraph_scores' in st.session_state:
|
220 |
+
|
221 |
+
if "paraphrased_paragrpahs" not in st.session_state:
|
222 |
+
st.session_state.paraphrased_paragrpahs = []
|
223 |
+
for i, (similarity_score, commonality_score, paragraph) in enumerate(st.session_state.paragraph_scores[:5]):
|
224 |
+
|
225 |
+
|
226 |
+
output_1 = paraphrase(paragraph['modified_text'])
|
227 |
+
# print(output_1)
|
228 |
+
output_2 = paraphrase(output_1)
|
229 |
+
# print(output_2)
|
230 |
+
st.session_state.paraphrased_paragrpahs.append(output_2)
|
231 |
+
st.write("Top scored paragraphs and their scores:")
|
232 |
+
for i, (similarity_score, commonality_score, paragraph) in enumerate(
|
233 |
+
st.session_state.paragraph_scores[:5]):
|
234 |
+
st.write("Paraphrased Paragraph: ", st.session_state.paraphrased_paragrpahs[i])
|
235 |
+
if st.button(f"Show Original Paragraph {i + 1}", key=f"button_{i}"):
|
236 |
+
st.write(f"Similarity Score: {similarity_score}, Commonality Score: {commonality_score}")
|
237 |
+
st.write("Original Paragraph: ", paragraph['original_text'])
|
238 |
+
# st.write("Modified Paragraph: ", paragraph['modified_text'])
|
239 |
|
|
|
|
|
|
|
|
|
|