Spaces:
Sleeping
Sleeping
from prompts import get_classification_prompt, get_query_generation_prompt | |
from utils_code import initialize_openai_creds, create_llm | |
from llama_index.core.schema import QueryBundle, NodeWithScore | |
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever | |
from transformers import pipeline | |
from typing import List, Optional | |
import asyncio | |
from llama_index.core.postprocessor import SentenceTransformerRerank | |
from llama_index.core.indices.property_graph import LLMSynonymRetriever | |
from llama_index.core.indices.property_graph import VectorContextRetriever, PGRetriever | |
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever, KGTableRetriever | |
import os | |
class PARetriever(BaseRetriever): | |
"""Custom retriever that performs query rewriting, Vector search, and BM25 search without Knowledge Graph search.""" | |
def __init__( | |
self, | |
llm, # LLM for query generation | |
vector_retriever: Optional[VectorIndexRetriever] = None, | |
bm25_retriever: Optional[BaseRetriever] = None, | |
mode: str = "OR", | |
rewriter: bool = True, | |
classifier_model: Optional[str] = None, # Optional classifier model | |
device: str = 'cpu', # Device to CPU for huggingface demo | |
reranker_model_name: Optional[str] = None, # Model name for SentenceTransformerRerank | |
verbose: bool = False, # Verbose flag | |
fixed_params: Optional[dict] = None, # New parameter to pass in fixed parameters | |
categories_list: Optional[List[str]] = None, # List of categories for query classification | |
param_mappings: Optional[dict] = None # Custom parameter mappings based on classifier labels | |
) -> None: | |
"""Initialize PARetriever parameters.""" | |
self._vector_retriever = vector_retriever | |
self._bm25_retriever = bm25_retriever | |
self._llm = llm | |
self._rewriter = rewriter | |
self._mode = mode | |
self._reranker_model_name = reranker_model_name | |
self._reranker = None # Initialize reranker as None | |
self.verbose = verbose | |
self.fixed_params = fixed_params | |
self.categories_list = categories_list | |
self.param_mappings = param_mappings or { | |
"label_0": {"top_k": 5, "max_keywords_per_query": 3, "max_knowledge_sequence": 1}, | |
"label_1": {"top_k": 7, "max_keywords_per_query": 4, "max_knowledge_sequence": 2}, | |
"label_2": {"top_k": 10, "max_keywords_per_query": 5, "max_knowledge_sequence": 3} | |
} | |
# Initialize the classifier if provided | |
self.classifier = None | |
if classifier_model: | |
self.classifier = pipeline("text-classification", model=classifier_model, device=device) | |
if mode not in ("AND", "OR"): | |
raise ValueError("Invalid mode.") | |
def classify_query_and_get_params(self, query: str) -> (str, dict): | |
"""Classify the query and determine adaptive parameters or use fixed parameters.""" | |
if self.fixed_params: | |
# Use fixed parameters from the dictionary if provided | |
params = self.fixed_params | |
classification_result = "Fixed" | |
if self.verbose: | |
print(f"Using fixed parameters: {params}") | |
else: | |
params = { | |
"top_k": 5, # Default top-k | |
"max_keywords_per_query": 4, # Default max keywords | |
"max_knowledge_sequence": 2 # Default max knowledge sequence | |
} | |
classification_result = None | |
if self.classifier: | |
classification = self.classifier(query)[0] | |
label = classification['label'] # Get the classification label directly | |
classification_result = label # Store the classification result | |
if self.verbose: | |
print(f"Query Classification: {classification['label']} with score {classification['score']}") | |
# Use custom mappings or default mappings | |
if label in self.param_mappings: | |
params = self.param_mappings[label] | |
else: | |
if self.verbose: | |
print(f"Warning: No mapping found for label {label}, using default parameters.") | |
self._classification_result = classification_result | |
return classification_result, params | |
def classify_query(self, query_str: str) -> Optional[str]: | |
"""Classify the query into one of the predefined categories using LLM, or skip if no categories are provided.""" | |
if not self.categories_list: | |
if self.verbose: | |
print("No categories provided, skipping query classification.") | |
return None | |
# Generate the classification prompt using external function | |
classification_prompt = get_classification_prompt(self.categories_list) + f" Query: '{query_str}'" | |
response = self._llm.complete(classification_prompt) | |
category = response.text.strip() | |
# Return the category only if it's in the categories list | |
return category if category in self.categories_list else None | |
def generate_queries(self, query_str: str, category: Optional[str], num_queries: int = 3) -> List[str]: | |
"""Generate query variations using the LLM, taking into account the category if applicable.""" | |
# Generate query generation prompt using external function | |
query_gen_prompt = get_query_generation_prompt(query_str, num_queries) | |
response = self._llm.complete(query_gen_prompt) | |
queries = response.text.split("\n") | |
queries = [query.strip() for query in queries if query.strip()] | |
if category: | |
category_query = f"{category}" | |
queries.append(category_query) | |
return queries | |
async def run_queries(self, queries: List[str], retrievers: List[BaseRetriever]) -> dict: | |
"""Run queries against retrievers.""" | |
tasks = [] | |
for query in queries: | |
for retriever in retrievers: | |
tasks.append(retriever.aretrieve(query)) | |
task_results = await asyncio.gather(*tasks) | |
results_dict = {} | |
for i, (query, query_result) in enumerate(zip(queries, task_results)): | |
results_dict[(query, i)] = query_result | |
return results_dict | |
def fuse_vector_and_bm25_results(self, results_dict, similarity_top_k: int) -> List[NodeWithScore]: | |
"""Fuse results from Vector and BM25 retrievers.""" | |
k = 60.0 # `k` is a parameter used to control the impact of outlier rankings. | |
fused_scores = {} | |
text_to_node = {} | |
for nodes_with_scores in results_dict.values(): | |
for rank, node_with_score in enumerate( | |
sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True) | |
): | |
text = node_with_score.node.get_content() | |
text_to_node[text] = node_with_score | |
if text not in fused_scores: | |
fused_scores[text] = 0.0 | |
fused_scores[text] += 1.0 / (rank + k) | |
reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)) | |
reranked_nodes: List[NodeWithScore] = [] | |
for text, score in reranked_results.items(): | |
if text in text_to_node: | |
node = text_to_node[text] | |
node.score = score | |
reranked_nodes.append(node) | |
else: | |
if self.verbose: | |
print(f"Warning: Text not found in `text_to_node`: {text}") | |
return reranked_nodes[:similarity_top_k] | |
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | |
"""Retrieve nodes given query.""" | |
if self._rewriter: | |
category = self.classify_query(query_bundle.query_str) | |
if self.verbose and category: | |
print(f"Classified Category: {category}") | |
classification_result, params = self.classify_query_and_get_params(query_bundle.query_str) | |
self._classification_result = classification_result | |
top_k = params["top_k"] | |
if self._reranker_model_name: | |
self._reranker = SentenceTransformerRerank(model=self._reranker_model_name, top_n=top_k) | |
if self.verbose: | |
print(f"Initialized reranker with top_n: {top_k}") | |
num_queries = 3 if top_k == 5 else 5 if top_k == 7 else 7 | |
if self.verbose: | |
print(f"Number of Query Rewrites: {num_queries}") | |
if self._rewriter: | |
queries = self.generate_queries(query_bundle.query_str, category, num_queries=num_queries) | |
if self.verbose: | |
print(f"Generated Queries: {queries}") | |
else: | |
queries = [query_bundle.query_str] | |
active_retrievers = [] | |
if self._vector_retriever: | |
active_retrievers.append(self._vector_retriever) | |
if self._bm25_retriever: | |
active_retrievers.append(self._bm25_retriever) | |
if not active_retrievers: | |
raise ValueError("No active retriever provided!") | |
results = {} | |
if active_retrievers: | |
results = asyncio.run(self.run_queries(queries, active_retrievers)) | |
if self.verbose: | |
print(f"Fusion Results: {results}") | |
final_results = self.fuse_vector_and_bm25_results(results, similarity_top_k=top_k) | |
if self._reranker: | |
final_results = self._reranker.postprocess_nodes(final_results, query_bundle) | |
if self.verbose: | |
print(f"Reranked Results: {final_results}") | |
else: | |
final_results = final_results[:top_k] | |
if self._rewriter: | |
unique_nodes = {} | |
for node in final_results: | |
content = node.node.get_content() | |
if content not in unique_nodes: | |
unique_nodes[content] = node | |
final_results = list(unique_nodes.values()) | |
if self.verbose: | |
print(f"Final Results: {final_results}") | |
return final_results | |
def get_classification_result(self) -> str: | |
return getattr(self, "_classification_result", None) | |
class HyPARetriever(PARetriever): | |
"""Custom retriever that extends PARetriever with knowledge graph (KG) search.""" | |
def __init__( | |
self, | |
llm, # LLM for query generation | |
vector_retriever: Optional[VectorIndexRetriever] = None, | |
bm25_retriever: Optional[BaseRetriever] = None, | |
kg_index=None, # Pass the knowledge graph index | |
property_index: bool = True, # Whether to use the property graph for retrieval | |
pg_filters=None, | |
**kwargs, # Pass any additional arguments to PARetriever | |
): | |
# Initialize PARetriever to reuse all its functionality | |
super().__init__( | |
llm=llm, | |
vector_retriever=vector_retriever, | |
bm25_retriever=bm25_retriever, | |
**kwargs | |
) | |
# Initialize knowledge graph (KG) specific components | |
self._pg_filters = pg_filters | |
self._kg_index = kg_index | |
self.property_index = property_index | |
def _initialize_kg_retriever(self, params): | |
"""Initialize the KG retriever based on retrieval mode.""" | |
graph_index = self._kg_index | |
filters = self._pg_filters | |
if self._kg_index and not self.property_index: | |
# If not using property index, use KGTableRetriever | |
return KGTableRetriever( | |
index=self._kg_index, | |
retriever_mode='hybrid', | |
max_keywords_per_query=params["max_keywords_per_query"], | |
max_knowledge_sequence=params["max_knowledge_sequence"] | |
) | |
elif self._kg_index and self.property_index: | |
# If using property index, use the simpler graph index retriever | |
# Use this for the DEMO | |
vector_retriever = VectorContextRetriever( | |
graph_store=graph_index.property_graph_store, | |
similarity_top_k=params["max_keywords_per_query"], | |
path_depth=params["max_knowledge_sequence"], | |
include_text=True, | |
filters=filters | |
) | |
synonym_retriever = LLMSynonymRetriever( | |
graph_store=graph_index.property_graph_store, | |
llm=self._llm, | |
include_text=True, | |
filters=filters | |
) | |
return graph_index.as_retriever(sub_retrievers=[vector_retriever, synonym_retriever]) | |
#return graph_index.as_retriever(similarity_top_k=params["top_k"]) | |
return None | |
def _combine_with_kg_results(self, vector_bm25_results, kg_results): | |
"""Combine KG results with vector and BM25 results.""" | |
vector_ids = {n.node.id_ for n in vector_bm25_results} | |
kg_ids = {n.node.id_ for n in kg_results} | |
combined_results = {n.node.id_: n for n in vector_bm25_results} | |
combined_results.update({n.node.id_: n for n in kg_results}) | |
if self._mode == "AND": | |
result_ids = vector_ids.intersection(kg_ids) | |
else: | |
result_ids = vector_ids.union(kg_ids) | |
return [combined_results[rid] for rid in result_ids] | |
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | |
"""Retrieve nodes with KG integration.""" | |
# Call PARetriever's _retrieve to get the vector and BM25 results | |
final_results = super()._retrieve(query_bundle) | |
# If we have a KG index, initialize the retriever | |
if self._kg_index: | |
kg_retriever = self._initialize_kg_retriever(self.classify_query_and_get_params(query_bundle.query_str)[1]) | |
if kg_retriever: | |
kg_nodes = kg_retriever.retrieve(query_bundle) | |
# Only combine KG and vector/BM25 results if property_index is True | |
if self.property_index: | |
final_results = self._combine_with_kg_results(final_results, kg_nodes) | |
return final_results | |
import os | |
from dotenv import load_dotenv | |
from llama_index.llms.azure_openai import AzureOpenAI | |
from llama_index.core import VectorStoreIndex, Settings | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.core.retrievers import KGTableRetriever, VectorIndexRetriever | |
from llama_index.retrievers.bm25 import BM25Retriever | |
from llama_index.readers.file import PyMuPDFReader | |
from llama_index.core.chat_engine import ContextChatEngine | |
from llama_index.core.memory.chat_memory_buffer import ChatMemoryBuffer | |
from llama_index.core import KnowledgeGraphIndex | |
from retrievers import PARetriever, HyPARetriever | |
def load_documents(): | |
"""Load and return documents from specified file paths.""" | |
loader = PyMuPDFReader() | |
documents1 = loader.load(file_path="../../legal_data/LL144/LL144.pdf") | |
documents2 = loader.load(file_path="../../legal_data/LL144/LL144_Definitions.pdf") | |
return documents1 + documents2 | |
def create_indices(documents, llm, embed_model): | |
"""Create and return VectorStoreIndex and KnowledgeGraphIndex from documents.""" | |
splitter = SentenceSplitter(chunk_size=512) | |
vector_index = VectorStoreIndex.from_documents( | |
documents, | |
embed_model=embed_model, | |
transformations=[splitter] | |
) | |
"""graph_index = KnowledgeGraphIndex.from_documents( | |
documents, | |
max_triplets_per_chunk=5, | |
llm=llm, | |
embed_model=embed_model, | |
include_embeddings=True, | |
transformations=[splitter] | |
)""" | |
return vector_index#, graph_index | |
def create_retrievers(vector_index, graph_index, llm, category_list): | |
"""Create and return the PA and HyPA retrievers.""" | |
vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=10) | |
bm25_retriever = BM25Retriever.from_defaults(index=vector_index, similarity_top_k=10) | |
PA_retriever = PARetriever( | |
llm=llm, | |
categories_list=category_list, | |
rewriter=True, | |
vector_retriever=vector_retriever, | |
bm25_retriever=bm25_retriever, | |
classifier_model="rk68/distilbert-q-classifier-3", | |
verbose=False | |
) | |
HyPA_retriever = HyPARetriever( | |
llm=llm, | |
categories_list=category_list, | |
rewriter=True, | |
kg_index=graph_index, | |
vector_retriever=vector_retriever, | |
bm25_retriever=bm25_retriever, | |
classifier_model="rk68/distilbert-q-classifier-3", | |
verbose=False, | |
property_index=False | |
) | |
return PA_retriever, HyPA_retriever | |
def create_chat_engine(retriever, memory): | |
"""Create and return the ContextChatEngine using the provided retriever and memory.""" | |
return ContextChatEngine.from_defaults( | |
retriever=retriever, | |
verbose=False, | |
chat_mode="context", | |
memory_cls=memory, | |
memory=memory | |
) | |
def main(): | |
# Initialize environment and LLM | |
gpt35_creds, gpt4o_mini_creds, gpt4o_creds = initialize_openai_creds() | |
llm_gpt35 = create_llm(gpt35_creds=gpt35_creds, gpt4o_mini_creds=gpt4o_mini_creds, gpt4o_creds=gpt4o_creds) | |
# Set global settings for embedding model and LLM | |
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5") | |
Settings.embed_model = embed_model | |
Settings.llm = llm_gpt35 | |
category_list = [ | |
'5-301 Bias Audit', | |
'5-302 Data Requirements', | |
'§ 5-303 Published Results', | |
'§ 5-304 Notice to Candidates and Employees' | |
] | |
# Load documents and create indices | |
documents = load_documents() | |
vector_index, graph_index = create_indices(documents, llm_gpt35, embed_model) | |
# Create retrievers | |
PA_retriever, HyPA_retriever = create_retrievers(vector_index, graph_index, llm_gpt35, category_list) | |
# Initialize chat memory | |
memory = ChatMemoryBuffer.from_defaults(token_limit=8192) | |
# Create chat engines | |
PA_chat_engine = create_chat_engine(PA_retriever, memory) | |
HyPA_chat_engine = create_chat_engine(HyPA_retriever, memory) | |
# Sample question and response | |
question = "What is a bias audit?" | |
PA_response = PA_chat_engine.chat(question) | |
HyPA_response = HyPA_chat_engine.chat(question) | |
# Output responses in a nicely formatted manner | |
print("\n" + "="*50) | |
print(f"Question: {question}") | |
print("="*50) | |
print("\n------- PA Retriever Response -------") | |
print(PA_response) | |
print("\n------- HyPA Retriever Response -------") | |
print(HyPA_response) | |
print("="*50 + "\n") | |
if __name__ == '__main__': | |
main() | |