legal-ease / base /create_collection.py
shivi's picture
update app theme
38f8c33
# Run this script independently (`python create_collection.py`) to create a Qdrant collection. A Qdrant collection is a set of vectors among which you can search.
# All the legal documents over which search needs to be enabled need to be converted to their embedding representation and inserted into a Qdrant collection for search feature to work.
import os
import cohere
from datasets import load_dataset
from qdrant_client import QdrantClient
from qdrant_client import models
from qdrant_client.http import models as rest
from constants import (
ENGLISH_EMBEDDING_MODEL,
MULTILINGUAL_EMBEDDING_MODEL,
USE_MULTILINGUAL_EMBEDDING,
CREATE_QDRANT_COLLECTION_NAME,
)
# load environment variables
QDRANT_HOST = os.environ.get("QDRANT_HOST")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
def get_embedding_size():
"""
Get the dimensions of the embeddings returned by the model being used to create embeddings for documents.
Returns:
embedding_size (`int`):
The dimensions of the embeddings returned by the embeddings model.
"""
if USE_MULTILINGUAL_EMBEDDING:
embedding_size = 768
else:
embedding_size = 4096
return embedding_size
def create_qdrant_collection(vector_size):
"""
(Re)-create a Qdrant Collection with the desired `collection name` , `vector_size` and `distance_measure`.
This collection will be used to keep all the vectors representing all the legal documents.
Args:
vector_size (`int`):
The dimensions of the embeddings that will be added to the collection.
"""
if USE_MULTILINGUAL_EMBEDDING:
# multilingual embedding model trained using dot product calculation
distance_measure = rest.Distance.DOT
else:
distance_measure = rest.Distance.COSINE
print("CREATE_QDRANT_COLLECTION_NAME:", CREATE_QDRANT_COLLECTION_NAME)
qdrant_client.recreate_collection(
collection_name=CREATE_QDRANT_COLLECTION_NAME,
vectors_config=models.VectorParams(size=vector_size, distance=distance_measure),
)
def embed_legal_docs(legal_docs):
"""
Create embeddings and ids which will used to represent the legal documents upon which search needs to be enabled.
Args:
legal_docs (`List`):
A list of documents for which embeddings need to be created.
Returns:
doc_embeddings (`List`):
A list of embeddings corresponding to each document.
doc_ids (`List`):
A list of unique ids which will be used as identifiers for the points (documents) in a qdrant collection.
"""
if USE_MULTILINGUAL_EMBEDDING:
model_name = MULTILINGUAL_EMBEDDING_MODEL
else:
model_name = ENGLISH_EMBEDDING_MODEL
legal_docs_embeds = cohere_client.embed(
texts=legal_docs,
model=model_name,
)
doc_embeddings = [
list(map(float, vector)) for vector in legal_docs_embeds.embeddings
]
doc_ids = [id for id, _ in enumerate(legal_docs_embeds)]
return doc_embeddings, doc_ids
def upsert_data_in_collection(vectors, ids, payload):
"""
Create embeddings and ids which will used to represent the legal documents upon which search needs to be enabled.
Args:
vectors (`List`):
A list of embeddings corresponding to each document which needs to be added to the collection.
ids (`List`):
A list of unique ids which will be used as identifiers for the points (documents) in a qdrant collection.
payload (`List`):
A list of additional information or metadata corresponding to each document being added to the collection.
"""
try:
update_result = qdrant_client.upsert(
collection_name=CREATE_QDRANT_COLLECTION_NAME,
points=rest.Batch(
ids=ids,
vectors=vectors,
payloads=payload,
),
)
return update_result
except:
return None
def fetch_legal_documents_and_payload():
"""
Get the legal documents and additional information (payload) related to them which will be used as part of the search module.
Returns:
legal_docs (`List['str]`):
The documents that will be used as part of the search module.
payload (`List[Dict]`):
Additional information related to the documents that are being used as part of the search module.
"""
legal_dataset = load_dataset("joelito/covid19_emergency_event", split="train")
legal_docs = legal_dataset["text"]
# prepare payload (additional information or metadata for documents being inserted)
payload = list(legal_dataset)
return payload, legal_docs
if __name__ == "__main__":
# create qdrant and cohere client
cohere_client = cohere.Client(COHERE_API_KEY)
qdrant_client = QdrantClient(
host=QDRANT_HOST,
prefer_grpc=True,
api_key=QDRANT_API_KEY,
)
# fetch the size of the embeddings depending on which model is being used to create embeddings for documents
vector_size = get_embedding_size()
# create a collection in Qdrant
create_qdrant_collection(vector_size)
# load the set of documents which will be inserted into the Qdrant collection
payload, legal_docs = fetch_legal_documents_and_payload()
# create embedddings for documents and IDs for documents before insertion into Qdrant collection
doc_embeddings, doc_ids = embed_legal_docs(legal_docs)
# insert/update documents in the previously created qdrant collection
update_result = upsert_data_in_collection(doc_embeddings, doc_ids, payload)
collection_info = qdrant_client.get_collection(
collection_name=CREATE_QDRANT_COLLECTION_NAME
)
if update_result is not None:
if collection_info.vectors_count == len(legal_docs):
print("All documents have been successfully added to Qdrant Collection!")
else:
print("Failed to add documents to Qdrant collection")