|
from application.vectorstore.base import BaseVectorStore |
|
from application.core.settings import settings |
|
from application.vectorstore.document_class import Document |
|
import elasticsearch |
|
|
|
|
|
|
|
|
|
class ElasticsearchStore(BaseVectorStore): |
|
_es_connection = None |
|
|
|
def __init__(self, path, embeddings_key, index_name=settings.ELASTIC_INDEX): |
|
super().__init__() |
|
self.path = path.replace("application/indexes/", "").rstrip("/") |
|
self.embeddings_key = embeddings_key |
|
self.index_name = index_name |
|
|
|
if ElasticsearchStore._es_connection is None: |
|
connection_params = {} |
|
if settings.ELASTIC_URL: |
|
connection_params["hosts"] = [settings.ELASTIC_URL] |
|
connection_params["http_auth"] = (settings.ELASTIC_USERNAME, settings.ELASTIC_PASSWORD) |
|
elif settings.ELASTIC_CLOUD_ID: |
|
connection_params["cloud_id"] = settings.ELASTIC_CLOUD_ID |
|
connection_params["basic_auth"] = (settings.ELASTIC_USERNAME, settings.ELASTIC_PASSWORD) |
|
else: |
|
raise ValueError("Please provide either elasticsearch_url or cloud_id.") |
|
|
|
|
|
|
|
ElasticsearchStore._es_connection = elasticsearch.Elasticsearch(**connection_params) |
|
|
|
self.docsearch = ElasticsearchStore._es_connection |
|
|
|
def connect_to_elasticsearch( |
|
*, |
|
es_url = None, |
|
cloud_id = None, |
|
api_key = None, |
|
username = None, |
|
password = None, |
|
): |
|
try: |
|
import elasticsearch |
|
except ImportError: |
|
raise ImportError( |
|
"Could not import elasticsearch python package. " |
|
"Please install it with `pip install elasticsearch`." |
|
) |
|
|
|
if es_url and cloud_id: |
|
raise ValueError( |
|
"Both es_url and cloud_id are defined. Please provide only one." |
|
) |
|
|
|
connection_params = {} |
|
|
|
if es_url: |
|
connection_params["hosts"] = [es_url] |
|
elif cloud_id: |
|
connection_params["cloud_id"] = cloud_id |
|
else: |
|
raise ValueError("Please provide either elasticsearch_url or cloud_id.") |
|
|
|
if api_key: |
|
connection_params["api_key"] = api_key |
|
elif username and password: |
|
connection_params["basic_auth"] = (username, password) |
|
|
|
es_client = elasticsearch.Elasticsearch( |
|
**connection_params, |
|
) |
|
try: |
|
es_client.info() |
|
except Exception as e: |
|
raise e |
|
|
|
return es_client |
|
|
|
def search(self, question, k=2, index_name=settings.ELASTIC_INDEX, *args, **kwargs): |
|
embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key) |
|
vector = embeddings.embed_query(question) |
|
knn = { |
|
"filter": [{"match": {"metadata.store.keyword": self.path}}], |
|
"field": "vector", |
|
"k": k, |
|
"num_candidates": 100, |
|
"query_vector": vector, |
|
} |
|
full_query = { |
|
"knn": knn, |
|
"query": { |
|
"bool": { |
|
"must": [ |
|
{ |
|
"match": { |
|
"text": { |
|
"query": question, |
|
} |
|
} |
|
} |
|
], |
|
"filter": [{"match": {"metadata.store.keyword": self.path}}], |
|
} |
|
}, |
|
"rank": {"rrf": {}}, |
|
} |
|
resp = self.docsearch.search(index=self.index_name, query=full_query['query'], size=k, knn=full_query['knn']) |
|
|
|
doc_list = [] |
|
for hit in resp['hits']['hits']: |
|
|
|
doc_list.append(Document(page_content = hit['_source']['text'], metadata = hit['_source']['metadata'])) |
|
return doc_list |
|
|
|
def _create_index_if_not_exists( |
|
self, index_name, dims_length |
|
): |
|
|
|
if self._es_connection.indices.exists(index=index_name): |
|
print(f"Index {index_name} already exists.") |
|
|
|
else: |
|
|
|
indexSettings = self.index( |
|
dims_length=dims_length, |
|
) |
|
self._es_connection.indices.create(index=index_name, **indexSettings) |
|
|
|
def index( |
|
self, |
|
dims_length, |
|
): |
|
return { |
|
"mappings": { |
|
"properties": { |
|
"vector": { |
|
"type": "dense_vector", |
|
"dims": dims_length, |
|
"index": True, |
|
"similarity": "cosine", |
|
}, |
|
} |
|
} |
|
} |
|
|
|
def add_texts( |
|
self, |
|
texts, |
|
metadatas = None, |
|
ids = None, |
|
refresh_indices = True, |
|
create_index_if_not_exists = True, |
|
bulk_kwargs = None, |
|
**kwargs, |
|
): |
|
|
|
from elasticsearch.helpers import BulkIndexError, bulk |
|
|
|
bulk_kwargs = bulk_kwargs or {} |
|
import uuid |
|
embeddings = [] |
|
ids = ids or [str(uuid.uuid4()) for _ in texts] |
|
requests = [] |
|
embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key) |
|
|
|
vectors = embeddings.embed_documents(list(texts)) |
|
|
|
dims_length = len(vectors[0]) |
|
|
|
if create_index_if_not_exists: |
|
self._create_index_if_not_exists( |
|
index_name=self.index_name, dims_length=dims_length |
|
) |
|
|
|
for i, (text, vector) in enumerate(zip(texts, vectors)): |
|
metadata = metadatas[i] if metadatas else {} |
|
|
|
requests.append( |
|
{ |
|
"_op_type": "index", |
|
"_index": self.index_name, |
|
"text": text, |
|
"vector": vector, |
|
"metadata": metadata, |
|
"_id": ids[i], |
|
} |
|
) |
|
|
|
|
|
if len(requests) > 0: |
|
try: |
|
success, failed = bulk( |
|
self._es_connection, |
|
requests, |
|
stats_only=True, |
|
refresh=refresh_indices, |
|
**bulk_kwargs, |
|
) |
|
return ids |
|
except BulkIndexError as e: |
|
print(f"Error adding texts: {e}") |
|
firstError = e.errors[0].get("index", {}).get("error", {}) |
|
print(f"First error reason: {firstError.get('reason')}") |
|
raise e |
|
|
|
else: |
|
return [] |
|
|
|
def delete_index(self): |
|
self._es_connection.delete_by_query(index=self.index_name, query={"match": { |
|
"metadata.store.keyword": self.path}},) |
|
|
|
|