Spaces:
Sleeping
Sleeping
fix: simplify structured outputs which improves accuracy
Browse files- planning_ai/chains/map_chain.py +8 -4
- planning_ai/chains/policy_chain.py +12 -2
- planning_ai/chains/prompts/map.txt +1 -1
- planning_ai/chains/prompts/policy.txt +1 -3
- planning_ai/graph.py +10 -26
- planning_ai/nodes/hallucination_node.py +33 -70
- planning_ai/nodes/map_node.py +29 -29
- planning_ai/nodes/reduce_node.py +72 -97
- planning_ai/states.py +8 -4
planning_ai/chains/map_chain.py
CHANGED
@@ -2,7 +2,7 @@ from enum import Enum, auto
|
|
2 |
from typing import Optional, Set, Type
|
3 |
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
5 |
-
from pydantic import BaseModel, create_model
|
6 |
|
7 |
from planning_ai.common.utils import Paths
|
8 |
from planning_ai.llms.llm import LLM
|
@@ -39,9 +39,11 @@ def create_brief_summary_model(policy_enum: Enum) -> Type[BaseModel]:
|
|
39 |
Type[BaseModel]: A dynamically generated Pydantic model for BriefSummary.
|
40 |
"""
|
41 |
|
|
|
42 |
DynamicPolicy = create_model(
|
43 |
"DynamicPolicy",
|
44 |
-
policy=(policy_enum, ...),
|
|
|
45 |
note=(str, ...),
|
46 |
__config__={"extra": "forbid"},
|
47 |
)
|
@@ -49,7 +51,9 @@ def create_brief_summary_model(policy_enum: Enum) -> Type[BaseModel]:
|
|
49 |
return create_model(
|
50 |
"DynamicBriefSummary",
|
51 |
summary=(str, ...),
|
52 |
-
policies=(Optional[list[DynamicPolicy]], ...),
|
|
|
|
|
53 |
__module__=__name__,
|
54 |
__config__={"extra": "forbid"},
|
55 |
)
|
@@ -82,7 +86,7 @@ if __name__ == "__main__":
|
|
82 |
the major settlement of Cambourne has been created - now over the projected 3,000 homes and
|
83 |
Papworth Everard has grown beyond recognition. This in itself is a matter of concern.
|
84 |
"""
|
85 |
-
test_themes = {"
|
86 |
|
87 |
dynamic_map_chain = create_dynamic_map_chain(test_themes, prompt=map_template)
|
88 |
result = dynamic_map_chain.invoke({"context": test_document, "themes": test_themes})
|
|
|
2 |
from typing import Optional, Set, Type
|
3 |
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
5 |
+
from pydantic import BaseModel, Field, create_model
|
6 |
|
7 |
from planning_ai.common.utils import Paths
|
8 |
from planning_ai.llms.llm import LLM
|
|
|
39 |
Type[BaseModel]: A dynamically generated Pydantic model for BriefSummary.
|
40 |
"""
|
41 |
|
42 |
+
# NOTE: For some reason GPT4o goes mental if we use too much structure
|
43 |
DynamicPolicy = create_model(
|
44 |
"DynamicPolicy",
|
45 |
+
# policy=(policy_enum, ...),
|
46 |
+
policy=(str, ...),
|
47 |
note=(str, ...),
|
48 |
__config__={"extra": "forbid"},
|
49 |
)
|
|
|
51 |
return create_model(
|
52 |
"DynamicBriefSummary",
|
53 |
summary=(str, ...),
|
54 |
+
# policies=(Optional[list[DynamicPolicy]], ...),
|
55 |
+
policies=(Optional[list[str]], ...),
|
56 |
+
notes=(Optional[list[str]], ...),
|
57 |
__module__=__name__,
|
58 |
__config__={"extra": "forbid"},
|
59 |
)
|
|
|
86 |
the major settlement of Cambourne has been created - now over the projected 3,000 homes and
|
87 |
Papworth Everard has grown beyond recognition. This in itself is a matter of concern.
|
88 |
"""
|
89 |
+
test_themes = {"Homes", "Great Places"}
|
90 |
|
91 |
dynamic_map_chain = create_dynamic_map_chain(test_themes, prompt=map_template)
|
92 |
result = dynamic_map_chain.invoke({"context": test_document, "themes": test_themes})
|
planning_ai/chains/policy_chain.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
from langchain_core.output_parsers import StrOutputParser
|
2 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
3 |
|
4 |
from planning_ai.common.utils import Paths
|
5 |
from planning_ai.llms.llm import LLM
|
@@ -8,8 +8,18 @@ with open(Paths.PROMPTS / "policy.txt", "r") as f:
|
|
8 |
policy_template = f.read()
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
policy_prompt = ChatPromptTemplate([("system", policy_template)])
|
12 |
-
policy_chain = policy_prompt |
|
13 |
|
14 |
|
15 |
if __name__ == "__main__":
|
|
|
|
|
1 |
from langchain_core.prompts import ChatPromptTemplate
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
|
4 |
from planning_ai.common.utils import Paths
|
5 |
from planning_ai.llms.llm import LLM
|
|
|
8 |
policy_template = f.read()
|
9 |
|
10 |
|
11 |
+
class PolicyMerger(BaseModel):
|
12 |
+
"""Return condensed details and their associated doc_ids"""
|
13 |
+
|
14 |
+
details: list[str]
|
15 |
+
doc_id: list[list[int]]
|
16 |
+
|
17 |
+
|
18 |
+
SLLM = LLM.with_structured_output(PolicyMerger, strict=True)
|
19 |
+
|
20 |
+
|
21 |
policy_prompt = ChatPromptTemplate([("system", policy_template)])
|
22 |
+
policy_chain = policy_prompt | SLLM
|
23 |
|
24 |
|
25 |
if __name__ == "__main__":
|
planning_ai/chains/prompts/map.txt
CHANGED
@@ -2,7 +2,7 @@ Please analyze the response to the planning application provided below. Your tas
|
|
2 |
|
3 |
1. **Summary**: Provide a concise summary of the response, highlighting the main points and any significant details.
|
4 |
|
5 |
-
2. **Policy Identification**:
|
6 |
|
7 |
3. **Policy Notes**: For each identified policy, extract and list at least one verbatim section from the response that directly relates to it. Ensure the **full** context is retained so the section can be understood independently. Policy notes may overlap. If a note does not have a clear link to the policy, omit both the policy and the note.
|
8 |
|
|
|
2 |
|
3 |
1. **Summary**: Provide a concise summary of the response, highlighting the main points and any significant details.
|
4 |
|
5 |
+
2. **Policy Identification**: Carefully review the response and identify all relevant policies from the provided list. Focus on capturing policies that are explicitly mentioned or strongly implied. Avoid inferring new policies beyond those stated. Select **all** relevant policies, even if they seem minor.
|
6 |
|
7 |
3. **Policy Notes**: For each identified policy, extract and list at least one verbatim section from the response that directly relates to it. Ensure the **full** context is retained so the section can be understood independently. Policy notes may overlap. If a note does not have a clear link to the policy, omit both the policy and the note.
|
8 |
|
planning_ai/chains/prompts/policy.txt
CHANGED
@@ -6,9 +6,7 @@ You are tasked with refining a list of details related to a specific planning po
|
|
6 |
4. Exclude any details that do not pertain to the policy.
|
7 |
5. Disregard generic details that merely restate the policy.
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
For example '- Impact of increased housing density on the character of Cambridge [1][2][11].'.
|
12 |
|
13 |
Theme: {theme}
|
14 |
|
|
|
6 |
4. Exclude any details that do not pertain to the policy.
|
7 |
5. Disregard generic details that merely restate the policy.
|
8 |
|
9 |
+
Ensure that all returned details use proper sentence structure.
|
|
|
|
|
10 |
|
11 |
Theme: {theme}
|
12 |
|
planning_ai/graph.py
CHANGED
@@ -4,13 +4,13 @@ from langgraph.graph import END, StateGraph
|
|
4 |
from planning_ai.nodes.hallucination_node import (
|
5 |
check_hallucination,
|
6 |
fix_hallucination,
|
7 |
-
|
8 |
-
|
9 |
)
|
10 |
from planning_ai.nodes.map_node import (
|
11 |
add_entities,
|
12 |
generate_summary,
|
13 |
-
|
14 |
retrieve_themes,
|
15 |
)
|
16 |
from planning_ai.nodes.reduce_node import generate_final_report
|
@@ -24,31 +24,15 @@ def create_graph():
|
|
24 |
graph.add_node("generate_summary", generate_summary)
|
25 |
graph.add_node("check_hallucination", check_hallucination)
|
26 |
graph.add_node("fix_hallucination", fix_hallucination)
|
27 |
-
graph.add_node("
|
28 |
|
29 |
graph.add_edge(START, "add_entities")
|
30 |
-
graph.add_conditional_edges(
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
graph.
|
36 |
-
"generate_summary",
|
37 |
-
map_hallucinations,
|
38 |
-
["check_hallucination"],
|
39 |
-
)
|
40 |
-
graph.add_conditional_edges(
|
41 |
-
"check_hallucination",
|
42 |
-
map_fix_hallucinations,
|
43 |
-
["fix_hallucination"],
|
44 |
-
)
|
45 |
-
graph.add_conditional_edges(
|
46 |
-
"fix_hallucination",
|
47 |
-
map_hallucinations,
|
48 |
-
["check_hallucination"],
|
49 |
-
)
|
50 |
-
|
51 |
-
graph.add_edge("check_hallucination", "generate_final_summary")
|
52 |
graph.add_edge("generate_final_summary", END)
|
53 |
|
54 |
return graph.compile()
|
|
|
4 |
from planning_ai.nodes.hallucination_node import (
|
5 |
check_hallucination,
|
6 |
fix_hallucination,
|
7 |
+
map_check,
|
8 |
+
map_fix,
|
9 |
)
|
10 |
from planning_ai.nodes.map_node import (
|
11 |
add_entities,
|
12 |
generate_summary,
|
13 |
+
map_documents,
|
14 |
retrieve_themes,
|
15 |
)
|
16 |
from planning_ai.nodes.reduce_node import generate_final_report
|
|
|
24 |
graph.add_node("generate_summary", generate_summary)
|
25 |
graph.add_node("check_hallucination", check_hallucination)
|
26 |
graph.add_node("fix_hallucination", fix_hallucination)
|
27 |
+
graph.add_node("generate_final_report", generate_final_report)
|
28 |
|
29 |
graph.add_edge(START, "add_entities")
|
30 |
+
graph.add_conditional_edges("add_entities", map_documents, ["generate_summary"])
|
31 |
+
graph.add_conditional_edges("generate_summary", map_check, ["check_hallucination"])
|
32 |
+
graph.add_conditional_edges("check_hallucination", map_fix, ["fix_hallucination"])
|
33 |
+
graph.add_conditional_edges("fix_hallucination", map_check, ["check_hallucination"])
|
34 |
+
|
35 |
+
graph.add_edge("check_hallucination", "generate_final_report")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
graph.add_edge("generate_final_summary", END)
|
37 |
|
38 |
return graph.compile()
|
planning_ai/nodes/hallucination_node.py
CHANGED
@@ -2,14 +2,12 @@ import json
|
|
2 |
import logging
|
3 |
|
4 |
from langchain_core.exceptions import OutputParserException
|
|
|
5 |
from langgraph.types import Send
|
6 |
from pydantic import BaseModel
|
7 |
|
8 |
from planning_ai.chains.fix_chain import fix_template
|
9 |
-
from planning_ai.chains.hallucination_chain import
|
10 |
-
HallucinationChecker,
|
11 |
-
hallucination_chain,
|
12 |
-
)
|
13 |
from planning_ai.chains.map_chain import create_dynamic_map_chain
|
14 |
from planning_ai.states import DocumentState, OverallState
|
15 |
|
@@ -19,12 +17,7 @@ logging.basicConfig(
|
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
|
22 |
-
|
23 |
-
summary: str
|
24 |
-
policies: None
|
25 |
-
|
26 |
-
|
27 |
-
ITERATIONS = 2
|
28 |
|
29 |
|
30 |
def check_hallucination(state: DocumentState):
|
@@ -43,47 +36,35 @@ def check_hallucination(state: DocumentState):
|
|
43 |
that need to be addressed.
|
44 |
"""
|
45 |
logger.warning(f"Checking hallucinations for document {state['filename']}")
|
46 |
-
|
47 |
-
if state["
|
48 |
-
state["
|
49 |
-
state
|
50 |
-
|
|
|
|
|
51 |
|
52 |
try:
|
53 |
response = hallucination_chain.invoke(
|
54 |
{"document": state["document"], "summary": state["summary"].summary}
|
55 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
except (OutputParserException, json.JSONDecodeError) as e:
|
57 |
logger.error(f"Failed to decode JSON: {e}.")
|
58 |
-
state
|
59 |
-
state["hallucination"] = HallucinationChecker(score=1, explanation="INVALID")
|
60 |
-
state["summary"] = BasicSummaryBroken(summary="INVALID", policies=None)
|
61 |
-
return {"documents": [state]}
|
62 |
-
if response.score == 1:
|
63 |
-
return {"documents": [{**state, "hallucination": response}]}
|
64 |
-
|
65 |
-
return {
|
66 |
-
"documents": [
|
67 |
-
{**state, "hallucination": response, "iteration": state["iteration"] + 1}
|
68 |
-
]
|
69 |
-
}
|
70 |
-
|
71 |
-
|
72 |
-
def map_hallucinations(state: OverallState):
|
73 |
-
"""Maps summaries to the `check_hallucination` function.
|
74 |
-
|
75 |
-
This function prepares a list of summaries to be checked for hallucinations by
|
76 |
-
sending them to the `check_hallucination` function. Allows summaries to be checked
|
77 |
-
in parrallel.
|
78 |
-
|
79 |
-
Args:
|
80 |
-
state (OverallState): The overall state containing all summaries.
|
81 |
-
|
82 |
-
Returns:
|
83 |
-
list: A list of Send objects directing each summary to the check_hallucination
|
84 |
-
function.
|
85 |
-
"""
|
86 |
-
return [Send("check_hallucination", document) for document in state["documents"]]
|
87 |
|
88 |
|
89 |
def fix_hallucination(state: DocumentState):
|
@@ -112,35 +93,17 @@ def fix_hallucination(state: DocumentState):
|
|
112 |
)
|
113 |
except (OutputParserException, json.JSONDecodeError) as e:
|
114 |
logger.error(f"Failed to decode JSON: {e}.")
|
115 |
-
state
|
116 |
-
|
117 |
-
state["summary"] = BasicSummaryBroken(summary="INVALID", policies=None)
|
118 |
-
return {"documents": [state]}
|
119 |
-
state["summary"] = response # type: ignore
|
120 |
-
return {"documents": [state]}
|
121 |
|
122 |
|
123 |
-
def
|
124 |
-
""
|
125 |
|
126 |
-
This function filters out hallucinations that need fixing and prepares them to be
|
127 |
-
sent to the `fix_hallucination` function. Allows hallucinations to be fixed in
|
128 |
-
parrallel.
|
129 |
|
130 |
-
|
131 |
-
state (OverallState): The overall state containing all hallucinations.
|
132 |
-
|
133 |
-
Returns:
|
134 |
-
list: A list of Send objects directing each hallucination to the
|
135 |
-
fix_hallucination function.
|
136 |
-
"""
|
137 |
-
hallucinations = []
|
138 |
-
if "documents" in state:
|
139 |
-
hallucinations = [
|
140 |
-
document
|
141 |
-
for document in state["documents"]
|
142 |
-
if document["hallucination"].score != 1
|
143 |
-
]
|
144 |
return [
|
145 |
-
Send("fix_hallucination",
|
|
|
|
|
146 |
]
|
|
|
2 |
import logging
|
3 |
|
4 |
from langchain_core.exceptions import OutputParserException
|
5 |
+
from langgraph.constants import END
|
6 |
from langgraph.types import Send
|
7 |
from pydantic import BaseModel
|
8 |
|
9 |
from planning_ai.chains.fix_chain import fix_template
|
10 |
+
from planning_ai.chains.hallucination_chain import hallucination_chain
|
|
|
|
|
|
|
11 |
from planning_ai.chains.map_chain import create_dynamic_map_chain
|
12 |
from planning_ai.states import DocumentState, OverallState
|
13 |
|
|
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
19 |
|
20 |
+
MAX_ATTEMPTS = 3
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
def check_hallucination(state: DocumentState):
|
|
|
36 |
that need to be addressed.
|
37 |
"""
|
38 |
logger.warning(f"Checking hallucinations for document {state['filename']}")
|
39 |
+
|
40 |
+
if (state["refinement_attempts"] >= MAX_ATTEMPTS) or state["processed"]:
|
41 |
+
logger.warning(f"Max attempts exceeded for document: {state['filename']}")
|
42 |
+
return {"documents": [{**state, "failed": True, "processed": True}]}
|
43 |
+
elif not state["is_hallucinated"]:
|
44 |
+
logger.warning(f"Finished processing document: {state['filename']}")
|
45 |
+
return {"documents": [{**state, "processed": True}]}
|
46 |
|
47 |
try:
|
48 |
response = hallucination_chain.invoke(
|
49 |
{"document": state["document"], "summary": state["summary"].summary}
|
50 |
)
|
51 |
+
is_hallucinated = response.score == 0
|
52 |
+
refinement_attempts = state["refinement_attempts"] + 1
|
53 |
+
out = {
|
54 |
+
**state,
|
55 |
+
"hallucination": response,
|
56 |
+
"refinement_attempts": refinement_attempts,
|
57 |
+
"is_hallucinated": is_hallucinated,
|
58 |
+
}
|
59 |
+
logger.warning(f"Hallucination: {is_hallucinated}")
|
60 |
+
return (
|
61 |
+
{"documents": [{**out, "processed": False}]}
|
62 |
+
if is_hallucinated
|
63 |
+
else {"documents": [{**out, "processed": True}]}
|
64 |
+
)
|
65 |
except (OutputParserException, json.JSONDecodeError) as e:
|
66 |
logger.error(f"Failed to decode JSON: {e}.")
|
67 |
+
return {"documents": [{**state, "failed": True, "processed": True}]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
|
70 |
def fix_hallucination(state: DocumentState):
|
|
|
93 |
)
|
94 |
except (OutputParserException, json.JSONDecodeError) as e:
|
95 |
logger.error(f"Failed to decode JSON: {e}.")
|
96 |
+
return {"documents": [{**state, "failed": True, "processed": True}]}
|
97 |
+
return {"documents": [{**state, "summary": response}]}
|
|
|
|
|
|
|
|
|
98 |
|
99 |
|
100 |
+
def map_check(state: OverallState):
|
101 |
+
return [Send("check_hallucination", doc) for doc in state["documents"]]
|
102 |
|
|
|
|
|
|
|
103 |
|
104 |
+
def map_fix(state: OverallState):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
return [
|
106 |
+
Send("fix_hallucination", doc)
|
107 |
+
for doc in state["documents"]
|
108 |
+
if doc["is_hallucinated"] and not doc["processed"]
|
109 |
]
|
planning_ai/nodes/map_node.py
CHANGED
@@ -3,6 +3,7 @@ import logging
|
|
3 |
|
4 |
import spacy
|
5 |
from langchain_core.exceptions import OutputParserException
|
|
|
6 |
from langgraph.types import Send
|
7 |
from presidio_analyzer import AnalyzerEngine
|
8 |
from presidio_anonymizer import AnonymizerEngine
|
@@ -19,26 +20,14 @@ logging.basicConfig(
|
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
|
22 |
-
class BasicSummaryBroken(BaseModel):
|
23 |
-
summary: str
|
24 |
-
policies: None
|
25 |
-
|
26 |
-
|
27 |
analyzer = AnalyzerEngine()
|
28 |
anonymizer = AnonymizerEngine()
|
29 |
|
30 |
nlp = spacy.load("en_core_web_lg")
|
31 |
|
32 |
|
33 |
-
def _return_summary_error(state):
|
34 |
-
state["iteration"] = 99
|
35 |
-
state["hallucination"] = HallucinationChecker(score=1, explanation="INVALID")
|
36 |
-
state["summary"] = BasicSummaryBroken(summary="INVALID", policies=None)
|
37 |
-
return {"documents": [state]}
|
38 |
-
|
39 |
-
|
40 |
def retrieve_themes(state: DocumentState) -> DocumentState:
|
41 |
-
result = themes_chain.invoke({"document": state["document"]})
|
42 |
if not result.themes:
|
43 |
state["themes"] = set()
|
44 |
return state
|
@@ -102,29 +91,40 @@ def generate_summary(state: DocumentState) -> dict:
|
|
102 |
state = retrieve_themes(state)
|
103 |
|
104 |
if not state["themes"]:
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
map_chain = create_dynamic_map_chain(themes=state["themes"], prompt=map_template)
|
107 |
try:
|
108 |
response = map_chain.invoke({"context": state["document"].page_content})
|
109 |
except (OutputParserException, json.JSONDecodeError) as e:
|
110 |
logger.error(f"Failed to decode JSON: {e}.")
|
111 |
-
return
|
112 |
logger.warning(f"Summary generation completed for document: {state['filename']}")
|
113 |
-
return {
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
121 |
|
122 |
-
Args:
|
123 |
-
state (OverallState): The overall state containing all documents and their filenames.
|
124 |
|
125 |
-
|
126 |
-
list: A list of Send objects directing each document to the `generate_summary`
|
127 |
-
function.
|
128 |
-
"""
|
129 |
logger.warning("Mapping documents to generate summaries.")
|
130 |
return [Send("generate_summary", document) for document in state["documents"]]
|
|
|
3 |
|
4 |
import spacy
|
5 |
from langchain_core.exceptions import OutputParserException
|
6 |
+
from langgraph.constants import END
|
7 |
from langgraph.types import Send
|
8 |
from presidio_analyzer import AnalyzerEngine
|
9 |
from presidio_anonymizer import AnonymizerEngine
|
|
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
23 |
analyzer = AnalyzerEngine()
|
24 |
anonymizer = AnonymizerEngine()
|
25 |
|
26 |
nlp = spacy.load("en_core_web_lg")
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def retrieve_themes(state: DocumentState) -> DocumentState:
|
30 |
+
result = themes_chain.invoke({"document": state["document"].page_content})
|
31 |
if not result.themes:
|
32 |
state["themes"] = set()
|
33 |
return state
|
|
|
91 |
state = retrieve_themes(state)
|
92 |
|
93 |
if not state["themes"]:
|
94 |
+
logger.error("No themes found.")
|
95 |
+
return {
|
96 |
+
"documents": [
|
97 |
+
{
|
98 |
+
**state,
|
99 |
+
"summary": "",
|
100 |
+
"processed": True,
|
101 |
+
"is_hallucinated": True,
|
102 |
+
"failed": True,
|
103 |
+
"refinement_attempts": 0,
|
104 |
+
}
|
105 |
+
]
|
106 |
+
}
|
107 |
map_chain = create_dynamic_map_chain(themes=state["themes"], prompt=map_template)
|
108 |
try:
|
109 |
response = map_chain.invoke({"context": state["document"].page_content})
|
110 |
except (OutputParserException, json.JSONDecodeError) as e:
|
111 |
logger.error(f"Failed to decode JSON: {e}.")
|
112 |
+
return {"documents": [{**state, "failed": True, "processed": True}]}
|
113 |
logger.warning(f"Summary generation completed for document: {state['filename']}")
|
114 |
+
return {
|
115 |
+
"documents": [
|
116 |
+
{
|
117 |
+
**state,
|
118 |
+
"summary": response,
|
119 |
+
"refinement_attempts": 0,
|
120 |
+
"is_hallucinated": True, # start true to ensure cycle begins
|
121 |
+
"failed": False,
|
122 |
+
"processed": False,
|
123 |
+
}
|
124 |
+
]
|
125 |
+
}
|
126 |
|
|
|
|
|
127 |
|
128 |
+
def map_documents(state: OverallState) -> list[Send]:
|
|
|
|
|
|
|
129 |
logger.warning("Mapping documents to generate summaries.")
|
130 |
return [Send("generate_summary", document) for document in state["documents"]]
|
planning_ai/nodes/reduce_node.py
CHANGED
@@ -15,46 +15,6 @@ logging.basicConfig(
|
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
|
18 |
-
def extract_policies_from_docs(docs):
|
19 |
-
policies = {"themes": [], "policies": [], "details": [], "stance": []}
|
20 |
-
for doc in docs:
|
21 |
-
if not doc["summary"].policies:
|
22 |
-
continue
|
23 |
-
for policy in doc["summary"].policies:
|
24 |
-
for theme, p in THEMES_AND_POLICIES.items():
|
25 |
-
if policy.policy.name in p:
|
26 |
-
policies["themes"].append(theme)
|
27 |
-
policies["policies"].append(policy.policy.name)
|
28 |
-
policies["details"].append(
|
29 |
-
f"{policy.note} [{doc['document'].metadata['index']}]"
|
30 |
-
)
|
31 |
-
policies["stance"].append(
|
32 |
-
doc["document"].metadata["representations_support/object"]
|
33 |
-
)
|
34 |
-
df = pl.DataFrame(policies)
|
35 |
-
grouped = df.group_by(["themes", "policies", "stance"]).agg(pl.col("details"))
|
36 |
-
return grouped
|
37 |
-
|
38 |
-
|
39 |
-
def filter_final_documents(state: OverallState):
|
40 |
-
return [doc for doc in state["documents"] if doc["hallucination"].score == 1]
|
41 |
-
|
42 |
-
|
43 |
-
def filter_docs(final_docs):
|
44 |
-
out_docs = []
|
45 |
-
for doc in final_docs:
|
46 |
-
if (
|
47 |
-
(doc["summary"].summary != "INVALID")
|
48 |
-
and (doc["themes"] != set())
|
49 |
-
and (doc["iteration"] != 99)
|
50 |
-
):
|
51 |
-
doc["summary"].summary = (
|
52 |
-
f"Document ID: [{doc['document'].metadata['index']}]\n\n{doc['summary'].summary}"
|
53 |
-
)
|
54 |
-
out_docs.append(doc)
|
55 |
-
return out_docs
|
56 |
-
|
57 |
-
|
58 |
def save_summaries_to_json(docs):
|
59 |
"""Saves summaries to JSON files.
|
60 |
|
@@ -69,12 +29,12 @@ def save_summaries_to_json(docs):
|
|
69 |
"entities": doc["entities"],
|
70 |
"themes": list(doc["themes"]),
|
71 |
"summary": doc["summary"].model_dump()["summary"],
|
72 |
-
"policies": [
|
73 |
-
|
74 |
-
|
75 |
-
],
|
76 |
-
"iteration": doc["iteration"],
|
77 |
"hallucination": doc["hallucination"].model_dump(),
|
|
|
|
|
78 |
}
|
79 |
for doc in docs
|
80 |
]
|
@@ -84,6 +44,41 @@ def save_summaries_to_json(docs):
|
|
84 |
json.dump(doc, f)
|
85 |
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
def batch_generate_executive_summaries(summaries):
|
88 |
"""Processes summaries to generate final responses.
|
89 |
|
@@ -101,7 +96,7 @@ def batch_generate_executive_summaries(summaries):
|
|
101 |
batch_size = 50
|
102 |
for i in range(0, len(summaries_text), batch_size):
|
103 |
logger.warning(
|
104 |
-
f"Processing batches... {i/50}/{len(summaries_text)//batch_size}"
|
105 |
)
|
106 |
batch = summaries_text[i : i + batch_size]
|
107 |
response = reduce_chain.invoke({"context": batch})
|
@@ -110,62 +105,42 @@ def batch_generate_executive_summaries(summaries):
|
|
110 |
|
111 |
|
112 |
def generate_policy_output(policy_groups):
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
)
|
121 |
-
|
122 |
-
|
123 |
-
{
|
124 |
-
"theme": policy["themes"][0],
|
125 |
-
"policy": policy["policies"][0],
|
126 |
-
"points": pchain_out,
|
127 |
-
}
|
128 |
-
)
|
129 |
-
else:
|
130 |
-
policies_object.append(
|
131 |
-
{
|
132 |
-
"theme": policy["themes"][0],
|
133 |
-
"policy": policy["policies"][0],
|
134 |
-
"points": pchain_out,
|
135 |
-
}
|
136 |
-
)
|
137 |
-
return policies_support, policies_object
|
138 |
-
|
139 |
-
|
140 |
-
def format_themes(policies):
|
141 |
-
themes = ""
|
142 |
-
for theme, policies in pl.DataFrame(policies).group_by("theme"):
|
143 |
-
themes += f"### {theme[0]}\n\n"
|
144 |
-
for row in policies.iter_rows(named=True):
|
145 |
-
themes += f"\n#### {row['policy']}\n\n"
|
146 |
-
themes += f"{row['points']}\n"
|
147 |
-
themes += "\n"
|
148 |
-
return themes
|
149 |
|
150 |
|
151 |
def generate_final_report(state: OverallState):
|
152 |
-
|
153 |
-
final_docs = filter_final_documents(state)
|
154 |
-
logger.warning(f"Number of final docs: {len(final_docs)}")
|
155 |
-
|
156 |
if len(final_docs) == state["n_docs"]:
|
157 |
-
|
158 |
-
|
159 |
|
160 |
-
policy_groups = extract_policies_from_docs(docs)
|
161 |
-
policies_support, policies_object = generate_policy_output(policy_groups)
|
162 |
|
163 |
-
|
164 |
-
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def save_summaries_to_json(docs):
|
19 |
"""Saves summaries to JSON files.
|
20 |
|
|
|
29 |
"entities": doc["entities"],
|
30 |
"themes": list(doc["themes"]),
|
31 |
"summary": doc["summary"].model_dump()["summary"],
|
32 |
+
"policies": doc["policies"],
|
33 |
+
"notes": doc["notes"],
|
34 |
+
"refinement_attempts": doc["refinement_attempts"],
|
|
|
|
|
35 |
"hallucination": doc["hallucination"].model_dump(),
|
36 |
+
"is_hallucinated": doc["is_hallucinated"],
|
37 |
+
"failed": doc["failed"],
|
38 |
}
|
39 |
for doc in docs
|
40 |
]
|
|
|
44 |
json.dump(doc, f)
|
45 |
|
46 |
|
47 |
+
def extract_policies_from_docs(docs):
|
48 |
+
policies = {"doc_id": [], "themes": [], "policies": [], "details": [], "stance": []}
|
49 |
+
for doc in docs:
|
50 |
+
if not doc["summary"].policies or not doc["summary"].notes:
|
51 |
+
continue
|
52 |
+
# TODO: Test when this is sometimes empty
|
53 |
+
assert len(doc["summary"].policies) == len(doc["summary"].notes), __import__(
|
54 |
+
"ipdb"
|
55 |
+
).set_trace()
|
56 |
+
try:
|
57 |
+
for policy, note in zip(doc["summary"].policies, doc["summary"].notes):
|
58 |
+
for theme, p in THEMES_AND_POLICIES.items():
|
59 |
+
if policy in p:
|
60 |
+
policies["doc_id"].append(doc["document"].metadata["index"])
|
61 |
+
policies["themes"].append(theme)
|
62 |
+
policies["policies"].append(policy)
|
63 |
+
policies["details"].append(note)
|
64 |
+
policies["stance"].append(
|
65 |
+
doc["document"].metadata["representations_support/object"]
|
66 |
+
)
|
67 |
+
except Exception:
|
68 |
+
__import__("ipdb").set_trace()
|
69 |
+
return pl.DataFrame(policies)
|
70 |
+
|
71 |
+
|
72 |
+
def add_doc_id(final_docs):
|
73 |
+
out_docs = []
|
74 |
+
for doc in final_docs:
|
75 |
+
doc["summary"].summary = (
|
76 |
+
f"Document ID: [{doc['document'].metadata['index']}]\n\n{doc['summary'].summary}"
|
77 |
+
)
|
78 |
+
out_docs.append(doc)
|
79 |
+
return out_docs
|
80 |
+
|
81 |
+
|
82 |
def batch_generate_executive_summaries(summaries):
|
83 |
"""Processes summaries to generate final responses.
|
84 |
|
|
|
96 |
batch_size = 50
|
97 |
for i in range(0, len(summaries_text), batch_size):
|
98 |
logger.warning(
|
99 |
+
f"Processing batches... {int(i/50)+1}/{(len(summaries_text)//batch_size)+1}"
|
100 |
)
|
101 |
batch = summaries_text[i : i + batch_size]
|
102 |
response = reduce_chain.invoke({"context": batch})
|
|
|
105 |
|
106 |
|
107 |
def generate_policy_output(policy_groups):
|
108 |
+
out = []
|
109 |
+
for policy in (
|
110 |
+
policy_groups.group_by(["themes", "policies", "stance"])
|
111 |
+
.agg(pl.col("details"), pl.col("doc_id"))
|
112 |
+
.rows(named=True)
|
113 |
+
):
|
114 |
+
logger.warning(f"Processing policies: {policy['policies']}...")
|
115 |
+
reduced = policy_chain.invoke(
|
116 |
+
{
|
117 |
+
"theme": policy["themes"],
|
118 |
+
"policy": policy["policies"],
|
119 |
+
"stance": policy["stance"],
|
120 |
+
"details": policy["details"],
|
121 |
+
"doc_id": policy["doc_id"],
|
122 |
+
}
|
123 |
)
|
124 |
+
out.append(policy | reduced.dict())
|
125 |
+
return pl.DataFrame(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
|
128 |
def generate_final_report(state: OverallState):
|
129 |
+
final_docs = [doc for doc in state["documents"] if doc["processed"]]
|
|
|
|
|
|
|
130 |
if len(final_docs) == state["n_docs"]:
|
131 |
+
logging.warning(f"Generating final report... ({len(final_docs)} documents)")
|
132 |
+
return final_output(final_docs)
|
133 |
|
|
|
|
|
134 |
|
135 |
+
def final_output(final_docs):
|
136 |
+
docs = [doc for doc in final_docs if not doc["failed"]]
|
137 |
|
138 |
+
docs = add_doc_id(docs)
|
139 |
+
|
140 |
+
policy_groups = extract_policies_from_docs(docs)
|
141 |
+
policies = generate_policy_output(policy_groups)
|
142 |
+
|
143 |
+
batch_executive = batch_generate_executive_summaries(docs)
|
144 |
+
executive = reduce_chain.invoke({"context": "\n\n".join(batch_executive)})
|
145 |
+
|
146 |
+
return {"executive": executive, "documents": docs, "policies": policies}
|
planning_ai/states.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from pathlib import Path
|
2 |
from typing import Annotated, TypedDict
|
3 |
|
|
|
4 |
from langchain_core.documents import Document
|
5 |
from pydantic import BaseModel
|
6 |
|
@@ -14,16 +15,19 @@ class DocumentState(TypedDict):
|
|
14 |
|
15 |
entities: list[dict]
|
16 |
themes: set[str]
|
|
|
17 |
summary: BaseModel
|
18 |
hallucination: HallucinationChecker
|
19 |
|
20 |
-
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
class OverallState(TypedDict):
|
|
|
24 |
executive: str
|
25 |
-
|
26 |
-
policies_support: str
|
27 |
-
policies_object: str
|
28 |
|
29 |
n_docs: int
|
|
|
1 |
from pathlib import Path
|
2 |
from typing import Annotated, TypedDict
|
3 |
|
4 |
+
import polars as pl
|
5 |
from langchain_core.documents import Document
|
6 |
from pydantic import BaseModel
|
7 |
|
|
|
15 |
|
16 |
entities: list[dict]
|
17 |
themes: set[str]
|
18 |
+
|
19 |
summary: BaseModel
|
20 |
hallucination: HallucinationChecker
|
21 |
|
22 |
+
is_hallucinated: bool
|
23 |
+
refinement_attempts: int
|
24 |
+
failed: bool
|
25 |
+
processed: bool
|
26 |
|
27 |
|
28 |
class OverallState(TypedDict):
|
29 |
+
documents: Annotated[list, filename_reducer]
|
30 |
executive: str
|
31 |
+
policies: pl.DataFrame
|
|
|
|
|
32 |
|
33 |
n_docs: int
|