Sirinoot commited on
Commit
31bc28f
1 Parent(s): 1f0e176

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -103,6 +103,11 @@ def faiss_search(index, question_vector, k=1):
103
  distances, indices = index.search(question_vector, k)
104
  return distances,indices
105
 
 
 
 
 
 
106
 
107
  def predict_faiss(model, tokenizer, embedding_model, df, question, index):
108
  t = time.time()
@@ -156,12 +161,18 @@ def predict_test(model, tokenizer, embedding_model, df, question, index): # sen
156
  mostSimContext = re.sub(r'\s+', ' ', mostSimContext)
157
 
158
  segments = sent_tokenize(mostSimContext, engine="crfcut")
 
 
 
 
159
 
160
- segments_index = set_index(get_embeddings(embedding_model,segments))
161
  _distances,_indices = faiss_search(segments_index, question_vector)
162
  mostSimSegment = segments[_indices[0][0]]
163
 
164
  Answer = model_pipeline(model, tokenizer,question,mostSimSegment)
 
 
 
165
 
166
  # Find the start and end indices of mostSimSegment within mostSimContext
167
  start_index = mostSimContext.find(Answer)
 
103
  distances, indices = index.search(question_vector, k)
104
  return distances,indices
105
 
106
+ def create_segment_index(vector):
107
+ segment_index = faiss.IndexFlatL2(vector.shape[1])
108
+ segment_index.add(vector)
109
+ return segment_index
110
+
111
 
112
  def predict_faiss(model, tokenizer, embedding_model, df, question, index):
113
  t = time.time()
 
161
  mostSimContext = re.sub(r'\s+', ' ', mostSimContext)
162
 
163
  segments = sent_tokenize(mostSimContext, engine="crfcut")
164
+
165
+ segment_embeddings = get_embeddings(segments)
166
+ segment_embeddings = prepare_sentences_vector(segment_embeddings)
167
+ segment_index = create_segment_index(segment_embeddings)
168
 
 
169
  _distances,_indices = faiss_search(segments_index, question_vector)
170
  mostSimSegment = segments[_indices[0][0]]
171
 
172
  Answer = model_pipeline(model, tokenizer,question,mostSimSegment)
173
+
174
+ if len(answer) <= 2:
175
+ answer = mostSimSegment
176
 
177
  # Find the start and end indices of mostSimSegment within mostSimContext
178
  start_index = mostSimContext.find(Answer)