Spaces:
Build error
Build error
| from typing import List, Tuple, Dict, Any | |
| import time | |
| from tqdm.notebook import tqdm | |
| from rich import print | |
| from retrieval_evaluation import calc_hit_rate_scores, calc_mrr_scores, record_results, add_params | |
| from llama_index.finetuning import EmbeddingQAFinetuneDataset | |
| from weaviate_interface import WeaviateClient | |
| def retrieval_evaluation(dataset: EmbeddingQAFinetuneDataset, | |
| class_name: str, | |
| retriever: WeaviateClient, | |
| retrieve_limit: int=5, | |
| chunk_size: int=256, | |
| hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'], | |
| display_properties: List[str]=['doc_id', 'guest', 'content'], | |
| dir_outpath: str='./eval_results', | |
| include_miss_info: bool=False, | |
| user_def_params: Dict[str,Any]=None | |
| ) -> Dict[str, str|int|float]: | |
| ''' | |
| Given a dataset and a retriever evaluate the performance of the retriever. Returns a dict of kw and vector | |
| hit rates and mrr scores. If inlude_miss_info is True, will also return a list of kw and vector responses | |
| and their associated queries that did not return a hit, for deeper analysis. Text file with results output | |
| is automatically saved in the dir_outpath directory. | |
| Args: | |
| ----- | |
| dataset: EmbeddingQAFinetuneDataset | |
| Dataset to be used for evaluation | |
| class_name: str | |
| Name of Class on Weaviate host to be used for retrieval | |
| retriever: WeaviateClient | |
| WeaviateClient object to be used for retrieval | |
| retrieve_limit: int=5 | |
| Number of documents to retrieve from Weaviate host | |
| chunk_size: int=256 | |
| Number of tokens used to chunk text. This value is purely for results | |
| recording purposes and does not affect results. | |
| display_properties: List[str]=['doc_id', 'content'] | |
| List of properties to be returned from Weaviate host for display in response | |
| dir_outpath: str='./eval_results' | |
| Directory path for saving results. Directory will be created if it does not | |
| already exist. | |
| include_miss_info: bool=False | |
| Option to include queries and their associated kw and vector response values | |
| for queries that are "total misses" | |
| user_def_params : dict=None | |
| Option for user to pass in a dictionary of user-defined parameters and their values. | |
| ''' | |
| results_dict = {'n':retrieve_limit, | |
| 'Retriever': retriever.model_name_or_path, | |
| 'chunk_size': chunk_size, | |
| 'kw_hit_rate': 0, | |
| 'kw_mrr': 0, | |
| 'vector_hit_rate': 0, | |
| 'vector_mrr': 0, | |
| 'total_misses': 0, | |
| 'total_questions':0 | |
| } | |
| #add hnsw configs and user defined params (if any) | |
| results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys) | |
| start = time.perf_counter() | |
| miss_info = [] | |
| for query_id, q in tqdm(dataset.queries.items(), 'Queries'): | |
| results_dict['total_questions'] += 1 | |
| hit = False | |
| #make Keyword, Vector, and Hybrid calls to Weaviate host | |
| try: | |
| kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties) | |
| vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties) | |
| #collect doc_ids and position of doc_ids to check for document matches | |
| kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response, 1)} | |
| vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response, 1)} | |
| #extract doc_id for scoring purposes | |
| doc_id = dataset.relevant_docs[query_id][0] | |
| #increment hit_rate counters and mrr scores | |
| if doc_id in kw_doc_ids: | |
| results_dict['kw_hit_rate'] += 1 | |
| results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id] | |
| hit = True | |
| if doc_id in vector_doc_ids: | |
| results_dict['vector_hit_rate'] += 1 | |
| results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id] | |
| hit = True | |
| # if no hits, let's capture that | |
| if not hit: | |
| results_dict['total_misses'] += 1 | |
| miss_info.append({'query': q, 'kw_response': kw_response, 'vector_response': vector_response}) | |
| except Exception as e: | |
| print(e) | |
| continue | |
| #use raw counts to calculate final scores | |
| calc_hit_rate_scores(results_dict) | |
| calc_mrr_scores(results_dict) | |
| end = time.perf_counter() - start | |
| print(f'Total Processing Time: {round(end/60, 2)} minutes') | |
| record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True) | |
| if include_miss_info: | |
| return results_dict, miss_info | |
| return results_dict |