jessica45 commited on
Commit
64103af
·
verified ·
1 Parent(s): 1a10750

Update chroma_db_utils.py

Browse files
Files changed (1) hide show
  1. chroma_db_utils.py +0 -167
chroma_db_utils.py CHANGED
@@ -1,170 +1,3 @@
1
- # import os
2
- # import chromadb
3
- # import numpy as np
4
- # from typing import List, Tuple
5
- # from gemini_embedding import GeminiEmbeddingFunction
6
-
7
- # def create_chroma_db(documents: List[str], dataset_name: str, base_path: str = "chroma_db"):
8
- # """
9
- # Creates a Chroma database using the provided documents.
10
- # Automatically generates path and collection name based on dataset_name.
11
- # """
12
- # path = os.path.join(base_path, dataset_name)
13
- # name = f"{dataset_name}_collection"
14
-
15
- # if not os.path.exists(path):
16
- # os.makedirs(path)
17
-
18
- # chroma_client = chromadb.PersistentClient(path=path)
19
- # db = chroma_client.create_collection(name=name, embedding_function=GeminiEmbeddingFunction())
20
-
21
- # for i, doc in enumerate(documents):
22
- # db.add(documents=[doc], ids=[str(i)])
23
-
24
- # return db
25
-
26
- # def load_chroma_collection(dataset_name: str, base_path: str = "chroma_db"):
27
- # """
28
- # Loads an existing Chroma collection.
29
- # """
30
- # path = os.path.join(base_path, dataset_name)
31
- # name = f"{dataset_name}_collection"
32
-
33
- # chroma_client = chromadb.PersistentClient(path=path)
34
- # return chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFunction())
35
-
36
- # def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
37
- # """
38
- # Calculate cosine similarity between two vectors.
39
- # Returns a value between -1 and 1, where 1 means most similar.
40
- # """
41
- # dot_product = np.dot(vec1, vec2)
42
- # norm1 = np.linalg.norm(vec1)
43
- # norm2 = np.linalg.norm(vec2)
44
- # return dot_product / (norm1 * norm2)
45
-
46
- # def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
47
- # """
48
- # Retrieves relevant passages using explicit cosine similarity calculation.
49
- # """
50
- # # Get query embedding
51
- # query_embedding = db._embedding_function([query])[0]
52
-
53
- # # Get all document embeddings
54
- # all_docs = db.get(include=['embeddings', 'documents'])
55
- # doc_embeddings = all_docs['embeddings']
56
- # documents = all_docs['documents']
57
-
58
- # # Calculate cosine similarity for each document
59
- # similarities = []
60
- # for doc_embedding in doc_embeddings:
61
- # similarity = cosine_similarity(query_embedding, doc_embedding)
62
- # similarities.append(similarity)
63
-
64
- # # Sort documents by similarity
65
- # doc_similarities = list(zip(documents, similarities))
66
- # doc_similarities.sort(key=lambda x: x[1], reverse=True)
67
-
68
- # # Take top n results
69
- # top_results = doc_similarities[:n_results]
70
-
71
- # # Print results for debugging
72
- # print(f"Number of relevant passages retrieved: {len(top_results)}")
73
- # for i, (doc, similarity) in enumerate(top_results):
74
- # print(f"Passage {i+1} (Cosine Similarity: {similarity:.4f}): {doc[:100]}...")
75
-
76
- # # Return just the documents
77
- # return [doc for doc, _ in top_results]
78
-
79
-
80
-
81
-
82
-
83
-
84
-
85
- # in memory
86
-
87
-
88
- # import chromadb
89
- # from typing import List
90
- # from gemini_embedding import GeminiEmbeddingFunction # Ensure this is correctly implemented
91
- # import time
92
- # from chromadb.config import Settings
93
-
94
- # def create_chroma_db(chunks: List[str]):
95
- # """Create and return an in-memory ChromaDB collection."""
96
- # try:
97
- # # Initialize in-memory ChromaDB with current recommended configuration
98
- # client = chromadb.Client()
99
-
100
- # # Create collection with unique name to avoid conflicts
101
- # collection_name = f"temp_collection_{int(time.time())}"
102
- # collection = client.create_collection(name=collection_name)
103
-
104
- # # Add documents with unique IDs
105
- # collection.add(
106
- # documents=chunks,
107
- # ids=[f"doc_{i}" for i in range(len(chunks))]
108
- # )
109
-
110
- # # Verify the data was added
111
- # verify_count = collection.count()
112
- # print(f"Verified: Added {verify_count} documents to collection {collection_name}")
113
-
114
- # # Test query to ensure collection is working
115
- # test_results = collection.query(
116
- # query_texts=["test"],
117
- # n_results=1
118
- # )
119
- # print("Verified: Collection is queryable")
120
-
121
- # return collection
122
-
123
- # except Exception as e:
124
- # print(f"Error creating ChromaDB: {str(e)}")
125
- # return None
126
-
127
- # def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
128
- # """
129
- # Retrieves relevant passages using ChromaDB's similarity search.
130
- # """
131
- # try:
132
- # if db is None:
133
- # print("Database not initialized")
134
- # return []
135
-
136
- # # Verify collection has documents
137
- # count = db.count()
138
- # if count == 0:
139
- # print("Collection is empty")
140
- # return []
141
-
142
- # # Query the database
143
- # results = db.query(
144
- # query_texts=[query],
145
- # n_results=min(n_results, count) # Ensure we don't request more than we have
146
- # )
147
-
148
- # # Ensure results exist
149
- # if not results["documents"]:
150
- # print("No relevant passages found.")
151
- # return []
152
-
153
- # documents = results["documents"][0] # First result batch
154
- # distances = results["distances"][0] # Corresponding distances
155
-
156
- # # Debug output
157
- # print(f"Number of relevant passages retrieved: {len(documents)}")
158
- # for i, (doc, distance) in enumerate(zip(documents, distances)):
159
- # similarity = 1 - distance # Convert distance to similarity
160
- # print(f"Passage {i+1} (Similarity: {similarity:.4f}): {doc[:100]}...")
161
-
162
- # return documents
163
- # except Exception as e:
164
- # print(f"Error in get_relevant_passage: {str(e)}")
165
- # return []
166
-
167
-
168
  import chromadb
169
  from chromadb.config import Settings
170
  from typing import List
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import chromadb
2
  from chromadb.config import Settings
3
  from typing import List