cjber commited on
Commit
1af4802
·
1 Parent(s): 134bc43

fix: simplify structured outputs which improves accuracy

Browse files
planning_ai/chains/map_chain.py CHANGED
@@ -2,7 +2,7 @@ from enum import Enum, auto
2
  from typing import Optional, Set, Type
3
 
4
  from langchain_core.prompts import ChatPromptTemplate
5
- from pydantic import BaseModel, create_model
6
 
7
  from planning_ai.common.utils import Paths
8
  from planning_ai.llms.llm import LLM
@@ -39,9 +39,11 @@ def create_brief_summary_model(policy_enum: Enum) -> Type[BaseModel]:
39
  Type[BaseModel]: A dynamically generated Pydantic model for BriefSummary.
40
  """
41
 
 
42
  DynamicPolicy = create_model(
43
  "DynamicPolicy",
44
- policy=(policy_enum, ...),
 
45
  note=(str, ...),
46
  __config__={"extra": "forbid"},
47
  )
@@ -49,7 +51,9 @@ def create_brief_summary_model(policy_enum: Enum) -> Type[BaseModel]:
49
  return create_model(
50
  "DynamicBriefSummary",
51
  summary=(str, ...),
52
- policies=(Optional[list[DynamicPolicy]], ...),
 
 
53
  __module__=__name__,
54
  __config__={"extra": "forbid"},
55
  )
@@ -82,7 +86,7 @@ if __name__ == "__main__":
82
  the major settlement of Cambourne has been created - now over the projected 3,000 homes and
83
  Papworth Everard has grown beyond recognition. This in itself is a matter of concern.
84
  """
85
- test_themes = {"Great Places", "Homes"}
86
 
87
  dynamic_map_chain = create_dynamic_map_chain(test_themes, prompt=map_template)
88
  result = dynamic_map_chain.invoke({"context": test_document, "themes": test_themes})
 
2
  from typing import Optional, Set, Type
3
 
4
  from langchain_core.prompts import ChatPromptTemplate
5
+ from pydantic import BaseModel, Field, create_model
6
 
7
  from planning_ai.common.utils import Paths
8
  from planning_ai.llms.llm import LLM
 
39
  Type[BaseModel]: A dynamically generated Pydantic model for BriefSummary.
40
  """
41
 
42
+ # NOTE: For some reason GPT4o goes mental if we use too much structure
43
  DynamicPolicy = create_model(
44
  "DynamicPolicy",
45
+ # policy=(policy_enum, ...),
46
+ policy=(str, ...),
47
  note=(str, ...),
48
  __config__={"extra": "forbid"},
49
  )
 
51
  return create_model(
52
  "DynamicBriefSummary",
53
  summary=(str, ...),
54
+ # policies=(Optional[list[DynamicPolicy]], ...),
55
+ policies=(Optional[list[str]], ...),
56
+ notes=(Optional[list[str]], ...),
57
  __module__=__name__,
58
  __config__={"extra": "forbid"},
59
  )
 
86
  the major settlement of Cambourne has been created - now over the projected 3,000 homes and
87
  Papworth Everard has grown beyond recognition. This in itself is a matter of concern.
88
  """
89
+ test_themes = {"Homes", "Great Places"}
90
 
91
  dynamic_map_chain = create_dynamic_map_chain(test_themes, prompt=map_template)
92
  result = dynamic_map_chain.invoke({"context": test_document, "themes": test_themes})
planning_ai/chains/policy_chain.py CHANGED
@@ -1,5 +1,5 @@
1
- from langchain_core.output_parsers import StrOutputParser
2
  from langchain_core.prompts import ChatPromptTemplate
 
3
 
4
  from planning_ai.common.utils import Paths
5
  from planning_ai.llms.llm import LLM
@@ -8,8 +8,18 @@ with open(Paths.PROMPTS / "policy.txt", "r") as f:
8
  policy_template = f.read()
9
 
10
 
 
 
 
 
 
 
 
 
 
 
11
  policy_prompt = ChatPromptTemplate([("system", policy_template)])
12
- policy_chain = policy_prompt | LLM | StrOutputParser()
13
 
14
 
15
  if __name__ == "__main__":
 
 
1
  from langchain_core.prompts import ChatPromptTemplate
2
+ from pydantic import BaseModel, Field
3
 
4
  from planning_ai.common.utils import Paths
5
  from planning_ai.llms.llm import LLM
 
8
  policy_template = f.read()
9
 
10
 
11
+ class PolicyMerger(BaseModel):
12
+ """Return condensed details and their associated doc_ids"""
13
+
14
+ details: list[str]
15
+ doc_id: list[list[int]]
16
+
17
+
18
+ SLLM = LLM.with_structured_output(PolicyMerger, strict=True)
19
+
20
+
21
  policy_prompt = ChatPromptTemplate([("system", policy_template)])
22
+ policy_chain = policy_prompt | SLLM
23
 
24
 
25
  if __name__ == "__main__":
planning_ai/chains/prompts/map.txt CHANGED
@@ -2,7 +2,7 @@ Please analyze the response to the planning application provided below. Your tas
2
 
3
  1. **Summary**: Provide a concise summary of the response, highlighting the main points and any significant details.
4
 
5
- 2. **Policy Identification**: Thoroughly review the response and identify all relevant policies from the provided list. Focus on capturing policies that are explicitly mentioned or strongly implied. Prioritize general policies over specific ones when both are relevant. Avoid inferring new policies beyond those stated. Select **all** relevant policies, even if they seem minor.
6
 
7
  3. **Policy Notes**: For each identified policy, extract and list at least one verbatim section from the response that directly relates to it. Ensure the **full** context is retained so the section can be understood independently. Policy notes may overlap. If a note does not have a clear link to the policy, omit both the policy and the note.
8
 
 
2
 
3
  1. **Summary**: Provide a concise summary of the response, highlighting the main points and any significant details.
4
 
5
+ 2. **Policy Identification**: Carefully review the response and identify all relevant policies from the provided list. Focus on capturing policies that are explicitly mentioned or strongly implied. Avoid inferring new policies beyond those stated. Select **all** relevant policies, even if they seem minor.
6
 
7
  3. **Policy Notes**: For each identified policy, extract and list at least one verbatim section from the response that directly relates to it. Ensure the **full** context is retained so the section can be understood independently. Policy notes may overlap. If a note does not have a clear link to the policy, omit both the policy and the note.
8
 
planning_ai/chains/prompts/policy.txt CHANGED
@@ -6,9 +6,7 @@ You are tasked with refining a list of details related to a specific planning po
6
  4. Exclude any details that do not pertain to the policy.
7
  5. Disregard generic details that merely restate the policy.
8
 
9
- The remaining bullet points **must** be followed by **up to 5** references to the original document IDs. Each bullet point should include inline citations corresponding to all the numerical IDs associated with the original details.
10
-
11
- For example '- Impact of increased housing density on the character of Cambridge [1][2][11].'.
12
 
13
  Theme: {theme}
14
 
 
6
  4. Exclude any details that do not pertain to the policy.
7
  5. Disregard generic details that merely restate the policy.
8
 
9
+ Ensure that all returned details use proper sentence structure.
 
 
10
 
11
  Theme: {theme}
12
 
planning_ai/graph.py CHANGED
@@ -4,13 +4,13 @@ from langgraph.graph import END, StateGraph
4
  from planning_ai.nodes.hallucination_node import (
5
  check_hallucination,
6
  fix_hallucination,
7
- map_fix_hallucinations,
8
- map_hallucinations,
9
  )
10
  from planning_ai.nodes.map_node import (
11
  add_entities,
12
  generate_summary,
13
- map_summaries,
14
  retrieve_themes,
15
  )
16
  from planning_ai.nodes.reduce_node import generate_final_report
@@ -24,31 +24,15 @@ def create_graph():
24
  graph.add_node("generate_summary", generate_summary)
25
  graph.add_node("check_hallucination", check_hallucination)
26
  graph.add_node("fix_hallucination", fix_hallucination)
27
- graph.add_node("generate_final_summary", generate_final_report)
28
 
29
  graph.add_edge(START, "add_entities")
30
- graph.add_conditional_edges(
31
- "add_entities",
32
- map_summaries,
33
- ["generate_summary"],
34
- )
35
- graph.add_conditional_edges(
36
- "generate_summary",
37
- map_hallucinations,
38
- ["check_hallucination"],
39
- )
40
- graph.add_conditional_edges(
41
- "check_hallucination",
42
- map_fix_hallucinations,
43
- ["fix_hallucination"],
44
- )
45
- graph.add_conditional_edges(
46
- "fix_hallucination",
47
- map_hallucinations,
48
- ["check_hallucination"],
49
- )
50
-
51
- graph.add_edge("check_hallucination", "generate_final_summary")
52
  graph.add_edge("generate_final_summary", END)
53
 
54
  return graph.compile()
 
4
  from planning_ai.nodes.hallucination_node import (
5
  check_hallucination,
6
  fix_hallucination,
7
+ map_check,
8
+ map_fix,
9
  )
10
  from planning_ai.nodes.map_node import (
11
  add_entities,
12
  generate_summary,
13
+ map_documents,
14
  retrieve_themes,
15
  )
16
  from planning_ai.nodes.reduce_node import generate_final_report
 
24
  graph.add_node("generate_summary", generate_summary)
25
  graph.add_node("check_hallucination", check_hallucination)
26
  graph.add_node("fix_hallucination", fix_hallucination)
27
+ graph.add_node("generate_final_report", generate_final_report)
28
 
29
  graph.add_edge(START, "add_entities")
30
+ graph.add_conditional_edges("add_entities", map_documents, ["generate_summary"])
31
+ graph.add_conditional_edges("generate_summary", map_check, ["check_hallucination"])
32
+ graph.add_conditional_edges("check_hallucination", map_fix, ["fix_hallucination"])
33
+ graph.add_conditional_edges("fix_hallucination", map_check, ["check_hallucination"])
34
+
35
+ graph.add_edge("check_hallucination", "generate_final_report")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  graph.add_edge("generate_final_summary", END)
37
 
38
  return graph.compile()
planning_ai/nodes/hallucination_node.py CHANGED
@@ -2,14 +2,12 @@ import json
2
  import logging
3
 
4
  from langchain_core.exceptions import OutputParserException
 
5
  from langgraph.types import Send
6
  from pydantic import BaseModel
7
 
8
  from planning_ai.chains.fix_chain import fix_template
9
- from planning_ai.chains.hallucination_chain import (
10
- HallucinationChecker,
11
- hallucination_chain,
12
- )
13
  from planning_ai.chains.map_chain import create_dynamic_map_chain
14
  from planning_ai.states import DocumentState, OverallState
15
 
@@ -19,12 +17,7 @@ logging.basicConfig(
19
  logger = logging.getLogger(__name__)
20
 
21
 
22
- class BasicSummaryBroken(BaseModel):
23
- summary: str
24
- policies: None
25
-
26
-
27
- ITERATIONS = 2
28
 
29
 
30
  def check_hallucination(state: DocumentState):
@@ -43,47 +36,35 @@ def check_hallucination(state: DocumentState):
43
  that need to be addressed.
44
  """
45
  logger.warning(f"Checking hallucinations for document {state['filename']}")
46
- # Stop trying after 2 iterations
47
- if state["iteration"] > ITERATIONS:
48
- state["iteration"] = 99
49
- state["hallucination"].score = 1
50
- return {"documents": [state]}
 
 
51
 
52
  try:
53
  response = hallucination_chain.invoke(
54
  {"document": state["document"], "summary": state["summary"].summary}
55
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  except (OutputParserException, json.JSONDecodeError) as e:
57
  logger.error(f"Failed to decode JSON: {e}.")
58
- state["iteration"] = 99
59
- state["hallucination"] = HallucinationChecker(score=1, explanation="INVALID")
60
- state["summary"] = BasicSummaryBroken(summary="INVALID", policies=None)
61
- return {"documents": [state]}
62
- if response.score == 1:
63
- return {"documents": [{**state, "hallucination": response}]}
64
-
65
- return {
66
- "documents": [
67
- {**state, "hallucination": response, "iteration": state["iteration"] + 1}
68
- ]
69
- }
70
-
71
-
72
- def map_hallucinations(state: OverallState):
73
- """Maps summaries to the `check_hallucination` function.
74
-
75
- This function prepares a list of summaries to be checked for hallucinations by
76
- sending them to the `check_hallucination` function. Allows summaries to be checked
77
- in parrallel.
78
-
79
- Args:
80
- state (OverallState): The overall state containing all summaries.
81
-
82
- Returns:
83
- list: A list of Send objects directing each summary to the check_hallucination
84
- function.
85
- """
86
- return [Send("check_hallucination", document) for document in state["documents"]]
87
 
88
 
89
  def fix_hallucination(state: DocumentState):
@@ -112,35 +93,17 @@ def fix_hallucination(state: DocumentState):
112
  )
113
  except (OutputParserException, json.JSONDecodeError) as e:
114
  logger.error(f"Failed to decode JSON: {e}.")
115
- state["iteration"] = 99
116
- state["hallucination"] = HallucinationChecker(score=1, explanation="INVALID")
117
- state["summary"] = BasicSummaryBroken(summary="INVALID", policies=None)
118
- return {"documents": [state]}
119
- state["summary"] = response # type: ignore
120
- return {"documents": [state]}
121
 
122
 
123
- def map_fix_hallucinations(state: OverallState):
124
- """Maps hallucinations to the `fix_hallucination` function.
125
 
126
- This function filters out hallucinations that need fixing and prepares them to be
127
- sent to the `fix_hallucination` function. Allows hallucinations to be fixed in
128
- parrallel.
129
 
130
- Args:
131
- state (OverallState): The overall state containing all hallucinations.
132
-
133
- Returns:
134
- list: A list of Send objects directing each hallucination to the
135
- fix_hallucination function.
136
- """
137
- hallucinations = []
138
- if "documents" in state:
139
- hallucinations = [
140
- document
141
- for document in state["documents"]
142
- if document["hallucination"].score != 1
143
- ]
144
  return [
145
- Send("fix_hallucination", hallucination) for hallucination in hallucinations
 
 
146
  ]
 
2
  import logging
3
 
4
  from langchain_core.exceptions import OutputParserException
5
+ from langgraph.constants import END
6
  from langgraph.types import Send
7
  from pydantic import BaseModel
8
 
9
  from planning_ai.chains.fix_chain import fix_template
10
+ from planning_ai.chains.hallucination_chain import hallucination_chain
 
 
 
11
  from planning_ai.chains.map_chain import create_dynamic_map_chain
12
  from planning_ai.states import DocumentState, OverallState
13
 
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
20
+ MAX_ATTEMPTS = 3
 
 
 
 
 
21
 
22
 
23
  def check_hallucination(state: DocumentState):
 
36
  that need to be addressed.
37
  """
38
  logger.warning(f"Checking hallucinations for document {state['filename']}")
39
+
40
+ if (state["refinement_attempts"] >= MAX_ATTEMPTS) or state["processed"]:
41
+ logger.warning(f"Max attempts exceeded for document: {state['filename']}")
42
+ return {"documents": [{**state, "failed": True, "processed": True}]}
43
+ elif not state["is_hallucinated"]:
44
+ logger.warning(f"Finished processing document: {state['filename']}")
45
+ return {"documents": [{**state, "processed": True}]}
46
 
47
  try:
48
  response = hallucination_chain.invoke(
49
  {"document": state["document"], "summary": state["summary"].summary}
50
  )
51
+ is_hallucinated = response.score == 0
52
+ refinement_attempts = state["refinement_attempts"] + 1
53
+ out = {
54
+ **state,
55
+ "hallucination": response,
56
+ "refinement_attempts": refinement_attempts,
57
+ "is_hallucinated": is_hallucinated,
58
+ }
59
+ logger.warning(f"Hallucination: {is_hallucinated}")
60
+ return (
61
+ {"documents": [{**out, "processed": False}]}
62
+ if is_hallucinated
63
+ else {"documents": [{**out, "processed": True}]}
64
+ )
65
  except (OutputParserException, json.JSONDecodeError) as e:
66
  logger.error(f"Failed to decode JSON: {e}.")
67
+ return {"documents": [{**state, "failed": True, "processed": True}]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
  def fix_hallucination(state: DocumentState):
 
93
  )
94
  except (OutputParserException, json.JSONDecodeError) as e:
95
  logger.error(f"Failed to decode JSON: {e}.")
96
+ return {"documents": [{**state, "failed": True, "processed": True}]}
97
+ return {"documents": [{**state, "summary": response}]}
 
 
 
 
98
 
99
 
100
+ def map_check(state: OverallState):
101
+ return [Send("check_hallucination", doc) for doc in state["documents"]]
102
 
 
 
 
103
 
104
+ def map_fix(state: OverallState):
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  return [
106
+ Send("fix_hallucination", doc)
107
+ for doc in state["documents"]
108
+ if doc["is_hallucinated"] and not doc["processed"]
109
  ]
planning_ai/nodes/map_node.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
 
4
  import spacy
5
  from langchain_core.exceptions import OutputParserException
 
6
  from langgraph.types import Send
7
  from presidio_analyzer import AnalyzerEngine
8
  from presidio_anonymizer import AnonymizerEngine
@@ -19,26 +20,14 @@ logging.basicConfig(
19
  logger = logging.getLogger(__name__)
20
 
21
 
22
- class BasicSummaryBroken(BaseModel):
23
- summary: str
24
- policies: None
25
-
26
-
27
  analyzer = AnalyzerEngine()
28
  anonymizer = AnonymizerEngine()
29
 
30
  nlp = spacy.load("en_core_web_lg")
31
 
32
 
33
- def _return_summary_error(state):
34
- state["iteration"] = 99
35
- state["hallucination"] = HallucinationChecker(score=1, explanation="INVALID")
36
- state["summary"] = BasicSummaryBroken(summary="INVALID", policies=None)
37
- return {"documents": [state]}
38
-
39
-
40
  def retrieve_themes(state: DocumentState) -> DocumentState:
41
- result = themes_chain.invoke({"document": state["document"]})
42
  if not result.themes:
43
  state["themes"] = set()
44
  return state
@@ -102,29 +91,40 @@ def generate_summary(state: DocumentState) -> dict:
102
  state = retrieve_themes(state)
103
 
104
  if not state["themes"]:
105
- return _return_summary_error(state)
 
 
 
 
 
 
 
 
 
 
 
 
106
  map_chain = create_dynamic_map_chain(themes=state["themes"], prompt=map_template)
107
  try:
108
  response = map_chain.invoke({"context": state["document"].page_content})
109
  except (OutputParserException, json.JSONDecodeError) as e:
110
  logger.error(f"Failed to decode JSON: {e}.")
111
- return _return_summary_error(state)
112
  logger.warning(f"Summary generation completed for document: {state['filename']}")
113
- return {"documents": [{**state, "summary": response, "iteration": 1}]}
114
-
115
-
116
- def map_summaries(state: OverallState) -> list[Send]:
117
- """Maps documents to the `generate_summary` function for processing.
118
-
119
- This function prepares a list of documents to be summarized by sending them to the
120
- `generate_summary` function. It allows for parallel processing of document summaries.
 
 
 
 
121
 
122
- Args:
123
- state (OverallState): The overall state containing all documents and their filenames.
124
 
125
- Returns:
126
- list: A list of Send objects directing each document to the `generate_summary`
127
- function.
128
- """
129
  logger.warning("Mapping documents to generate summaries.")
130
  return [Send("generate_summary", document) for document in state["documents"]]
 
3
 
4
  import spacy
5
  from langchain_core.exceptions import OutputParserException
6
+ from langgraph.constants import END
7
  from langgraph.types import Send
8
  from presidio_analyzer import AnalyzerEngine
9
  from presidio_anonymizer import AnonymizerEngine
 
20
  logger = logging.getLogger(__name__)
21
 
22
 
 
 
 
 
 
23
  analyzer = AnalyzerEngine()
24
  anonymizer = AnonymizerEngine()
25
 
26
  nlp = spacy.load("en_core_web_lg")
27
 
28
 
 
 
 
 
 
 
 
29
  def retrieve_themes(state: DocumentState) -> DocumentState:
30
+ result = themes_chain.invoke({"document": state["document"].page_content})
31
  if not result.themes:
32
  state["themes"] = set()
33
  return state
 
91
  state = retrieve_themes(state)
92
 
93
  if not state["themes"]:
94
+ logger.error("No themes found.")
95
+ return {
96
+ "documents": [
97
+ {
98
+ **state,
99
+ "summary": "",
100
+ "processed": True,
101
+ "is_hallucinated": True,
102
+ "failed": True,
103
+ "refinement_attempts": 0,
104
+ }
105
+ ]
106
+ }
107
  map_chain = create_dynamic_map_chain(themes=state["themes"], prompt=map_template)
108
  try:
109
  response = map_chain.invoke({"context": state["document"].page_content})
110
  except (OutputParserException, json.JSONDecodeError) as e:
111
  logger.error(f"Failed to decode JSON: {e}.")
112
+ return {"documents": [{**state, "failed": True, "processed": True}]}
113
  logger.warning(f"Summary generation completed for document: {state['filename']}")
114
+ return {
115
+ "documents": [
116
+ {
117
+ **state,
118
+ "summary": response,
119
+ "refinement_attempts": 0,
120
+ "is_hallucinated": True, # start true to ensure cycle begins
121
+ "failed": False,
122
+ "processed": False,
123
+ }
124
+ ]
125
+ }
126
 
 
 
127
 
128
+ def map_documents(state: OverallState) -> list[Send]:
 
 
 
129
  logger.warning("Mapping documents to generate summaries.")
130
  return [Send("generate_summary", document) for document in state["documents"]]
planning_ai/nodes/reduce_node.py CHANGED
@@ -15,46 +15,6 @@ logging.basicConfig(
15
  logger = logging.getLogger(__name__)
16
 
17
 
18
- def extract_policies_from_docs(docs):
19
- policies = {"themes": [], "policies": [], "details": [], "stance": []}
20
- for doc in docs:
21
- if not doc["summary"].policies:
22
- continue
23
- for policy in doc["summary"].policies:
24
- for theme, p in THEMES_AND_POLICIES.items():
25
- if policy.policy.name in p:
26
- policies["themes"].append(theme)
27
- policies["policies"].append(policy.policy.name)
28
- policies["details"].append(
29
- f"{policy.note} [{doc['document'].metadata['index']}]"
30
- )
31
- policies["stance"].append(
32
- doc["document"].metadata["representations_support/object"]
33
- )
34
- df = pl.DataFrame(policies)
35
- grouped = df.group_by(["themes", "policies", "stance"]).agg(pl.col("details"))
36
- return grouped
37
-
38
-
39
- def filter_final_documents(state: OverallState):
40
- return [doc for doc in state["documents"] if doc["hallucination"].score == 1]
41
-
42
-
43
- def filter_docs(final_docs):
44
- out_docs = []
45
- for doc in final_docs:
46
- if (
47
- (doc["summary"].summary != "INVALID")
48
- and (doc["themes"] != set())
49
- and (doc["iteration"] != 99)
50
- ):
51
- doc["summary"].summary = (
52
- f"Document ID: [{doc['document'].metadata['index']}]\n\n{doc['summary'].summary}"
53
- )
54
- out_docs.append(doc)
55
- return out_docs
56
-
57
-
58
  def save_summaries_to_json(docs):
59
  """Saves summaries to JSON files.
60
 
@@ -69,12 +29,12 @@ def save_summaries_to_json(docs):
69
  "entities": doc["entities"],
70
  "themes": list(doc["themes"]),
71
  "summary": doc["summary"].model_dump()["summary"],
72
- "policies": [
73
- {"policy": policy["policy"].name, "note": policy["note"]}
74
- for policy in (doc["summary"].model_dump().get("policies", []) or [])
75
- ],
76
- "iteration": doc["iteration"],
77
  "hallucination": doc["hallucination"].model_dump(),
 
 
78
  }
79
  for doc in docs
80
  ]
@@ -84,6 +44,41 @@ def save_summaries_to_json(docs):
84
  json.dump(doc, f)
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def batch_generate_executive_summaries(summaries):
88
  """Processes summaries to generate final responses.
89
 
@@ -101,7 +96,7 @@ def batch_generate_executive_summaries(summaries):
101
  batch_size = 50
102
  for i in range(0, len(summaries_text), batch_size):
103
  logger.warning(
104
- f"Processing batches... {i/50}/{len(summaries_text)//batch_size}"
105
  )
106
  batch = summaries_text[i : i + batch_size]
107
  response = reduce_chain.invoke({"context": batch})
@@ -110,62 +105,42 @@ def batch_generate_executive_summaries(summaries):
110
 
111
 
112
  def generate_policy_output(policy_groups):
113
- policies_support = []
114
- policies_object = []
115
- for _, policy in policy_groups.group_by(["themes", "policies"]):
116
- logger.warning("Processing policies.")
117
- bullets = "* " + "* \n".join(policy["details"][0])
118
- pchain_out = policy_chain.invoke(
119
- {"policy": policy["policies"][0], "bullet_points": bullets}
 
 
 
 
 
 
 
 
120
  )
121
- if policy["stance"][0] == "Support":
122
- policies_support.append(
123
- {
124
- "theme": policy["themes"][0],
125
- "policy": policy["policies"][0],
126
- "points": pchain_out,
127
- }
128
- )
129
- else:
130
- policies_object.append(
131
- {
132
- "theme": policy["themes"][0],
133
- "policy": policy["policies"][0],
134
- "points": pchain_out,
135
- }
136
- )
137
- return policies_support, policies_object
138
-
139
-
140
- def format_themes(policies):
141
- themes = ""
142
- for theme, policies in pl.DataFrame(policies).group_by("theme"):
143
- themes += f"### {theme[0]}\n\n"
144
- for row in policies.iter_rows(named=True):
145
- themes += f"\n#### {row['policy']}\n\n"
146
- themes += f"{row['points']}\n"
147
- themes += "\n"
148
- return themes
149
 
150
 
151
  def generate_final_report(state: OverallState):
152
- logger.warning("Generating final summary")
153
- final_docs = filter_final_documents(state)
154
- logger.warning(f"Number of final docs: {len(final_docs)}")
155
-
156
  if len(final_docs) == state["n_docs"]:
157
- docs = filter_docs(final_docs)
158
- save_summaries_to_json(docs)
159
 
160
- policy_groups = extract_policies_from_docs(docs)
161
- policies_support, policies_object = generate_policy_output(policy_groups)
162
 
163
- batch_executive = batch_generate_executive_summaries(docs)
164
- executive = reduce_chain.invoke({"context": "\n\n".join(batch_executive)})
165
 
166
- return {
167
- "executive": executive,
168
- "documents": final_docs,
169
- "policies_support": format_themes(policies_support),
170
- "policies_object": format_themes(policies_object),
171
- }
 
 
 
 
15
  logger = logging.getLogger(__name__)
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def save_summaries_to_json(docs):
19
  """Saves summaries to JSON files.
20
 
 
29
  "entities": doc["entities"],
30
  "themes": list(doc["themes"]),
31
  "summary": doc["summary"].model_dump()["summary"],
32
+ "policies": doc["policies"],
33
+ "notes": doc["notes"],
34
+ "refinement_attempts": doc["refinement_attempts"],
 
 
35
  "hallucination": doc["hallucination"].model_dump(),
36
+ "is_hallucinated": doc["is_hallucinated"],
37
+ "failed": doc["failed"],
38
  }
39
  for doc in docs
40
  ]
 
44
  json.dump(doc, f)
45
 
46
 
47
+ def extract_policies_from_docs(docs):
48
+ policies = {"doc_id": [], "themes": [], "policies": [], "details": [], "stance": []}
49
+ for doc in docs:
50
+ if not doc["summary"].policies or not doc["summary"].notes:
51
+ continue
52
+ # TODO: Test when this is sometimes empty
53
+ assert len(doc["summary"].policies) == len(doc["summary"].notes), __import__(
54
+ "ipdb"
55
+ ).set_trace()
56
+ try:
57
+ for policy, note in zip(doc["summary"].policies, doc["summary"].notes):
58
+ for theme, p in THEMES_AND_POLICIES.items():
59
+ if policy in p:
60
+ policies["doc_id"].append(doc["document"].metadata["index"])
61
+ policies["themes"].append(theme)
62
+ policies["policies"].append(policy)
63
+ policies["details"].append(note)
64
+ policies["stance"].append(
65
+ doc["document"].metadata["representations_support/object"]
66
+ )
67
+ except Exception:
68
+ __import__("ipdb").set_trace()
69
+ return pl.DataFrame(policies)
70
+
71
+
72
+ def add_doc_id(final_docs):
73
+ out_docs = []
74
+ for doc in final_docs:
75
+ doc["summary"].summary = (
76
+ f"Document ID: [{doc['document'].metadata['index']}]\n\n{doc['summary'].summary}"
77
+ )
78
+ out_docs.append(doc)
79
+ return out_docs
80
+
81
+
82
  def batch_generate_executive_summaries(summaries):
83
  """Processes summaries to generate final responses.
84
 
 
96
  batch_size = 50
97
  for i in range(0, len(summaries_text), batch_size):
98
  logger.warning(
99
+ f"Processing batches... {int(i/50)+1}/{(len(summaries_text)//batch_size)+1}"
100
  )
101
  batch = summaries_text[i : i + batch_size]
102
  response = reduce_chain.invoke({"context": batch})
 
105
 
106
 
107
  def generate_policy_output(policy_groups):
108
+ out = []
109
+ for policy in (
110
+ policy_groups.group_by(["themes", "policies", "stance"])
111
+ .agg(pl.col("details"), pl.col("doc_id"))
112
+ .rows(named=True)
113
+ ):
114
+ logger.warning(f"Processing policies: {policy['policies']}...")
115
+ reduced = policy_chain.invoke(
116
+ {
117
+ "theme": policy["themes"],
118
+ "policy": policy["policies"],
119
+ "stance": policy["stance"],
120
+ "details": policy["details"],
121
+ "doc_id": policy["doc_id"],
122
+ }
123
  )
124
+ out.append(policy | reduced.dict())
125
+ return pl.DataFrame(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
 
128
  def generate_final_report(state: OverallState):
129
+ final_docs = [doc for doc in state["documents"] if doc["processed"]]
 
 
 
130
  if len(final_docs) == state["n_docs"]:
131
+ logging.warning(f"Generating final report... ({len(final_docs)} documents)")
132
+ return final_output(final_docs)
133
 
 
 
134
 
135
+ def final_output(final_docs):
136
+ docs = [doc for doc in final_docs if not doc["failed"]]
137
 
138
+ docs = add_doc_id(docs)
139
+
140
+ policy_groups = extract_policies_from_docs(docs)
141
+ policies = generate_policy_output(policy_groups)
142
+
143
+ batch_executive = batch_generate_executive_summaries(docs)
144
+ executive = reduce_chain.invoke({"context": "\n\n".join(batch_executive)})
145
+
146
+ return {"executive": executive, "documents": docs, "policies": policies}
planning_ai/states.py CHANGED
@@ -1,6 +1,7 @@
1
  from pathlib import Path
2
  from typing import Annotated, TypedDict
3
 
 
4
  from langchain_core.documents import Document
5
  from pydantic import BaseModel
6
 
@@ -14,16 +15,19 @@ class DocumentState(TypedDict):
14
 
15
  entities: list[dict]
16
  themes: set[str]
 
17
  summary: BaseModel
18
  hallucination: HallucinationChecker
19
 
20
- iteration: int
 
 
 
21
 
22
 
23
  class OverallState(TypedDict):
 
24
  executive: str
25
- documents: Annotated[list[DocumentState], filename_reducer]
26
- policies_support: str
27
- policies_object: str
28
 
29
  n_docs: int
 
1
  from pathlib import Path
2
  from typing import Annotated, TypedDict
3
 
4
+ import polars as pl
5
  from langchain_core.documents import Document
6
  from pydantic import BaseModel
7
 
 
15
 
16
  entities: list[dict]
17
  themes: set[str]
18
+
19
  summary: BaseModel
20
  hallucination: HallucinationChecker
21
 
22
+ is_hallucinated: bool
23
+ refinement_attempts: int
24
+ failed: bool
25
+ processed: bool
26
 
27
 
28
  class OverallState(TypedDict):
29
+ documents: Annotated[list, filename_reducer]
30
  executive: str
31
+ policies: pl.DataFrame
 
 
32
 
33
  n_docs: int