Update chroma_db_utils.py
Browse files- chroma_db_utils.py +0 -167
chroma_db_utils.py
CHANGED
@@ -1,170 +1,3 @@
|
|
1 |
-
# import os
|
2 |
-
# import chromadb
|
3 |
-
# import numpy as np
|
4 |
-
# from typing import List, Tuple
|
5 |
-
# from gemini_embedding import GeminiEmbeddingFunction
|
6 |
-
|
7 |
-
# def create_chroma_db(documents: List[str], dataset_name: str, base_path: str = "chroma_db"):
|
8 |
-
# """
|
9 |
-
# Creates a Chroma database using the provided documents.
|
10 |
-
# Automatically generates path and collection name based on dataset_name.
|
11 |
-
# """
|
12 |
-
# path = os.path.join(base_path, dataset_name)
|
13 |
-
# name = f"{dataset_name}_collection"
|
14 |
-
|
15 |
-
# if not os.path.exists(path):
|
16 |
-
# os.makedirs(path)
|
17 |
-
|
18 |
-
# chroma_client = chromadb.PersistentClient(path=path)
|
19 |
-
# db = chroma_client.create_collection(name=name, embedding_function=GeminiEmbeddingFunction())
|
20 |
-
|
21 |
-
# for i, doc in enumerate(documents):
|
22 |
-
# db.add(documents=[doc], ids=[str(i)])
|
23 |
-
|
24 |
-
# return db
|
25 |
-
|
26 |
-
# def load_chroma_collection(dataset_name: str, base_path: str = "chroma_db"):
|
27 |
-
# """
|
28 |
-
# Loads an existing Chroma collection.
|
29 |
-
# """
|
30 |
-
# path = os.path.join(base_path, dataset_name)
|
31 |
-
# name = f"{dataset_name}_collection"
|
32 |
-
|
33 |
-
# chroma_client = chromadb.PersistentClient(path=path)
|
34 |
-
# return chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFunction())
|
35 |
-
|
36 |
-
# def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
|
37 |
-
# """
|
38 |
-
# Calculate cosine similarity between two vectors.
|
39 |
-
# Returns a value between -1 and 1, where 1 means most similar.
|
40 |
-
# """
|
41 |
-
# dot_product = np.dot(vec1, vec2)
|
42 |
-
# norm1 = np.linalg.norm(vec1)
|
43 |
-
# norm2 = np.linalg.norm(vec2)
|
44 |
-
# return dot_product / (norm1 * norm2)
|
45 |
-
|
46 |
-
# def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
|
47 |
-
# """
|
48 |
-
# Retrieves relevant passages using explicit cosine similarity calculation.
|
49 |
-
# """
|
50 |
-
# # Get query embedding
|
51 |
-
# query_embedding = db._embedding_function([query])[0]
|
52 |
-
|
53 |
-
# # Get all document embeddings
|
54 |
-
# all_docs = db.get(include=['embeddings', 'documents'])
|
55 |
-
# doc_embeddings = all_docs['embeddings']
|
56 |
-
# documents = all_docs['documents']
|
57 |
-
|
58 |
-
# # Calculate cosine similarity for each document
|
59 |
-
# similarities = []
|
60 |
-
# for doc_embedding in doc_embeddings:
|
61 |
-
# similarity = cosine_similarity(query_embedding, doc_embedding)
|
62 |
-
# similarities.append(similarity)
|
63 |
-
|
64 |
-
# # Sort documents by similarity
|
65 |
-
# doc_similarities = list(zip(documents, similarities))
|
66 |
-
# doc_similarities.sort(key=lambda x: x[1], reverse=True)
|
67 |
-
|
68 |
-
# # Take top n results
|
69 |
-
# top_results = doc_similarities[:n_results]
|
70 |
-
|
71 |
-
# # Print results for debugging
|
72 |
-
# print(f"Number of relevant passages retrieved: {len(top_results)}")
|
73 |
-
# for i, (doc, similarity) in enumerate(top_results):
|
74 |
-
# print(f"Passage {i+1} (Cosine Similarity: {similarity:.4f}): {doc[:100]}...")
|
75 |
-
|
76 |
-
# # Return just the documents
|
77 |
-
# return [doc for doc, _ in top_results]
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
# in memory
|
86 |
-
|
87 |
-
|
88 |
-
# import chromadb
|
89 |
-
# from typing import List
|
90 |
-
# from gemini_embedding import GeminiEmbeddingFunction # Ensure this is correctly implemented
|
91 |
-
# import time
|
92 |
-
# from chromadb.config import Settings
|
93 |
-
|
94 |
-
# def create_chroma_db(chunks: List[str]):
|
95 |
-
# """Create and return an in-memory ChromaDB collection."""
|
96 |
-
# try:
|
97 |
-
# # Initialize in-memory ChromaDB with current recommended configuration
|
98 |
-
# client = chromadb.Client()
|
99 |
-
|
100 |
-
# # Create collection with unique name to avoid conflicts
|
101 |
-
# collection_name = f"temp_collection_{int(time.time())}"
|
102 |
-
# collection = client.create_collection(name=collection_name)
|
103 |
-
|
104 |
-
# # Add documents with unique IDs
|
105 |
-
# collection.add(
|
106 |
-
# documents=chunks,
|
107 |
-
# ids=[f"doc_{i}" for i in range(len(chunks))]
|
108 |
-
# )
|
109 |
-
|
110 |
-
# # Verify the data was added
|
111 |
-
# verify_count = collection.count()
|
112 |
-
# print(f"Verified: Added {verify_count} documents to collection {collection_name}")
|
113 |
-
|
114 |
-
# # Test query to ensure collection is working
|
115 |
-
# test_results = collection.query(
|
116 |
-
# query_texts=["test"],
|
117 |
-
# n_results=1
|
118 |
-
# )
|
119 |
-
# print("Verified: Collection is queryable")
|
120 |
-
|
121 |
-
# return collection
|
122 |
-
|
123 |
-
# except Exception as e:
|
124 |
-
# print(f"Error creating ChromaDB: {str(e)}")
|
125 |
-
# return None
|
126 |
-
|
127 |
-
# def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
|
128 |
-
# """
|
129 |
-
# Retrieves relevant passages using ChromaDB's similarity search.
|
130 |
-
# """
|
131 |
-
# try:
|
132 |
-
# if db is None:
|
133 |
-
# print("Database not initialized")
|
134 |
-
# return []
|
135 |
-
|
136 |
-
# # Verify collection has documents
|
137 |
-
# count = db.count()
|
138 |
-
# if count == 0:
|
139 |
-
# print("Collection is empty")
|
140 |
-
# return []
|
141 |
-
|
142 |
-
# # Query the database
|
143 |
-
# results = db.query(
|
144 |
-
# query_texts=[query],
|
145 |
-
# n_results=min(n_results, count) # Ensure we don't request more than we have
|
146 |
-
# )
|
147 |
-
|
148 |
-
# # Ensure results exist
|
149 |
-
# if not results["documents"]:
|
150 |
-
# print("No relevant passages found.")
|
151 |
-
# return []
|
152 |
-
|
153 |
-
# documents = results["documents"][0] # First result batch
|
154 |
-
# distances = results["distances"][0] # Corresponding distances
|
155 |
-
|
156 |
-
# # Debug output
|
157 |
-
# print(f"Number of relevant passages retrieved: {len(documents)}")
|
158 |
-
# for i, (doc, distance) in enumerate(zip(documents, distances)):
|
159 |
-
# similarity = 1 - distance # Convert distance to similarity
|
160 |
-
# print(f"Passage {i+1} (Similarity: {similarity:.4f}): {doc[:100]}...")
|
161 |
-
|
162 |
-
# return documents
|
163 |
-
# except Exception as e:
|
164 |
-
# print(f"Error in get_relevant_passage: {str(e)}")
|
165 |
-
# return []
|
166 |
-
|
167 |
-
|
168 |
import chromadb
|
169 |
from chromadb.config import Settings
|
170 |
from typing import List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import chromadb
|
2 |
from chromadb.config import Settings
|
3 |
from typing import List
|