Spaces:
Runtime error
Runtime error
File size: 5,048 Bytes
d7a31e7 cc94398 d7a31e7 c2d8e03 d7a31e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""
Not used as part of the streamlit app, but run offline to train the architecture for the rag architecture.
"""
import argparse
import chromadb
import os
import shutil
import time
from chromadb import Settings
from src.common import data_dir
from src.datatypes import *
# Set up and parse expected arguments to the script
parser=argparse.ArgumentParser(prog="train_rag",
description="Train a Chroma DB document store from the product dataset provided for use in a RAG LLM architecture")
parser.add_argument('-in_vectors', help="Optional existing Chroma DB to load and extend")
parser.add_argument('-out_vectors', required=True, help="The name for the output Chroma DB")
parser.add_argument('-products_db', required=True, help="The products sqlite to vectorise")
parser.add_argument('-overwrite', action='store_true', help="Overwrite the in_vectors store from blank (defaults to append and skip existing)")
parser.add_argument('-delete', action='store_true', help="Delete product set from vector store")
args = parser.parse_args()
def need_to_copy_vector_store(args) -> bool:
if args.in_vectors is None:
return False
if args.in_vectors == args.out_vectors:
return False
return True
def copy_vector_store(args) -> None:
src = os.path.join(data_dir, 'vector_stores', f"{args.in_vectors}_chroma")
dest = os.path.join(data_dir, 'vector_stores', f"{args.out_vectors}_chroma")
shutil.copytree(src, dest)
def empty_or_create_out_vector_store(args) -> None:
chroma_client = get_out_vector_store_client(args, True)
chroma_client.reset()
def get_product_collection_from_client(client:chromadb.Client) -> chromadb.Collection:
return client.get_or_create_collection(name='products', metadata={'hnsw:space': 'cosine'})
def connect_to_product_db(args) -> None:
"""
Connect to the requested product DB which will load the products.
On failure this will raise an exception which will propagate to the command
line, which is fine as this is a script, not part of the running app
"""
if DataLoader.active_db != args.products_db:
DataLoader.set_db_name(args.products_db)
else:
DataLoader.load_data()
def get_out_vector_store_client(args, allow_reset: bool = False) -> chromadb.Client:
out_dir = os.path.join(data_dir, 'vector_stores', f"{args.out_vectors}_chroma")
chroma_settings = Settings()
chroma_settings.allow_reset = allow_reset
return chromadb.PersistentClient(path=out_dir, settings=chroma_settings)
def prepare_to_vectorise(args) -> chromadb.Client:
connect_to_product_db(args) # Do this first as non-destructive
# Now do possibly destructive setup
if args.overwrite:
empty_or_create_out_vector_store()
elif need_to_copy_vector_store(args):
copy_vector_store(args)
return get_out_vector_store_client(args)
def document_for_product(product: Product) -> str:
"""
Builds a string document for vectorisation from a product
"""
category = product.category.singular_name
category_sentence = f"The {product.name} is a {category}."
price_rating_sentence = f"It costs ${product.price} and is rated {product.average_rating} stars."
feature_sentence = f"The {product.name} features {join_items_comma_and(product.features)}."
return f"{category_sentence} {price_rating_sentence} {feature_sentence} {product.description}"
def vectorise(vector_client: chromadb.Client) -> None:
"""
Add documents representing the products from the products database into the vector store
Document is a built string from the features of the product
IDs are loaded as "prod_{id from db}"
Metadata is loaded with the category
"""
collection = get_product_collection_from_client(vector_client)
products = Product.all_as_list()
ids = [f"prod_{p.id}" for p in products]
documents = [document_for_product(p) for p in products]
metadata = [{'category': p.category.singular_name} for p in products]
print(f"Vectorising {len(products)} products")
collection.upsert(ids=ids, documents=documents, metadatas=metadata)
def prepare_to_delete_vectors(args) -> chromadb.Client:
connect_to_product_db(args) # Do this first as non-destructive
# Now do possibly destructive setup
if need_to_copy_vector_store(args):
copy_vector_store(args)
return get_out_vector_store_client(args)
def delete_vectors(vector_client: chromadb.Client) -> None:
collection = get_product_collection_from_client(vector_client)
products = Product.all_as_list()
ids = [f"prod_{p.id}" for p in products]
collection.delete(ids=ids)
def train(args):
if args.delete:
vector_store = prepare_to_delete_vectors(args)
delete_vectors(vector_store)
else:
vector_store = prepare_to_vectorise(args)
vectorise(vector_store)
if __name__ == "__main__":
start = time.time()
train(args)
end = time.time()
print(f"Training took {end-start:.2f} seconds")
|