Spaces:
Starting
Starting
test commit
Browse files- rag/rag_pipeline.py +12 -4
rag/rag_pipeline.py
CHANGED
@@ -12,8 +12,14 @@ import chromadb
|
|
12 |
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
|
|
|
15 |
class RAGPipeline:
|
16 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
17 |
self.study_json = study_json
|
18 |
self.collection_name = collection_name
|
19 |
self.use_semantic_splitter = use_semantic_splitter
|
@@ -65,12 +71,14 @@ class RAGPipeline:
|
|
65 |
|
66 |
# Parse documents into nodes for embedding
|
67 |
nodes = node_parser.get_nodes_from_documents(self.documents)
|
68 |
-
|
69 |
# Initialize ChromaVectorStore with the existing collection
|
70 |
vector_store = ChromaVectorStore(chroma_collection=self.collection)
|
71 |
|
72 |
# Create the VectorStoreIndex using the ChromaVectorStore
|
73 |
-
self.index = VectorStoreIndex(
|
|
|
|
|
74 |
|
75 |
def query(
|
76 |
self, context: str, prompt_template: PromptTemplate = None
|
@@ -88,7 +96,7 @@ class RAGPipeline:
|
|
88 |
"If you're unsure about a source, use [?]. "
|
89 |
"Ensure that EVERY statement from the context is properly cited."
|
90 |
)
|
91 |
-
|
92 |
# This is a hack to index all the documents in the store :)
|
93 |
n_documents = len(self.index.docstore.docs)
|
94 |
print(f"n_documents: {n_documents}")
|
|
|
12 |
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
|
15 |
+
|
16 |
class RAGPipeline:
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
study_json,
|
20 |
+
collection_name="study_files_rag_collection",
|
21 |
+
use_semantic_splitter=False,
|
22 |
+
):
|
23 |
self.study_json = study_json
|
24 |
self.collection_name = collection_name
|
25 |
self.use_semantic_splitter = use_semantic_splitter
|
|
|
71 |
|
72 |
# Parse documents into nodes for embedding
|
73 |
nodes = node_parser.get_nodes_from_documents(self.documents)
|
74 |
+
|
75 |
# Initialize ChromaVectorStore with the existing collection
|
76 |
vector_store = ChromaVectorStore(chroma_collection=self.collection)
|
77 |
|
78 |
# Create the VectorStoreIndex using the ChromaVectorStore
|
79 |
+
self.index = VectorStoreIndex(
|
80 |
+
nodes, vector_store=vector_store, embed_model=self.embedding_model
|
81 |
+
)
|
82 |
|
83 |
def query(
|
84 |
self, context: str, prompt_template: PromptTemplate = None
|
|
|
96 |
"If you're unsure about a source, use [?]. "
|
97 |
"Ensure that EVERY statement from the context is properly cited."
|
98 |
)
|
99 |
+
|
100 |
# This is a hack to index all the documents in the store :)
|
101 |
n_documents = len(self.index.docstore.docs)
|
102 |
print(f"n_documents: {n_documents}")
|