from langgraph.types import Send from planning_ai.chains.fix_chain import fix_chain from planning_ai.chains.hallucination_chain import hallucination_chain from planning_ai.logging import logger from planning_ai.states import DocumentState, OverallState MAX_ATTEMPTS = 3 def check_hallucination(state: DocumentState): """Checks for hallucinations in the summary of a document. This function uses the `hallucination_chain` to evaluate the summary of a document. If the hallucination score is 1, it indicates no hallucination, and the summary is considered fixed. If the iteration count exceeds 5, the process is terminated. Args: state (DocumentState): The current state of the document, including its summary and iteration count. Returns: dict: A dictionary containing either a list of fixed summaries or hallucinations that need to be addressed. """ logger.info(f"Checking hallucinations for document {state['filename']}") if state["processed"] or (state["refinement_attempts"] >= MAX_ATTEMPTS): logger.error(f"Max attempts exceeded for document: {state['filename']}") return {"documents": [{**state, "failed": True, "processed": True}]} elif not state["is_hallucinated"]: logger.info(f"Finished processing document: {state['filename']}") return {"documents": [{**state, "processed": True}]} try: response = hallucination_chain.invoke( {"document": state["document"], "summary": state["summary"]} ) is_hallucinated = response.score == 0 refinement_attempts = state["refinement_attempts"] + 1 except Exception as e: logger.error(f"Failed to decode JSON {state['filename']}: {e}") return { "documents": [ { **state, "summary": "", "refinement_attempts": 0, "is_hallucinated": True, "failed": True, "processed": True, } ] } out = { **state, "hallucination": response, "refinement_attempts": refinement_attempts, "is_hallucinated": is_hallucinated, } logger.info(f"Hallucination for {state['filename']}: {is_hallucinated}") return ( {"documents": [{**out, "processed": False}]} if is_hallucinated else {"documents": [{**out, "processed": True}]} ) def fix_hallucination(state: DocumentState): """Attempts to fix hallucinations in a document's summary. This function uses the `fix_chain` to correct hallucinations identified in a summary. The corrected summary is then updated in the document state. Args: state (DocumentState): The current state of the document, including its summary and hallucination details. Returns: dict: A dictionary containing the updated summaries after attempting to fix hallucinations. """ logger.warning(f"Fixing hallucinations for document {state['filename']}") try: response = fix_chain.invoke( { "context": state["document"], "summary": state["summary"], "explanation": state["hallucination"].explanation, } ) except Exception as e: logger.error(f"Failed to decode JSON {state['filename']}: {e}.") return { "documents": [ { **state, "summary": "", "refinement_attempts": 0, "is_hallucinated": True, "failed": True, "processed": True, } ] } return {"documents": [{**state, "summary": response}]} def map_check(state: OverallState): """Maps the check_hallucination function to each document in the overall state. Args: state (OverallState): The overall state containing multiple documents. Returns: list: A list of Send objects, each representing a request to check for hallucinations in a document. """ return [Send("check_hallucination", doc) for doc in state["documents"]] def map_fix(state: OverallState): """Maps the fix_hallucination function to each hallucinated document that is not processed. Args: state (OverallState): The overall state containing multiple documents. Returns: list: A list of Send objects, each representing a request to fix hallucinations in a document that is hallucinated and not yet processed. """ return [ Send("fix_hallucination", doc) for doc in state["documents"] if doc["is_hallucinated"] and not doc["processed"] ]