XThomasBU commited on
Commit
ea7b686
1 Parent(s): a2ac5f7

condnl checks replaved with dict lookups

Browse files
code/modules/retriever/retriever.py CHANGED
@@ -6,19 +6,19 @@ from modules.retriever.colbert_retriever import ColbertRetriever
6
  class Retriever:
7
  def __init__(self, config):
8
  self.config = config
 
 
 
 
 
9
  self._create_retriever()
10
 
11
  def _create_retriever(self):
12
- if self.config["vectorstore"]["db_option"] == "FAISS":
13
- self.retriever = FaissRetriever()
14
- elif self.config["vectorstore"]["db_option"] == "Chroma":
15
- self.retriever = ChromaRetriever()
16
- elif self.config["vectorstore"]["db_option"] == "RAGatouille":
17
- self.retriever = ColbertRetriever()
18
- else:
19
- raise ValueError(
20
- "Invalid db_option: {}".format(self.config["vectorstore"]["db_option"])
21
- )
22
 
23
  def _return_retriever(self, db):
24
  return self.retriever.return_retriever(db, self.config)
 
6
  class Retriever:
7
  def __init__(self, config):
8
  self.config = config
9
+ self.retriever_classes = {
10
+ "FAISS": FaissRetriever,
11
+ "Chroma": ChromaRetriever,
12
+ "RAGatouille": ColbertRetriever,
13
+ }
14
  self._create_retriever()
15
 
16
  def _create_retriever(self):
17
+ db_option = self.config["vectorstore"]["db_option"]
18
+ retriever_class = self.retriever_classes.get(db_option)
19
+ if not retriever_class:
20
+ raise ValueError(f"Invalid db_option: {db_option}")
21
+ self.retriever = retriever_class()
 
 
 
 
 
22
 
23
  def _return_retriever(self, db):
24
  return self.retriever.return_retriever(db, self.config)
code/modules/vectorstore/store_manager.py CHANGED
@@ -86,6 +86,8 @@ class VectorStoreManager:
86
  ):
87
  if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma"]:
88
  self.embedding_model = self.create_embedding_model()
 
 
89
 
90
  self.logger.info("Initializing vector_db")
91
  self.logger.info(
@@ -132,6 +134,8 @@ class VectorStoreManager:
132
  start_time = time.time() # Start time for loading database
133
  if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma"]:
134
  self.embedding_model = self.create_embedding_model()
 
 
135
  self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
136
  end_time = time.time() # End time for loading database
137
  self.logger.info(
 
86
  ):
87
  if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma"]:
88
  self.embedding_model = self.create_embedding_model()
89
+ else:
90
+ self.embedding_model = None
91
 
92
  self.logger.info("Initializing vector_db")
93
  self.logger.info(
 
134
  start_time = time.time() # Start time for loading database
135
  if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma"]:
136
  self.embedding_model = self.create_embedding_model()
137
+ else:
138
+ self.embedding_model = None
139
  self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
140
  end_time = time.time() # End time for loading database
141
  self.logger.info(
code/modules/vectorstore/vectorstore.py CHANGED
@@ -7,6 +7,11 @@ class VectorStore:
7
  def __init__(self, config):
8
  self.config = config
9
  self.vectorstore = None
 
 
 
 
 
10
 
11
  def _create_database(
12
  self,
@@ -16,32 +21,32 @@ class VectorStore:
16
  document_metadata,
17
  embedding_model,
18
  ):
19
- if self.config["vectorstore"]["db_option"] == "FAISS":
20
- self.vectorstore = FaissVectorStore(self.config)
21
- self.vectorstore.create_database(document_chunks, embedding_model)
22
- elif self.config["vectorstore"]["db_option"] == "Chroma":
23
- self.vectorstore = ChromaVectorStore(self.config)
24
- self.vectorstore.create_database(document_chunks, embedding_model)
25
- elif self.config["vectorstore"]["db_option"] == "RAGatouille":
26
- self.vectorstore = ColbertVectorStore(self.config)
27
  self.vectorstore.create_database(
28
  documents, document_names, document_metadata
29
  )
30
  else:
31
- raise ValueError(
32
- "Invalid db_option: {}".format(self.config["vectorstore"]["db_option"])
33
- )
34
 
35
  def _load_database(self, embedding_model):
36
- if self.config["vectorstore"]["db_option"] == "FAISS":
37
- self.vectorstore = FaissVectorStore(self.config)
38
- return self.vectorstore.load_database(embedding_model)
39
- elif self.config["vectorstore"]["db_option"] == "Chroma":
40
- self.vectorstore = ChromaVectorStore(self.config)
41
- return self.vectorstore.load_database(embedding_model)
42
- elif self.config["vectorstore"]["db_option"] == "RAGatouille":
43
- self.vectorstore = ColbertVectorStore(self.config)
44
  return self.vectorstore.load_database()
 
 
45
 
46
  def _as_retriever(self):
47
  return self.vectorstore.as_retriever()
 
7
  def __init__(self, config):
8
  self.config = config
9
  self.vectorstore = None
10
+ self.vectorstore_classes = {
11
+ "FAISS": FaissVectorStore,
12
+ "Chroma": ChromaVectorStore,
13
+ "RAGatouille": ColbertVectorStore,
14
+ }
15
 
16
  def _create_database(
17
  self,
 
21
  document_metadata,
22
  embedding_model,
23
  ):
24
+ db_option = self.config["vectorstore"]["db_option"]
25
+ vectorstore_class = self.vectorstore_classes.get(db_option)
26
+ if not vectorstore_class:
27
+ raise ValueError(f"Invalid db_option: {db_option}")
28
+
29
+ self.vectorstore = vectorstore_class(self.config)
30
+
31
+ if db_option == "RAGatouille":
32
  self.vectorstore.create_database(
33
  documents, document_names, document_metadata
34
  )
35
  else:
36
+ self.vectorstore.create_database(document_chunks, embedding_model)
 
 
37
 
38
  def _load_database(self, embedding_model):
39
+ db_option = self.config["vectorstore"]["db_option"]
40
+ vectorstore_class = self.vectorstore_classes.get(db_option)
41
+ if not vectorstore_class:
42
+ raise ValueError(f"Invalid db_option: {db_option}")
43
+
44
+ self.vectorstore = vectorstore_class(self.config)
45
+
46
+ if db_option == "RAGatouille":
47
  return self.vectorstore.load_database()
48
+ else:
49
+ return self.vectorstore.load_database(embedding_model)
50
 
51
  def _as_retriever(self):
52
  return self.vectorstore.as_retriever()