__import__("pysqlite3") import sys sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") import uuid from collections import defaultdict from typing import Any, List import chromadb import numpy as np from chromadb import Collection from embeddings import Embedding from PIL.Image import Image from utils import base64_to_image class ChromaStore: def __init__( self, collection_name: str, storage_path: str = "./chroma", database: str = "database", metadata: dict = {"hnsw:space": "cosine"}, ) -> None: """Initiate Chromadb - collection_name(str): name of the collection - metadata(dict): available options for 'hnsw:space' are 'l2', 'ip' or 'cosine'. """ self.collection_name = collection_name self.metadata = metadata self.storage_path = storage_path self.database = database self.client = chromadb.PersistentClient(path=self.storage_path) def _health_check(self) -> bool: return isinstance(self.client.heartbeat(), int) def generate_embeddings( self, images: List[Image], embedding: Embedding ) -> np.ndarray: return embedding.encode_images(images) def create(self): collection = self.client.get_or_create_collection( name=self.collection_name, ) return collection def add( self, collection: Collection, embeddings: List[float], documents: List[str], ids: List[str], ): """Add embeddings, documents to index or collection. Args: - collection: created collection. - embeddings: list of image embeddings - documents: list of base64 string of images - ids: list of ids for images.""" try: collection.add( embeddings=embeddings, ids=ids, documents=documents, ) except Exception as e: raise Exception(f"Failed to add documents to Chroma store. {e}") def query( self, collection: Collection, query_embedding: List[float], top_k: int = 3, ) -> list: """Retrieve relevant images from chroma database. Args: - collection: created collection. - query_embedding: query image embedding. - top_k (int): top k images to retrieve. Returns: - list of images along with their score. """ result = collection.query(query_embeddings=query_embedding, n_results=top_k) relevant_images = [ base64_to_image(img_str) for img_str in result["documents"][0] ] scores = [round(score, 3) for score in result["distances"][0]] return list(zip(relevant_images, scores)) def delete(self, collection_name: str): try: self.client.delete_collection(collection_name) return True except Exception as e: raise Exception("Failed to delete collection", e) @staticmethod def collection_info(collection: Collection): info = defaultdict(str) info["count"] = collection.count() info["top_10_items"] = collection.peek() return info