import asyncio
from pathlib import Path
import pandas as pd
from typing import Tuple, Optional
from graphrag.config import GraphRagConfig, load_config, resolve_paths
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.logging import PrintProgressReporter
from graphrag.utils.storage import _create_storage, _load_table_from_storage
import graphrag.api as api


class StreamlitProgressReporter(PrintProgressReporter):
    def __init__(self, placeholder):
        super().__init__("")
        self.placeholder = placeholder

    def success(self, message: str):
        self.placeholder.success(message)


def _resolve_parquet_files(
    root_dir: str,
    config: GraphRagConfig,
    parquet_list: list[str],
    optional_list: list[str],
) -> dict[str, pd.DataFrame]:
    """Read parquet files to a dataframe dict."""
    dataframe_dict = {}
    pipeline_config = create_pipeline_config(config)
    storage_obj = _create_storage(root_dir=root_dir, config=pipeline_config.storage)

    for parquet_file in parquet_list:
        df_key = parquet_file.split(".")[0]
        df_value = asyncio.run(
            _load_table_from_storage(name=parquet_file, storage=storage_obj)
        )
        dataframe_dict[df_key] = df_value

    for optional_file in optional_list:
        file_exists = asyncio.run(storage_obj.has(optional_file))
        df_key = optional_file.split(".")[0]
        if file_exists:
            df_value = asyncio.run(
                _load_table_from_storage(name=optional_file, storage=storage_obj)
            )
            dataframe_dict[df_key] = df_value
        else:
            dataframe_dict[df_key] = None

    return dataframe_dict


def run_global_search(
    config_filepath: Optional[str],
    data_dir: Optional[str],
    root_dir: str,
    community_level: int,
    response_type: str,
    streaming: bool,
    query: str,
    progress_placeholder,
) -> Tuple[str, dict]:
    """Perform a global search with a given query."""
    root = Path(root_dir).resolve()
    config = load_config(root, config_filepath)
    reporter = StreamlitProgressReporter(progress_placeholder)

    config.storage.base_dir = data_dir or config.storage.base_dir
    resolve_paths(config)

    dataframe_dict = _resolve_parquet_files(
        root_dir=root_dir,
        config=config,
        parquet_list=[
            "create_final_nodes.parquet",
            "create_final_entities.parquet",
            "create_final_community_reports.parquet",
        ],
        optional_list=[],
    )

    final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
    final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
    final_community_reports: pd.DataFrame = dataframe_dict[
        "create_final_community_reports"
    ]

    if streaming:

        async def run_streaming_search():
            full_response = ""
            context_data = None
            get_context_data = True
            try:
                async for stream_chunk in api.global_search_streaming(
                    config=config,
                    nodes=final_nodes,
                    entities=final_entities,
                    community_reports=final_community_reports,
                    community_level=community_level,
                    response_type=response_type,
                    query=query,
                ):
                    if get_context_data:
                        context_data = stream_chunk
                        get_context_data = False
                    else:
                        full_response += stream_chunk
                        progress_placeholder.markdown(full_response)
            except Exception as e:
                progress_placeholder.error(f"Error during streaming search: {e}")
                return None, None

            return full_response, context_data

        result = asyncio.run(run_streaming_search())
        if result is None:
            return "", {}  # Graceful fallback
        return result

    # Non-streaming logic
    try:
        response, context_data = asyncio.run(
            api.global_search(
                config=config,
                nodes=final_nodes,
                entities=final_entities,
                community_reports=final_community_reports,
                community_level=community_level,
                response_type=response_type,
                query=query,
            )
        )
        reporter.success(f"Global Search Response:\n{response}")
        return response, context_data
    except Exception as e:
        progress_placeholder.error(f"Error during global search: {e}")
        return "", {}  # Graceful fallback


def run_local_search(
    config_filepath: Optional[str],
    data_dir: Optional[str],
    root_dir: str,
    community_level: int,
    response_type: str,
    streaming: bool,
    query: str,
    progress_placeholder,
) -> Tuple[str, dict]:
    """Perform a local search with a given query."""
    root = Path(root_dir).resolve()
    config = load_config(root, config_filepath)
    reporter = StreamlitProgressReporter(progress_placeholder)

    config.storage.base_dir = data_dir or config.storage.base_dir
    resolve_paths(config)

    dataframe_dict = _resolve_parquet_files(
        root_dir=root_dir,
        config=config,
        parquet_list=[
            "create_final_nodes.parquet",
            "create_final_community_reports.parquet",
            "create_final_text_units.parquet",
            "create_final_relationships.parquet",
            "create_final_entities.parquet",
        ],
        optional_list=["create_final_covariates.parquet"],
    )

    final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
    final_community_reports: pd.DataFrame = dataframe_dict[
        "create_final_community_reports"
    ]
    final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
    final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]
    final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
    final_covariates: Optional[pd.DataFrame] = dataframe_dict["create_final_covariates"]

    if streaming:

        async def run_streaming_search():
            full_response = ""
            context_data = None
            get_context_data = True
            async for stream_chunk in api.local_search_streaming(
                config=config,
                nodes=final_nodes,
                entities=final_entities,
                community_reports=final_community_reports,
                text_units=final_text_units,
                relationships=final_relationships,
                covariates=final_covariates,
                community_level=community_level,
                response_type=response_type,
                query=query,
            ):
                if get_context_data:
                    context_data = stream_chunk
                    get_context_data = False
                else:
                    full_response += stream_chunk
                    progress_placeholder.markdown(full_response)
            return full_response, context_data

        return asyncio.run(run_streaming_search())

    response, context_data = asyncio.run(
        api.local_search(
            config=config,
            nodes=final_nodes,
            entities=final_entities,
            community_reports=final_community_reports,
            text_units=final_text_units,
            relationships=final_relationships,
            covariates=final_covariates,
            community_level=community_level,
            response_type=response_type,
            query=query,
        )
    )
    reporter.success(f"Local Search Response:\n{response}")
    return response, context_data


def run_drift_search(
    config_filepath: Optional[str],
    data_dir: Optional[str],
    root_dir: str,
    community_level: int,
    response_type: str,
    streaming: bool,
    query: str,
    progress_placeholder,
) -> Tuple[str, dict]:
    """Perform a DRIFT search with a given query."""
    root = Path(root_dir).resolve()
    config = load_config(root, config_filepath)
    reporter = StreamlitProgressReporter(progress_placeholder)

    config.storage.base_dir = data_dir or config.storage.base_dir
    resolve_paths(config)

    dataframe_dict = _resolve_parquet_files(
        root_dir=root_dir,
        config=config,
        parquet_list=[
            "create_final_nodes.parquet",
            "create_final_entities.parquet",
            "create_final_community_reports.parquet",
            "create_final_text_units.parquet",
            "create_final_relationships.parquet",
        ],
        optional_list=[],  # Remove covariates as it's not supported
    )

    final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
    final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
    final_community_reports: pd.DataFrame = dataframe_dict[
        "create_final_community_reports"
    ]
    final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
    final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]

    # Note: DRIFT search doesn't support streaming
    if streaming:
        progress_placeholder.warning(
            "Streaming is not supported for DRIFT search. Using standard search instead."
        )

    response, context_data = asyncio.run(
        api.drift_search(
            config=config,
            nodes=final_nodes,
            entities=final_entities,
            community_reports=final_community_reports,
            text_units=final_text_units,
            relationships=final_relationships,
            community_level=community_level,
            query=query,
        )
    )
    reporter.success(f"DRIFT Search Response:\n{response}")
    return response, context_data