Spaces:
Runtime error
Runtime error
import os | |
import cohere | |
from typing import List | |
from qdrant_client import QdrantClient | |
from qdrant_client import models | |
from .constants import ( | |
MULTILINGUAL_EMBEDDING_MODEL, | |
ENGLISH_EMBEDDING_MODEL, | |
SEARCH_QDRANT_COLLECTION_NAME, | |
TRANSLATE_BASED_ON_USER_QUERY, | |
TEXT_GENERATION_MODEL, | |
USE_MULTILINGUAL_EMBEDDING, | |
) | |
# load environment variables | |
QDRANT_HOST = os.environ.get("QDRANT_HOST") | |
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") | |
COHERE_API_KEY = os.environ.get("COHERE_API_KEY") | |
# create qdrant and cohere client | |
cohere_client = cohere.Client(COHERE_API_KEY) | |
qdrant_client = QdrantClient( | |
host=QDRANT_HOST, | |
prefer_grpc=False, | |
api_key=QDRANT_API_KEY, | |
port=443, | |
) | |
def embed_user_query(user_query: str) -> List: | |
""" | |
Create an embedding for the given query by the user using Cohere's Embed API. | |
Args: | |
user_query (`str`): | |
The input query by the user based on which search will be performed with the help of Qdrant. | |
Returns: | |
query_embedding (`List`): | |
A list of numbers or vector representing the user query. | |
""" | |
if USE_MULTILINGUAL_EMBEDDING: | |
model_name = MULTILINGUAL_EMBEDDING_MODEL | |
else: | |
model_name = ENGLISH_EMBEDDING_MODEL | |
embeddings = cohere_client.embed( | |
texts=[user_query], | |
model=model_name, | |
) | |
query_embedding = embeddings.embeddings[0] | |
return query_embedding | |
def search_docs_for_query( | |
query_embedding: List, | |
num_results: int, | |
user_query: str, | |
languages: List, | |
match_text: List, | |
) -> List: | |
""" | |
Perform search on the collection of documents for the given user query using Qdrant's search API. | |
Args: | |
query_embedding (`List`): | |
A vector representing the user query. | |
num_results (`str`): | |
The number of expected search results. | |
user_query (`str`): | |
The user input based on which search will be performed. | |
languages (`str`): | |
The list of languages based on which search results must be filtered. | |
match_text (`List`): | |
A field based on which it is decided whether to perform full-text-match while performing search. | |
Returns: | |
results (`List[ScoredPoint]`): | |
A list of `ScoredPoint` objects returned via Qdrant's search API. | |
""" | |
filters = [] | |
language_mapping = { | |
"Dutch": "nl", | |
"English": "en", | |
"French": "fr", | |
"Hungarian": "hu", | |
"Italian": "it", | |
"Norwegian": "nb", | |
"Polish": "pl", | |
} | |
# prepare filters to narrow down search results | |
# if the `match_text` list is not empty then create filter to find exact matching text in the documents | |
if match_text: | |
filters.append( | |
models.FieldCondition( | |
key="text", | |
match=models.MatchText(text=user_query), | |
) | |
) | |
# filter documents based on language before performing search: | |
if languages: | |
for lang in languages: | |
filters.append( | |
models.FieldCondition( | |
key="language", | |
match=models.MatchValue( | |
value=language_mapping[lang], | |
), | |
) | |
) | |
# perform search and get results | |
results = qdrant_client.search( | |
collection_name=SEARCH_QDRANT_COLLECTION_NAME, | |
query_filter=models.Filter(should=filters), | |
search_params=models.SearchParams(hnsw_ef=128, exact=False), | |
query_vector=query_embedding, | |
limit=num_results, | |
) | |
return results | |
def translate_search_result(input_sentence, user_query): | |
""" | |
Translate a given input sentence to the required target language. The required target language is `English` by default. | |
The target language can be changed to match the language that the user used to type his search query by setting the `TRANSLATE_BASED_ON_USER_QUERY` to `True`. | |
Args: | |
input_sentence (`str`): | |
The sentence which needs to be translated into the required target language. | |
user_query (`str`): | |
The user input based on which the target language for translation will be determined if `TRANSLATE_BASED_ON_USER_QUERY` is set to `True`. | |
Returns: | |
translation (`str`): | |
The final translation result obtained using Cohere's Generate API. | |
""" | |
response = cohere_client.tokenize(text=input_sentence) | |
src_detected_lang = cohere_client.detect_language(texts=[input_sentence]) | |
src_current_lang = src_detected_lang.results[0].language_name | |
if TRANSLATE_BASED_ON_USER_QUERY: | |
target_detected_lang = cohere_client.detect_language(texts=[user_query]) | |
target_current_lang = target_detected_lang.results[0].language_name | |
else: | |
target_current_lang = "English" | |
if target_current_lang == src_current_lang: | |
return input_sentence | |
prompt = f"""" | |
Translate this sentence from {src_current_lang} to {target_current_lang}: '{input_sentence}'. | |
Don't include the above prompt in the final translation. The final output should only include the translation of the input sentence. | |
""" | |
response = cohere_client.generate( | |
model=TEXT_GENERATION_MODEL, | |
prompt=prompt, | |
max_tokens=len(response.tokens) * 3, | |
temperature=0.6, | |
stop_sequences=["--"], | |
) | |
translation = response.generations[0].text | |
return translation | |
def cross_lingual_document_search( | |
user_input: str, num_results: int, languages, text_match | |
) -> List: | |
""" | |
Wrapper function for performing search on the collection of documents for the given user query. | |
Prepares query embedding, retrieves search results, checks if expected number of search results are being returned. | |
Args: | |
user_input (`str`): | |
The user input based on which search will be performed. | |
num_results (`str`): | |
The number of expected search results. | |
languages (`str`): | |
The list of languages based on which search results must be filtered. | |
text_match (`str`): | |
A field based on which it is decided whether to perform full-text-match while performing search. | |
Returns: | |
final_results (`List[str]`): | |
A list containing the final search results corresponding to the given user input. | |
""" | |
# create an embedding for the input query | |
query_embedding = embed_user_query(user_input) | |
# retrieve search results | |
result = search_docs_for_query( | |
query_embedding, | |
num_results, | |
user_input, | |
languages, | |
text_match, | |
) | |
final_results = [result[i].payload["text"] for i in range(len(result))] | |
# check if number of search results obtained (i.e. `final_results`) is matching with number of expected search results i.e. `num_results` | |
if num_results > len(final_results): | |
remaining_inputs = num_results - len(final_results) | |
for input in range(remaining_inputs): | |
final_results.append("") | |
return final_results | |