Thomas (Tom) Gardos commited on
Commit
5520fda
·
2 Parent(s): 6e49b76 33e5fa6

Merge pull request #35 from DL4DS/hf_index_load

Browse files
code/modules/config/config.yml CHANGED
@@ -3,11 +3,13 @@ log_chunk_dir: '../storage/logs/chunks' # str
3
  device: 'cpu' # str [cuda, cpu]
4
 
5
  vectorstore:
 
 
6
  embedd_files: False # bool
7
  data_path: '../storage/data' # str
8
  url_file_path: '../storage/data/urls.txt' # str
9
  expand_urls: True # bool
10
- db_option : 'FAISS' # str [FAISS, Chroma, RAGatouille, RAPTOR]
11
  db_path : '../vectorstores' # str
12
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
13
  search_top_k : 3 # int
 
3
  device: 'cpu' # str [cuda, cpu]
4
 
5
  vectorstore:
6
+ load_from_HF: True # bool
7
+ HF_path: "XThomasBU/Colbert_Index" # str
8
  embedd_files: False # bool
9
  data_path: '../storage/data' # str
10
  url_file_path: '../storage/data/urls.txt' # str
11
  expand_urls: True # bool
12
+ db_option : 'RAGatouille' # str [FAISS, Chroma, RAGatouille, RAPTOR]
13
  db_path : '../vectorstores' # str
14
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
15
  search_top_k : 3 # int
code/modules/vectorstore/store_manager.py CHANGED
@@ -143,6 +143,14 @@ class VectorStoreManager:
143
  self.logger.info("Loaded database")
144
  return self.loaded_vector_db
145
 
 
 
 
 
 
 
 
 
146
 
147
  if __name__ == "__main__":
148
  import yaml
@@ -152,7 +160,10 @@ if __name__ == "__main__":
152
  print(config)
153
  print(f"Trying to create database with config: {config}")
154
  vector_db = VectorStoreManager(config)
155
- vector_db.create_database()
 
 
 
156
  print("Created database")
157
 
158
  print(f"Trying to load the database")
 
143
  self.logger.info("Loaded database")
144
  return self.loaded_vector_db
145
 
146
+ def load_from_HF(self):
147
+ start_time = time.time() # Start time for loading database
148
+ self.vector_db._load_from_HF()
149
+ end_time = time.time()
150
+ self.logger.info(
151
+ f"Time taken to load database from Hugging Face: {end_time - start_time} seconds"
152
+ )
153
+
154
 
155
  if __name__ == "__main__":
156
  import yaml
 
160
  print(config)
161
  print(f"Trying to create database with config: {config}")
162
  vector_db = VectorStoreManager(config)
163
+ if config["vectorstore"]["load_from_HF"] and "HF_path" in config["vectorstore"]:
164
+ vector_db.load_from_HF()
165
+ else:
166
+ vector_db.create_database()
167
  print("Created database")
168
 
169
  print(f"Trying to load the database")
code/modules/vectorstore/vectorstore.py CHANGED
@@ -2,6 +2,9 @@ from modules.vectorstore.faiss import FaissVectorStore
2
  from modules.vectorstore.chroma import ChromaVectorStore
3
  from modules.vectorstore.colbert import ColbertVectorStore
4
  from modules.vectorstore.raptor import RAPTORVectoreStore
 
 
 
5
 
6
 
7
  class VectorStore:
@@ -50,6 +53,34 @@ class VectorStore:
50
  else:
51
  return self.vectorstore.load_database(embedding_model)
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def _as_retriever(self):
54
  return self.vectorstore.as_retriever()
55
 
 
2
  from modules.vectorstore.chroma import ChromaVectorStore
3
  from modules.vectorstore.colbert import ColbertVectorStore
4
  from modules.vectorstore.raptor import RAPTORVectoreStore
5
+ from huggingface_hub import snapshot_download
6
+ import os
7
+ import shutil
8
 
9
 
10
  class VectorStore:
 
53
  else:
54
  return self.vectorstore.load_database(embedding_model)
55
 
56
+ def _load_from_HF(self):
57
+ # Download the snapshot from Hugging Face Hub
58
+ # Note: Download goes to the cache directory
59
+ snapshot_path = snapshot_download(
60
+ repo_id=self.config["vectorstore"]["HF_path"],
61
+ repo_type="dataset",
62
+ force_download=True,
63
+ )
64
+
65
+ # Move the downloaded files to the desired directory
66
+ target_path = os.path.join(
67
+ self.config["vectorstore"]["db_path"],
68
+ "db_" + self.config["vectorstore"]["db_option"],
69
+ )
70
+
71
+ # Create target path if it doesn't exist
72
+ os.makedirs(target_path, exist_ok=True)
73
+
74
+ # move all files and directories from snapshot_path to target_path
75
+ # target path is used while loading the database
76
+ for item in os.listdir(snapshot_path):
77
+ s = os.path.join(snapshot_path, item)
78
+ d = os.path.join(target_path, item)
79
+ if os.path.isdir(s):
80
+ shutil.copytree(s, d, dirs_exist_ok=True)
81
+ else:
82
+ shutil.copy2(s, d)
83
+
84
  def _as_retriever(self):
85
  return self.vectorstore.as_retriever()
86