alfraser commited on
Commit
d7a31e7
·
1 Parent(s): cd5bd85

Added the RAG training script to be run offline

Browse files
Files changed (1) hide show
  1. src/training/train_rag.py +138 -0
src/training/train_rag.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Not used as part of the streamlit app, but run offline to train the architecture for the rag architecture.
3
+ """
4
+
5
+ import argparse
6
+ import chromadb
7
+ import os
8
+ import shutil
9
+ import time
10
+
11
+ from chromadb import Settings
12
+ from src.common import data_dir
13
+ from src.datatypes import *
14
+
15
+
16
+ # Set up and parse expected arguments to the script
17
+ parser=argparse.ArgumentParser(prog="train_rag",
18
+ description="Train a Chroma DB document store from the product dataset provided for use in a RAG LLM architecture")
19
+ parser.add_argument('-in_vectors', help="Optional existing Chroma DB to load and extend")
20
+ parser.add_argument('-out_vectors', required=True, help="The name for the output Chroma DB")
21
+ parser.add_argument('-products_db', required=True, help="The products sqlite to vectorise")
22
+ parser.add_argument('-overwrite', action='store_true', help="Overwrite the in_vectors store from blank (defaults to append and skip existing)")
23
+ parser.add_argument('-delete', action='store_true', help="Delete product set from vector store")
24
+ args = parser.parse_args()
25
+
26
+
27
+ def need_to_copy_vector_store(args) -> bool:
28
+ if args.in_vectors is None:
29
+ return False
30
+ if args.in_vectors == args.out_vectors:
31
+ return False
32
+ return True
33
+
34
+
35
+ def copy_vector_store(args) -> None:
36
+ src = os.path.join(data_dir, 'vector_stores', args.in_vectors)
37
+ dest = os.path.join(data_dir, 'vector_stores', args.out_vectors)
38
+ shutil.copytree(src, dest)
39
+
40
+
41
+ def empty_or_create_out_vector_store(args) -> None:
42
+ chroma_client = get_out_vector_store_client(args, True)
43
+ chroma_client.reset()
44
+
45
+
46
+ def get_product_collection_from_client(client:chromadb.Client) -> chromadb.Collection:
47
+ return client.get_or_create_collection(name='products', metadata={'hnsw:space': 'cosine'})
48
+
49
+
50
+ def connect_to_product_db(args) -> None:
51
+ """
52
+ Connect to the requested product DB which will load the products.
53
+ On failure this will raise an exception which will propagate to the command
54
+ line, which is fine as this is a script, not part of the running app
55
+ """
56
+ if DataLoader.active_db != args.products_db:
57
+ DataLoader.set_db_name(args.products_db)
58
+ else:
59
+ DataLoader.load_data()
60
+
61
+
62
+ def get_out_vector_store_client(args, allow_reset: bool = False) -> chromadb.Client:
63
+ out_dir = os.path.join(data_dir, 'vector_stores', args.out_vectors)
64
+ chroma_settings = Settings()
65
+ chroma_settings.allow_reset = allow_reset
66
+ return chromadb.PersistentClient(path=out_dir, settings=chroma_settings)
67
+
68
+
69
+ def prepare_to_vectorise(args) -> chromadb.Client:
70
+ connect_to_product_db(args) # Do this first as non-destructive
71
+
72
+ # Now do possibly destructive setup
73
+ if args.overwrite:
74
+ empty_or_create_out_vector_store()
75
+ elif need_to_copy_vector_store(args):
76
+ copy_vector_store(args)
77
+
78
+ return get_out_vector_store_client(args)
79
+
80
+
81
+ def document_for_product(product: Product) -> str:
82
+ """
83
+ Builds a string document for vectorisation from a product
84
+ """
85
+ category = product.category.singular_name
86
+ category_sentence = f"The {product.name} is a {category}."
87
+ price_rating_sentence = f"It costs ${product.price} and is rated {product.average_rating} stars."
88
+ feature_sentence = f"The {product.name} features {join_items_comma_and(product.features)}."
89
+ return f"{category_sentence} {price_rating_sentence} {feature_sentence} {product.description}"
90
+
91
+
92
+ def vectorise(vector_client: chromadb.Client) -> None:
93
+ """
94
+ Add documents representing the products from the products database into the vector store
95
+ Document is a built string from the features of the product
96
+ IDs are loaded as "prod_{id from db}"
97
+ Metadata is loaded with the category
98
+ """
99
+ collection = get_product_collection_from_client(vector_client)
100
+ products = Product.all_as_list()
101
+ ids = [f"prod_{p.id}" for p in products]
102
+ documents = [document_for_product(p) for p in products]
103
+ metadata = [{'category': p.category.singular_name} for p in products]
104
+ print(f"Vectorising {len(products)} products")
105
+ collection.upsert(ids=ids, documents=documents, metadatas=metadata)
106
+
107
+
108
+ def prepare_to_delete_vectors(args) -> chromadb.Client:
109
+ connect_to_product_db(args) # Do this first as non-destructive
110
+
111
+ # Now do possibly destructive setup
112
+ if need_to_copy_vector_store(args):
113
+ copy_vector_store(args)
114
+
115
+ return get_out_vector_store_client(args)
116
+
117
+
118
+ def delete_vectors(vector_client: chromadb.Client) -> None:
119
+ collection = get_product_collection_from_client(vector_client)
120
+ products = Product.all_as_list()
121
+ ids = [f"prod_{p.id}" for p in products]
122
+ collection.delete(ids=ids)
123
+
124
+
125
+ def train(args):
126
+ if args.delete:
127
+ vector_store = prepare_to_delete_vectors(args)
128
+ delete_vectors(vector_store)
129
+ else:
130
+ vector_store = prepare_to_vectorise(args)
131
+ vectorise(vector_store)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ start = time.time()
136
+ train(args)
137
+ end = time.time()
138
+ print(f"Training took {end-start:.2f} seconds")