Spaces:
Sleeping
Sleeping
import asyncio | |
from pathlib import Path | |
import pandas as pd | |
from typing import Tuple, Optional | |
from graphrag.config import GraphRagConfig, load_config, resolve_paths | |
from graphrag.index.create_pipeline_config import create_pipeline_config | |
from graphrag.logging import PrintProgressReporter | |
from graphrag.utils.storage import _create_storage, _load_table_from_storage | |
import graphrag.api as api | |
class StreamlitProgressReporter(PrintProgressReporter): | |
def __init__(self, placeholder): | |
super().__init__("") | |
self.placeholder = placeholder | |
def success(self, message: str): | |
self.placeholder.success(message) | |
def _resolve_parquet_files( | |
root_dir: str, | |
config: GraphRagConfig, | |
parquet_list: list[str], | |
optional_list: list[str], | |
) -> dict[str, pd.DataFrame]: | |
"""Read parquet files to a dataframe dict.""" | |
dataframe_dict = {} | |
pipeline_config = create_pipeline_config(config) | |
storage_obj = _create_storage(root_dir=root_dir, config=pipeline_config.storage) | |
for parquet_file in parquet_list: | |
df_key = parquet_file.split(".")[0] | |
df_value = asyncio.run( | |
_load_table_from_storage(name=parquet_file, storage=storage_obj) | |
) | |
dataframe_dict[df_key] = df_value | |
for optional_file in optional_list: | |
file_exists = asyncio.run(storage_obj.has(optional_file)) | |
df_key = optional_file.split(".")[0] | |
if file_exists: | |
df_value = asyncio.run( | |
_load_table_from_storage(name=optional_file, storage=storage_obj) | |
) | |
dataframe_dict[df_key] = df_value | |
else: | |
dataframe_dict[df_key] = None | |
return dataframe_dict | |
def run_global_search( | |
config_filepath: Optional[str], | |
data_dir: Optional[str], | |
root_dir: str, | |
community_level: int, | |
response_type: str, | |
streaming: bool, | |
query: str, | |
progress_placeholder, | |
) -> Tuple[str, dict]: | |
"""Perform a global search with a given query.""" | |
root = Path(root_dir).resolve() | |
config = load_config(root, config_filepath) | |
reporter = StreamlitProgressReporter(progress_placeholder) | |
config.storage.base_dir = data_dir or config.storage.base_dir | |
resolve_paths(config) | |
dataframe_dict = _resolve_parquet_files( | |
root_dir=root_dir, | |
config=config, | |
parquet_list=[ | |
"create_final_nodes.parquet", | |
"create_final_entities.parquet", | |
"create_final_community_reports.parquet", | |
], | |
optional_list=[], | |
) | |
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"] | |
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"] | |
final_community_reports: pd.DataFrame = dataframe_dict[ | |
"create_final_community_reports" | |
] | |
if streaming: | |
async def run_streaming_search(): | |
full_response = "" | |
context_data = None | |
get_context_data = True | |
try: | |
async for stream_chunk in api.global_search_streaming( | |
config=config, | |
nodes=final_nodes, | |
entities=final_entities, | |
community_reports=final_community_reports, | |
community_level=community_level, | |
response_type=response_type, | |
query=query, | |
): | |
if get_context_data: | |
context_data = stream_chunk | |
get_context_data = False | |
else: | |
full_response += stream_chunk | |
progress_placeholder.markdown(full_response) | |
except Exception as e: | |
progress_placeholder.error(f"Error during streaming search: {e}") | |
return None, None | |
return full_response, context_data | |
result = asyncio.run(run_streaming_search()) | |
if result is None: | |
return "", {} # Graceful fallback | |
return result | |
# Non-streaming logic | |
try: | |
response, context_data = asyncio.run( | |
api.global_search( | |
config=config, | |
nodes=final_nodes, | |
entities=final_entities, | |
community_reports=final_community_reports, | |
community_level=community_level, | |
response_type=response_type, | |
query=query, | |
) | |
) | |
reporter.success(f"Global Search Response:\n{response}") | |
return response, context_data | |
except Exception as e: | |
progress_placeholder.error(f"Error during global search: {e}") | |
return "", {} # Graceful fallback | |
def run_local_search( | |
config_filepath: Optional[str], | |
data_dir: Optional[str], | |
root_dir: str, | |
community_level: int, | |
response_type: str, | |
streaming: bool, | |
query: str, | |
progress_placeholder, | |
) -> Tuple[str, dict]: | |
"""Perform a local search with a given query.""" | |
root = Path(root_dir).resolve() | |
config = load_config(root, config_filepath) | |
reporter = StreamlitProgressReporter(progress_placeholder) | |
config.storage.base_dir = data_dir or config.storage.base_dir | |
resolve_paths(config) | |
dataframe_dict = _resolve_parquet_files( | |
root_dir=root_dir, | |
config=config, | |
parquet_list=[ | |
"create_final_nodes.parquet", | |
"create_final_community_reports.parquet", | |
"create_final_text_units.parquet", | |
"create_final_relationships.parquet", | |
"create_final_entities.parquet", | |
], | |
optional_list=["create_final_covariates.parquet"], | |
) | |
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"] | |
final_community_reports: pd.DataFrame = dataframe_dict[ | |
"create_final_community_reports" | |
] | |
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"] | |
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"] | |
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"] | |
final_covariates: Optional[pd.DataFrame] = dataframe_dict["create_final_covariates"] | |
if streaming: | |
async def run_streaming_search(): | |
full_response = "" | |
context_data = None | |
get_context_data = True | |
async for stream_chunk in api.local_search_streaming( | |
config=config, | |
nodes=final_nodes, | |
entities=final_entities, | |
community_reports=final_community_reports, | |
text_units=final_text_units, | |
relationships=final_relationships, | |
covariates=final_covariates, | |
community_level=community_level, | |
response_type=response_type, | |
query=query, | |
): | |
if get_context_data: | |
context_data = stream_chunk | |
get_context_data = False | |
else: | |
full_response += stream_chunk | |
progress_placeholder.markdown(full_response) | |
return full_response, context_data | |
return asyncio.run(run_streaming_search()) | |
response, context_data = asyncio.run( | |
api.local_search( | |
config=config, | |
nodes=final_nodes, | |
entities=final_entities, | |
community_reports=final_community_reports, | |
text_units=final_text_units, | |
relationships=final_relationships, | |
covariates=final_covariates, | |
community_level=community_level, | |
response_type=response_type, | |
query=query, | |
) | |
) | |
reporter.success(f"Local Search Response:\n{response}") | |
return response, context_data | |
def run_drift_search( | |
config_filepath: Optional[str], | |
data_dir: Optional[str], | |
root_dir: str, | |
community_level: int, | |
response_type: str, | |
streaming: bool, | |
query: str, | |
progress_placeholder, | |
) -> Tuple[str, dict]: | |
"""Perform a DRIFT search with a given query.""" | |
root = Path(root_dir).resolve() | |
config = load_config(root, config_filepath) | |
reporter = StreamlitProgressReporter(progress_placeholder) | |
config.storage.base_dir = data_dir or config.storage.base_dir | |
resolve_paths(config) | |
dataframe_dict = _resolve_parquet_files( | |
root_dir=root_dir, | |
config=config, | |
parquet_list=[ | |
"create_final_nodes.parquet", | |
"create_final_entities.parquet", | |
"create_final_community_reports.parquet", | |
"create_final_text_units.parquet", | |
"create_final_relationships.parquet", | |
], | |
optional_list=[], # Remove covariates as it's not supported | |
) | |
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"] | |
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"] | |
final_community_reports: pd.DataFrame = dataframe_dict[ | |
"create_final_community_reports" | |
] | |
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"] | |
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"] | |
# Note: DRIFT search doesn't support streaming | |
if streaming: | |
progress_placeholder.warning( | |
"Streaming is not supported for DRIFT search. Using standard search instead." | |
) | |
response, context_data = asyncio.run( | |
api.drift_search( | |
config=config, | |
nodes=final_nodes, | |
entities=final_entities, | |
community_reports=final_community_reports, | |
text_units=final_text_units, | |
relationships=final_relationships, | |
community_level=community_level, | |
query=query, | |
) | |
) | |
reporter.success(f"DRIFT Search Response:\n{response}") | |
return response, context_data | |