Spaces:
Paused
Paused
import numpy as np | |
from collections import defaultdict | |
from typing import List, Tuple, Callable | |
from aimakerspace.openai_utils.embedding import EmbeddingModel | |
import asyncio | |
def cosine_similarity(vector_a: np.array, vector_b: np.array) -> float: | |
"""Computes the cosine similarity between two vectors.""" | |
dot_product = np.dot(vector_a, vector_b) | |
norm_a = np.linalg.norm(vector_a) | |
norm_b = np.linalg.norm(vector_b) | |
return dot_product / (norm_a * norm_b) | |
class VectorDatabase: | |
def __init__(self, embedding_model: EmbeddingModel = None): | |
#self.vectors = defaultdict(np.array) | |
self.data = defaultdict(lambda: {"vector": np.array([]), "metadata": {}}) #changed by YL | |
self.embedding_model = embedding_model or EmbeddingModel() | |
#def insert(self, key: str, vector: np.array) -> None: | |
# self.vectors[key] = vector | |
def insert(self, key: str, vector: np.array, metadata: dict[str, any] = None) -> None: | |
# Insert vector and metadata | |
self.data[key]["vector"] = vector | |
self.data[key]["metadata"] = metadata or {} | |
# def search( | |
# self, | |
# query_vector: np.array, | |
# k: int, | |
# distance_measure: Callable = cosine_similarity, | |
# ) -> List[Tuple[str, float]]: | |
# scores = [ | |
# (key, distance_measure(query_vector, vector)) | |
# for key, vector in self.vectors.items() | |
# ] | |
# return sorted(scores, key=lambda x: x[1], reverse=True)[:k] | |
def search( | |
self, | |
query_vector: np.array, | |
k: int, | |
distance_measure: Callable = cosine_similarity, | |
) -> List[Tuple[str, float, dict[str, any]]]: | |
scores = [ | |
(key, distance_measure(query_vector, data["vector"]), data["metadata"]) | |
for key, data in self.data.items() | |
] | |
return sorted(scores, key=lambda x: x[1], reverse=True)[:k] | |
# def search_by_text( | |
# self, | |
# query_text: str, | |
# k: int, | |
# distance_measure: Callable = cosine_similarity, | |
# return_as_text: bool = False, | |
# ) -> List[Tuple[str, float]]: | |
# query_vector = self.embedding_model.get_embedding(query_text) | |
# results = self.search(query_vector, k, distance_measure) | |
# return [result[0] for result in results] if return_as_text else results | |
def search_by_text( | |
self, | |
query_text: str, | |
k: int, | |
distance_measure: Callable = cosine_similarity, # Use your cosine_similarity by default | |
return_as_text: bool = False, | |
return_metadata: bool = True, | |
) -> List[Tuple[str, float, dict[str, any]]]: | |
query_vector = self.embedding_model.get_embedding(query_text) | |
results = self.search(query_vector, k, distance_measure) | |
if return_as_text and return_metadata: | |
return [(result[0], result[2]) for result in results] | |
elif return_as_text: | |
return [result[0] for result in results] | |
elif return_metadata: | |
return results | |
else: | |
return [(result[0], result[1]) for result in results] | |
# def retrieve_from_key(self, key: str) -> np.array: | |
# return self.vectors.get(key, None) | |
def retrieve_from_key(self, key: str) -> Tuple[np.array, dict[str, any]]: | |
# Retrieve both vector and metadata | |
data = self.data.get(key, None) | |
if data: | |
return data["vector"], data["metadata"] | |
return None, None | |
# async def abuild_from_list(self, list_of_text: List[str]) -> "VectorDatabase": | |
# embeddings = await self.embedding_model.async_get_embeddings(list_of_text) | |
# for text, embedding in zip(list_of_text, embeddings): | |
# self.insert(text, np.array(embedding)) | |
# return self | |
async def abuild_from_list(self, list_of_text: List[str], metadata_list: List[dict[str, any]] = None) -> "VectorDatabase": | |
embeddings = await self.embedding_model.async_get_embeddings(list_of_text) | |
for i, (text, embedding) in enumerate(zip(list_of_text, embeddings)): | |
metadata = metadata_list[i] if metadata_list else {} | |
self.insert(text, np.array(embedding), metadata) | |
return self | |
# if __name__ == "__main__": | |
# list_of_text = [ | |
# "I like to eat broccoli and bananas.", | |
# "I ate a banana and spinach smoothie for breakfast.", | |
# "Chinchillas and kittens are cute.", | |
# "My sister adopted a kitten yesterday.", | |
# "Look at this cute hamster munching on a piece of broccoli.", | |
# ] | |
# vector_db = VectorDatabase() | |
# vector_db = asyncio.run(vector_db.abuild_from_list(list_of_text)) | |
# k = 2 | |
# searched_vector = vector_db.search_by_text("I think fruit is awesome!", k=k) | |
# print(f"Closest {k} vector(s):", searched_vector) | |
# retrieved_vector = vector_db.retrieve_from_key( | |
# "I like to eat broccoli and bananas." | |
# ) | |
# print("Retrieved vector:", retrieved_vector) | |
# relevant_texts = vector_db.search_by_text( | |
# "I think fruit is awesome!", k=k, return_as_text=True | |
# ) | |
# print(f"Closest {k} text(s):", relevant_texts) |