prashant commited on
Commit
b941115
·
1 Parent(s): efb11f2

fix for update embedding cache

Browse files
Files changed (1) hide show
  1. utils/semantic_search.py +40 -14
utils/semantic_search.py CHANGED
@@ -100,24 +100,50 @@ def semanticSearchPipeline(documents:List[Document]):
100
  list of document returned by preprocessing pipeline.
101
 
102
  """
103
-
104
- document_store = InMemoryDocumentStore()
105
- document_store.write_documents(documents)
106
-
107
- embedding_model = config.get('semantic_search','RETRIEVER')
108
- embedding_model_format = config.get('semantic_search','RETRIEVER_FORMAT')
109
- embedding_layer = int(config.get('semantic_search','RETRIEVER_EMB_LAYER'))
110
- retriever_top_k = int(config.get('semantic_search','RETRIEVER_TOP_K'))
111
-
 
 
 
 
 
 
 
 
 
 
112
 
113
-
114
- querycheck = QueryCheck()
115
- retriever = EmbeddingRetriever(
116
  document_store=document_store,
117
  embedding_model=embedding_model,top_k = retriever_top_k,
118
  emb_extraction_layer=embedding_layer, scale_score =True,
119
  model_format=embedding_model_format, use_gpu = True)
120
- document_store.update_embeddings(retriever)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  reader_model = config.get('semantic_search','READER')
122
  reader_top_k = retriever_top_k
123
  reader = FARMReader(model_name_or_path=reader_model,
@@ -150,7 +176,7 @@ def semanticsearchAnnotator(matches: List[List[int]], document):
150
  end_idx = match[1]
151
  annotated_text = (annotated_text + document[start:start_idx]
152
  + str(annotation(body=document[start_idx:end_idx],
153
- label="ANSWER", background="#964448", color='#ffffff')))
154
  start = end_idx
155
 
156
  annotated_text = annotated_text + document[end_idx:]
 
100
  list of document returned by preprocessing pipeline.
101
 
102
  """
103
+ if 'document_store' in st.session_state:
104
+ document_store = st.session_state['document_store']
105
+ temp = document_store.get_all_documents()
106
+ if st.session_state('filename') != temp[0].meta['name']:
107
+
108
+ document_store = InMemoryDocumentStore()
109
+ document_store.write_documents(documents)
110
+
111
+ embedding_model = config.get('semantic_search','RETRIEVER')
112
+ embedding_model_format = config.get('semantic_search','RETRIEVER_FORMAT')
113
+ embedding_layer = int(config.get('semantic_search','RETRIEVER_EMB_LAYER'))
114
+ retriever_top_k = int(config.get('semantic_search','RETRIEVER_TOP_K'))
115
+ retriever = EmbeddingRetriever(
116
+ document_store=document_store,
117
+ embedding_model=embedding_model,top_k = retriever_top_k,
118
+ emb_extraction_layer=embedding_layer, scale_score =True,
119
+ model_format=embedding_model_format, use_gpu = True)
120
+ document_store.update_embeddings(retriever)
121
+ else:
122
 
123
+ retriever = EmbeddingRetriever(
 
 
124
  document_store=document_store,
125
  embedding_model=embedding_model,top_k = retriever_top_k,
126
  emb_extraction_layer=embedding_layer, scale_score =True,
127
  model_format=embedding_model_format, use_gpu = True)
128
+
129
+ else:
130
+ document_store = InMemoryDocumentStore()
131
+ document_store.write_documents(documents)
132
+
133
+ embedding_model = config.get('semantic_search','RETRIEVER')
134
+ embedding_model_format = config.get('semantic_search','RETRIEVER_FORMAT')
135
+ embedding_layer = int(config.get('semantic_search','RETRIEVER_EMB_LAYER'))
136
+ retriever_top_k = int(config.get('semantic_search','RETRIEVER_TOP_K'))
137
+ retriever = EmbeddingRetriever(
138
+ document_store=document_store,
139
+ embedding_model=embedding_model,top_k = retriever_top_k,
140
+ emb_extraction_layer=embedding_layer, scale_score =True,
141
+ model_format=embedding_model_format, use_gpu = True)
142
+ document_store.update_embeddings(retriever)
143
+ st.session_state['document_store'] = document_store
144
+
145
+ querycheck = QueryCheck()
146
+
147
  reader_model = config.get('semantic_search','READER')
148
  reader_top_k = retriever_top_k
149
  reader = FARMReader(model_name_or_path=reader_model,
 
176
  end_idx = match[1]
177
  annotated_text = (annotated_text + document[start:start_idx]
178
  + str(annotation(body=document[start_idx:end_idx],
179
+ label="CONTEXT", background="#964448", color='#ffffff')))
180
  start = end_idx
181
 
182
  annotated_text = annotated_text + document[end_idx:]