Spaces:
Runtime error
Runtime error
Added the RAG training script to be run offline
Browse files- src/training/train_rag.py +138 -0
src/training/train_rag.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Not used as part of the streamlit app, but run offline to train the architecture for the rag architecture.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import chromadb
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
import time
|
10 |
+
|
11 |
+
from chromadb import Settings
|
12 |
+
from src.common import data_dir
|
13 |
+
from src.datatypes import *
|
14 |
+
|
15 |
+
|
16 |
+
# Set up and parse expected arguments to the script
|
17 |
+
parser=argparse.ArgumentParser(prog="train_rag",
|
18 |
+
description="Train a Chroma DB document store from the product dataset provided for use in a RAG LLM architecture")
|
19 |
+
parser.add_argument('-in_vectors', help="Optional existing Chroma DB to load and extend")
|
20 |
+
parser.add_argument('-out_vectors', required=True, help="The name for the output Chroma DB")
|
21 |
+
parser.add_argument('-products_db', required=True, help="The products sqlite to vectorise")
|
22 |
+
parser.add_argument('-overwrite', action='store_true', help="Overwrite the in_vectors store from blank (defaults to append and skip existing)")
|
23 |
+
parser.add_argument('-delete', action='store_true', help="Delete product set from vector store")
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
|
27 |
+
def need_to_copy_vector_store(args) -> bool:
|
28 |
+
if args.in_vectors is None:
|
29 |
+
return False
|
30 |
+
if args.in_vectors == args.out_vectors:
|
31 |
+
return False
|
32 |
+
return True
|
33 |
+
|
34 |
+
|
35 |
+
def copy_vector_store(args) -> None:
|
36 |
+
src = os.path.join(data_dir, 'vector_stores', args.in_vectors)
|
37 |
+
dest = os.path.join(data_dir, 'vector_stores', args.out_vectors)
|
38 |
+
shutil.copytree(src, dest)
|
39 |
+
|
40 |
+
|
41 |
+
def empty_or_create_out_vector_store(args) -> None:
|
42 |
+
chroma_client = get_out_vector_store_client(args, True)
|
43 |
+
chroma_client.reset()
|
44 |
+
|
45 |
+
|
46 |
+
def get_product_collection_from_client(client:chromadb.Client) -> chromadb.Collection:
|
47 |
+
return client.get_or_create_collection(name='products', metadata={'hnsw:space': 'cosine'})
|
48 |
+
|
49 |
+
|
50 |
+
def connect_to_product_db(args) -> None:
|
51 |
+
"""
|
52 |
+
Connect to the requested product DB which will load the products.
|
53 |
+
On failure this will raise an exception which will propagate to the command
|
54 |
+
line, which is fine as this is a script, not part of the running app
|
55 |
+
"""
|
56 |
+
if DataLoader.active_db != args.products_db:
|
57 |
+
DataLoader.set_db_name(args.products_db)
|
58 |
+
else:
|
59 |
+
DataLoader.load_data()
|
60 |
+
|
61 |
+
|
62 |
+
def get_out_vector_store_client(args, allow_reset: bool = False) -> chromadb.Client:
|
63 |
+
out_dir = os.path.join(data_dir, 'vector_stores', args.out_vectors)
|
64 |
+
chroma_settings = Settings()
|
65 |
+
chroma_settings.allow_reset = allow_reset
|
66 |
+
return chromadb.PersistentClient(path=out_dir, settings=chroma_settings)
|
67 |
+
|
68 |
+
|
69 |
+
def prepare_to_vectorise(args) -> chromadb.Client:
|
70 |
+
connect_to_product_db(args) # Do this first as non-destructive
|
71 |
+
|
72 |
+
# Now do possibly destructive setup
|
73 |
+
if args.overwrite:
|
74 |
+
empty_or_create_out_vector_store()
|
75 |
+
elif need_to_copy_vector_store(args):
|
76 |
+
copy_vector_store(args)
|
77 |
+
|
78 |
+
return get_out_vector_store_client(args)
|
79 |
+
|
80 |
+
|
81 |
+
def document_for_product(product: Product) -> str:
|
82 |
+
"""
|
83 |
+
Builds a string document for vectorisation from a product
|
84 |
+
"""
|
85 |
+
category = product.category.singular_name
|
86 |
+
category_sentence = f"The {product.name} is a {category}."
|
87 |
+
price_rating_sentence = f"It costs ${product.price} and is rated {product.average_rating} stars."
|
88 |
+
feature_sentence = f"The {product.name} features {join_items_comma_and(product.features)}."
|
89 |
+
return f"{category_sentence} {price_rating_sentence} {feature_sentence} {product.description}"
|
90 |
+
|
91 |
+
|
92 |
+
def vectorise(vector_client: chromadb.Client) -> None:
|
93 |
+
"""
|
94 |
+
Add documents representing the products from the products database into the vector store
|
95 |
+
Document is a built string from the features of the product
|
96 |
+
IDs are loaded as "prod_{id from db}"
|
97 |
+
Metadata is loaded with the category
|
98 |
+
"""
|
99 |
+
collection = get_product_collection_from_client(vector_client)
|
100 |
+
products = Product.all_as_list()
|
101 |
+
ids = [f"prod_{p.id}" for p in products]
|
102 |
+
documents = [document_for_product(p) for p in products]
|
103 |
+
metadata = [{'category': p.category.singular_name} for p in products]
|
104 |
+
print(f"Vectorising {len(products)} products")
|
105 |
+
collection.upsert(ids=ids, documents=documents, metadatas=metadata)
|
106 |
+
|
107 |
+
|
108 |
+
def prepare_to_delete_vectors(args) -> chromadb.Client:
|
109 |
+
connect_to_product_db(args) # Do this first as non-destructive
|
110 |
+
|
111 |
+
# Now do possibly destructive setup
|
112 |
+
if need_to_copy_vector_store(args):
|
113 |
+
copy_vector_store(args)
|
114 |
+
|
115 |
+
return get_out_vector_store_client(args)
|
116 |
+
|
117 |
+
|
118 |
+
def delete_vectors(vector_client: chromadb.Client) -> None:
|
119 |
+
collection = get_product_collection_from_client(vector_client)
|
120 |
+
products = Product.all_as_list()
|
121 |
+
ids = [f"prod_{p.id}" for p in products]
|
122 |
+
collection.delete(ids=ids)
|
123 |
+
|
124 |
+
|
125 |
+
def train(args):
|
126 |
+
if args.delete:
|
127 |
+
vector_store = prepare_to_delete_vectors(args)
|
128 |
+
delete_vectors(vector_store)
|
129 |
+
else:
|
130 |
+
vector_store = prepare_to_vectorise(args)
|
131 |
+
vectorise(vector_store)
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
start = time.time()
|
136 |
+
train(args)
|
137 |
+
end = time.time()
|
138 |
+
print(f"Training took {end-start:.2f} seconds")
|