File size: 4,794 Bytes
be2f825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import json
import logging
from typing import List

from txtai.embeddings import Embeddings

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EmbeddingsManager:
    def __init__(self, base_path: str = "./indexes", model_path: str = "avsolatorio/GIST-all-MiniLM-L6-v2"):
        """
        Initializes the EmbeddingsManager.

        Args:
            base_path (str): Base directory to store indices.
            model_path (str): Path or identifier for the embeddings model.
        """
        self.base_path = base_path
        os.makedirs(self.base_path, exist_ok=True)
        self.model_path = model_path
        self.embeddings = Embeddings({"path": self.model_path})
        logger.info(f"Embeddings model loaded from '{self.model_path}'. Base path set to '{self.base_path}'.")

    def create_index(self, index_id: str, documents: List[str]) -> None:
        """
        Creates a new embeddings index with the provided documents.

        Args:
            index_id (str): Unique identifier for the index.
            documents (List[str]): List of documents to be indexed.

        Raises:
            ValueError: If the index already exists.
            Exception: For any other errors during indexing or saving.
        """
        index_path = os.path.join(self.base_path, index_id)
        if os.path.exists(index_path):
            logger.error(f"Index with index_id '{index_id}' already exists at '{index_path}'.")
            raise ValueError(f"Index with index_id '{index_id}' already exists.")

        try:
            # Prepare documents for txtai indexing
            document_tuples = [(i, text, None) for i, text in enumerate(documents)]
            self.embeddings.index(document_tuples)
            logger.info(f"Documents indexed for index_id '{index_id}'.")

            # Create index directory
            os.makedirs(index_path, exist_ok=True)

            # Save embeddings
            self.embeddings.save(os.path.join(index_path, "embeddings"))
            logger.info(f"Embeddings saved to '{os.path.join(index_path, 'embeddings')}'.")

            # Save document list
            with open(os.path.join(index_path, "document_list.json"), "w", encoding='utf-8') as f:
                json.dump(documents, f, ensure_ascii=False, indent=4)
            logger.info(f"Document list saved to '{os.path.join(index_path, 'document_list.json')}'.")

            logger.info(f"Index '{index_id}' created and saved successfully.")
        except Exception as e:
            logger.error(f"Failed to create index '{index_id}': {e}")
            raise Exception(f"Failed to create index '{index_id}': {e}")

    def query_index(self, index_id: str, query: str, num_results: int = 5) -> List[str]:
        """
        Queries an existing embeddings index.

        Args:
            index_id (str): Unique identifier for the index to query.
            query (str): The search query.
            num_results (int): Number of top results to return.

        Returns:
            List[str]: List of top matching documents.

        Raises:
            FileNotFoundError: If the index does not exist.
            Exception: For any other errors during querying.
        """
        index_path = os.path.join(self.base_path, index_id)
        if not os.path.exists(index_path):
            logger.error(f"Index '{index_id}' not found at '{index_path}'.")
            raise FileNotFoundError(f"Index '{index_id}' not found.")

        try:
            # Load embeddings from the index
            self.embeddings.load(os.path.join(index_path, "embeddings"))
            logger.info(f"Embeddings loaded from '{os.path.join(index_path, 'embeddings')}' for index '{index_id}'.")

            # Load document list
            document_list_path = os.path.join(index_path, "document_list.json")
            if not os.path.exists(document_list_path):
                logger.error(f"Document list not found at '{document_list_path}'.")
                raise FileNotFoundError(f"Document list not found for index '{index_id}'.")

            with open(document_list_path, "r", encoding='utf-8') as f:
                document_list = json.load(f)
            logger.info(f"Document list loaded from '{document_list_path}'.")

            # Perform the search
            results = self.embeddings.search(query, num_results)
            queried_texts = [document_list[idx[0]] for idx in results]
            logger.info(f"Query executed successfully on index '{index_id}'. Retrieved {len(queried_texts)} results.")

            return queried_texts
        except Exception as e:
            logger.error(f"Failed to query index '{index_id}': {e}")
            raise Exception(f"Failed to query index '{index_id}': {e}")