Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -103,6 +103,11 @@ def faiss_search(index, question_vector, k=1):
|
|
103 |
distances, indices = index.search(question_vector, k)
|
104 |
return distances,indices
|
105 |
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
def predict_faiss(model, tokenizer, embedding_model, df, question, index):
|
108 |
t = time.time()
|
@@ -156,12 +161,18 @@ def predict_test(model, tokenizer, embedding_model, df, question, index): # sen
|
|
156 |
mostSimContext = re.sub(r'\s+', ' ', mostSimContext)
|
157 |
|
158 |
segments = sent_tokenize(mostSimContext, engine="crfcut")
|
|
|
|
|
|
|
|
|
159 |
|
160 |
-
segments_index = set_index(get_embeddings(embedding_model,segments))
|
161 |
_distances,_indices = faiss_search(segments_index, question_vector)
|
162 |
mostSimSegment = segments[_indices[0][0]]
|
163 |
|
164 |
Answer = model_pipeline(model, tokenizer,question,mostSimSegment)
|
|
|
|
|
|
|
165 |
|
166 |
# Find the start and end indices of mostSimSegment within mostSimContext
|
167 |
start_index = mostSimContext.find(Answer)
|
|
|
103 |
distances, indices = index.search(question_vector, k)
|
104 |
return distances,indices
|
105 |
|
106 |
+
def create_segment_index(vector):
|
107 |
+
segment_index = faiss.IndexFlatL2(vector.shape[1])
|
108 |
+
segment_index.add(vector)
|
109 |
+
return segment_index
|
110 |
+
|
111 |
|
112 |
def predict_faiss(model, tokenizer, embedding_model, df, question, index):
|
113 |
t = time.time()
|
|
|
161 |
mostSimContext = re.sub(r'\s+', ' ', mostSimContext)
|
162 |
|
163 |
segments = sent_tokenize(mostSimContext, engine="crfcut")
|
164 |
+
|
165 |
+
segment_embeddings = get_embeddings(segments)
|
166 |
+
segment_embeddings = prepare_sentences_vector(segment_embeddings)
|
167 |
+
segment_index = create_segment_index(segment_embeddings)
|
168 |
|
|
|
169 |
_distances,_indices = faiss_search(segments_index, question_vector)
|
170 |
mostSimSegment = segments[_indices[0][0]]
|
171 |
|
172 |
Answer = model_pipeline(model, tokenizer,question,mostSimSegment)
|
173 |
+
|
174 |
+
if len(answer) <= 2:
|
175 |
+
answer = mostSimSegment
|
176 |
|
177 |
# Find the start and end indices of mostSimSegment within mostSimContext
|
178 |
start_index = mostSimContext.find(Answer)
|