Spaces:
Sleeping
Sleeping
File size: 2,256 Bytes
5a34799 207bed4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from typing import List, Dict
class VectorDatabase:
def __init__(self, embedding_model_name: str = "all-MiniLM-L6-v2", dim: int = 384):
"""
VectorDatabase: A simple vector database for storing and retrieving contextual embeddings.
"""
self.model = SentenceTransformer(embedding_model_name)
self.index = faiss.IndexFlatL2(dim)
self.data = [] # Stores raw data (context and summaries)
def add_data(self, texts: List[str], summaries: List[str]) -> None:
"""
Adds data to the vector database.
Args:
texts (List[str]): The original texts to be stored.
summaries (List[str]): Summarized versions of the texts.
"""
embeddings = self.model.encode(texts)
self.index.add(np.array(embeddings).astype("float32"))
for text, summary in zip(texts, summaries):
self.data.append({"text": text, "summary": summary})
def search(self, query: str, top_k: int = 5) -> List[Dict]:
"""
Searches the vector database for the top-k similar results.
Args:
query (str): The query text.
top_k (int): Number of results to return.
Returns:
List[Dict]: List of matched context and summaries.
"""
query_embedding = self.model.encode([query])
distances, indices = self.index.search(np.array(query_embedding).astype("float32"), top_k)
results = [self.data[i] for i in indices[0] if i < len(self.data)]
return results
def to_dict(self) -> Dict:
"""
Converts the internal state of the vector database to a dictionary format.
"""
return {
"data": self.data,
"index": self.index.reconstruct_n(0, self.index.ntotal).tolist()
}
def from_dict(self, state: Dict) -> None:
"""
Restores the internal state of the vector database from a dictionary format.
Args:
state (Dict): The state to restore.
"""
self.data = state["data"]
embeddings = np.array(state["index"]).astype("float32")
self.index.add(embeddings) |