query encoding refactored
Browse files
app.py
CHANGED
@@ -206,12 +206,8 @@ if 'paragraph_sentence_encodings' in st.session_state:
|
|
206 |
if 'prev_query' not in st.session_state or st.session_state.prev_query != query:
|
207 |
st.session_state.prev_query = query
|
208 |
st.session_state.premise = query
|
209 |
-
query_tokens = st.session_state.bert_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(
|
210 |
-
'cuda')
|
211 |
-
with torch.no_grad(): # Disable gradient calculation for inference
|
212 |
-
query_encoding = st.session_state.bert_model(**query_tokens).last_hidden_state[:, 0,
|
213 |
-
:].cpu().numpy() # Move the result to CPU and convert to NumPy
|
214 |
|
|
|
215 |
paragraph_scores = []
|
216 |
sentence_scores = []
|
217 |
total_count = len(st.session_state.paragraph_sentence_encodings)
|
|
|
206 |
if 'prev_query' not in st.session_state or st.session_state.prev_query != query:
|
207 |
st.session_state.prev_query = query
|
208 |
st.session_state.premise = query
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
+
query_encoding = encode_sentence(query)
|
211 |
paragraph_scores = []
|
212 |
sentence_scores = []
|
213 |
total_count = len(st.session_state.paragraph_sentence_encodings)
|