pwcGraphRAG / search_handlers.py
cordwainersmith
Add project files and Docker setup
c917d47
import asyncio
from pathlib import Path
import pandas as pd
from typing import Tuple, Optional
from graphrag.config import GraphRagConfig, load_config, resolve_paths
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.logging import PrintProgressReporter
from graphrag.utils.storage import _create_storage, _load_table_from_storage
import graphrag.api as api
class StreamlitProgressReporter(PrintProgressReporter):
def __init__(self, placeholder):
super().__init__("")
self.placeholder = placeholder
def success(self, message: str):
self.placeholder.success(message)
def _resolve_parquet_files(
root_dir: str,
config: GraphRagConfig,
parquet_list: list[str],
optional_list: list[str],
) -> dict[str, pd.DataFrame]:
"""Read parquet files to a dataframe dict."""
dataframe_dict = {}
pipeline_config = create_pipeline_config(config)
storage_obj = _create_storage(root_dir=root_dir, config=pipeline_config.storage)
for parquet_file in parquet_list:
df_key = parquet_file.split(".")[0]
df_value = asyncio.run(
_load_table_from_storage(name=parquet_file, storage=storage_obj)
)
dataframe_dict[df_key] = df_value
for optional_file in optional_list:
file_exists = asyncio.run(storage_obj.has(optional_file))
df_key = optional_file.split(".")[0]
if file_exists:
df_value = asyncio.run(
_load_table_from_storage(name=optional_file, storage=storage_obj)
)
dataframe_dict[df_key] = df_value
else:
dataframe_dict[df_key] = None
return dataframe_dict
def run_global_search(
config_filepath: Optional[str],
data_dir: Optional[str],
root_dir: str,
community_level: int,
response_type: str,
streaming: bool,
query: str,
progress_placeholder,
) -> Tuple[str, dict]:
"""Perform a global search with a given query."""
root = Path(root_dir).resolve()
config = load_config(root, config_filepath)
reporter = StreamlitProgressReporter(progress_placeholder)
config.storage.base_dir = data_dir or config.storage.base_dir
resolve_paths(config)
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_entities.parquet",
"create_final_community_reports.parquet",
],
optional_list=[],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]
if streaming:
async def run_streaming_search():
full_response = ""
context_data = None
get_context_data = True
try:
async for stream_chunk in api.global_search_streaming(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = stream_chunk
get_context_data = False
else:
full_response += stream_chunk
progress_placeholder.markdown(full_response)
except Exception as e:
progress_placeholder.error(f"Error during streaming search: {e}")
return None, None
return full_response, context_data
result = asyncio.run(run_streaming_search())
if result is None:
return "", {} # Graceful fallback
return result
# Non-streaming logic
try:
response, context_data = asyncio.run(
api.global_search(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
community_level=community_level,
response_type=response_type,
query=query,
)
)
reporter.success(f"Global Search Response:\n{response}")
return response, context_data
except Exception as e:
progress_placeholder.error(f"Error during global search: {e}")
return "", {} # Graceful fallback
def run_local_search(
config_filepath: Optional[str],
data_dir: Optional[str],
root_dir: str,
community_level: int,
response_type: str,
streaming: bool,
query: str,
progress_placeholder,
) -> Tuple[str, dict]:
"""Perform a local search with a given query."""
root = Path(root_dir).resolve()
config = load_config(root, config_filepath)
reporter = StreamlitProgressReporter(progress_placeholder)
config.storage.base_dir = data_dir or config.storage.base_dir
resolve_paths(config)
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_community_reports.parquet",
"create_final_text_units.parquet",
"create_final_relationships.parquet",
"create_final_entities.parquet",
],
optional_list=["create_final_covariates.parquet"],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_covariates: Optional[pd.DataFrame] = dataframe_dict["create_final_covariates"]
if streaming:
async def run_streaming_search():
full_response = ""
context_data = None
get_context_data = True
async for stream_chunk in api.local_search_streaming(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
covariates=final_covariates,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = stream_chunk
get_context_data = False
else:
full_response += stream_chunk
progress_placeholder.markdown(full_response)
return full_response, context_data
return asyncio.run(run_streaming_search())
response, context_data = asyncio.run(
api.local_search(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
covariates=final_covariates,
community_level=community_level,
response_type=response_type,
query=query,
)
)
reporter.success(f"Local Search Response:\n{response}")
return response, context_data
def run_drift_search(
config_filepath: Optional[str],
data_dir: Optional[str],
root_dir: str,
community_level: int,
response_type: str,
streaming: bool,
query: str,
progress_placeholder,
) -> Tuple[str, dict]:
"""Perform a DRIFT search with a given query."""
root = Path(root_dir).resolve()
config = load_config(root, config_filepath)
reporter = StreamlitProgressReporter(progress_placeholder)
config.storage.base_dir = data_dir or config.storage.base_dir
resolve_paths(config)
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_entities.parquet",
"create_final_community_reports.parquet",
"create_final_text_units.parquet",
"create_final_relationships.parquet",
],
optional_list=[], # Remove covariates as it's not supported
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]
# Note: DRIFT search doesn't support streaming
if streaming:
progress_placeholder.warning(
"Streaming is not supported for DRIFT search. Using standard search instead."
)
response, context_data = asyncio.run(
api.drift_search(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
community_level=community_level,
query=query,
)
)
reporter.success(f"DRIFT Search Response:\n{response}")
return response, context_data