File size: 5,048 Bytes
d7a31e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc94398
 
d7a31e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2d8e03
d7a31e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Not used as part of the streamlit app, but run offline to train the architecture for the rag architecture.
"""

import argparse
import chromadb
import os
import shutil
import time

from chromadb import Settings
from src.common import data_dir
from src.datatypes import *


# Set up and parse expected arguments to the script
parser=argparse.ArgumentParser(prog="train_rag",
                               description="Train a Chroma DB document store from the product dataset provided for use in a RAG LLM architecture")
parser.add_argument('-in_vectors', help="Optional existing Chroma DB to load and extend")
parser.add_argument('-out_vectors', required=True, help="The name for the output Chroma DB")
parser.add_argument('-products_db', required=True, help="The products sqlite to vectorise")
parser.add_argument('-overwrite', action='store_true', help="Overwrite the in_vectors store from blank (defaults to append and skip existing)")
parser.add_argument('-delete', action='store_true', help="Delete product set from vector store")
args = parser.parse_args()


def need_to_copy_vector_store(args) -> bool:
    if args.in_vectors is None:
        return False
    if args.in_vectors == args.out_vectors:
        return False
    return True


def copy_vector_store(args) -> None:
    src = os.path.join(data_dir, 'vector_stores', f"{args.in_vectors}_chroma")
    dest = os.path.join(data_dir, 'vector_stores', f"{args.out_vectors}_chroma")
    shutil.copytree(src, dest)


def empty_or_create_out_vector_store(args) -> None:
    chroma_client = get_out_vector_store_client(args, True)
    chroma_client.reset()


def get_product_collection_from_client(client:chromadb.Client) -> chromadb.Collection:
    return client.get_or_create_collection(name='products', metadata={'hnsw:space': 'cosine'})


def connect_to_product_db(args) -> None:
    """
    Connect to the requested product DB which will load the products.
    On failure this will raise an exception which will propagate to the command
    line, which is fine as this is a script, not part of the running app
    """
    if DataLoader.active_db != args.products_db:
        DataLoader.set_db_name(args.products_db)
    else:
        DataLoader.load_data()


def get_out_vector_store_client(args, allow_reset: bool = False) -> chromadb.Client:
    out_dir = os.path.join(data_dir, 'vector_stores', f"{args.out_vectors}_chroma")
    chroma_settings = Settings()
    chroma_settings.allow_reset = allow_reset
    return chromadb.PersistentClient(path=out_dir, settings=chroma_settings)


def prepare_to_vectorise(args) -> chromadb.Client:
    connect_to_product_db(args)  # Do this first as non-destructive

    # Now do possibly destructive setup
    if args.overwrite:
        empty_or_create_out_vector_store()
    elif need_to_copy_vector_store(args):
        copy_vector_store(args)

    return get_out_vector_store_client(args)


def document_for_product(product: Product) -> str:
    """
    Builds a string document for vectorisation from a product
    """
    category = product.category.singular_name
    category_sentence = f"The {product.name} is a {category}."
    price_rating_sentence = f"It costs ${product.price} and is rated {product.average_rating} stars."
    feature_sentence = f"The {product.name} features {join_items_comma_and(product.features)}."
    return f"{category_sentence} {price_rating_sentence} {feature_sentence} {product.description}"


def vectorise(vector_client: chromadb.Client) -> None:
    """
    Add documents representing the products from the products database into the vector store
    Document is a built string from the features of the product
    IDs are loaded as "prod_{id from db}"
    Metadata is loaded with the category
    """
    collection = get_product_collection_from_client(vector_client)
    products = Product.all_as_list()
    ids = [f"prod_{p.id}" for p in products]
    documents = [document_for_product(p) for p in products]
    metadata = [{'category': p.category.singular_name} for p in products]
    print(f"Vectorising {len(products)} products")
    collection.upsert(ids=ids, documents=documents, metadatas=metadata)


def prepare_to_delete_vectors(args) -> chromadb.Client:
    connect_to_product_db(args)  # Do this first as non-destructive

    # Now do possibly destructive setup
    if need_to_copy_vector_store(args):
        copy_vector_store(args)

    return get_out_vector_store_client(args)


def delete_vectors(vector_client: chromadb.Client) -> None:
    collection = get_product_collection_from_client(vector_client)
    products = Product.all_as_list()
    ids = [f"prod_{p.id}" for p in products]
    collection.delete(ids=ids)


def train(args):
    if args.delete:
        vector_store = prepare_to_delete_vectors(args)
        delete_vectors(vector_store)
    else:
        vector_store = prepare_to_vectorise(args)
        vectorise(vector_store)


if __name__ == "__main__":
    start = time.time()
    train(args)
    end = time.time()
    print(f"Training took {end-start:.2f} seconds")