cjber commited on
Commit
3f35591
·
1 Parent(s): 691ae52

add themes and policies processing

Browse files
Files changed (1) hide show
  1. planning_ai/nodes/reduce_node.py +67 -65
planning_ai/nodes/reduce_node.py CHANGED
@@ -15,26 +15,24 @@ logging.basicConfig(
15
  logger = logging.getLogger(__name__)
16
 
17
 
18
- def extract_policies_from_summaries(summaries):
19
- policies = {"themes": [], "policies": [], "details": []}
20
- for summary in summaries:
21
- if not summary["summary"].policies:
22
  continue
23
- for policy in summary["summary"].policies:
24
  for theme, p in THEMES_AND_POLICIES.items():
25
  if policy.policy.name in p:
26
  policies["themes"].append(theme)
27
  policies["policies"].append(policy.policy.name)
28
- policies["details"].append(policy.note)
 
 
 
 
 
29
  df = pl.DataFrame(policies)
30
-
31
- grouped = df.group_by(["themes", "policies"]).agg(pl.col("details"))
32
- return grouped
33
-
34
-
35
- def markdown_bullets(summaries):
36
- policies = extract_policies_from_summaries(summaries)
37
- grouped = policies.group_by(["themes", "policies"]).agg(pl.col("details"))
38
  return grouped
39
 
40
 
@@ -42,17 +40,22 @@ def filter_final_documents(state: OverallState):
42
  return [doc for doc in state["documents"] if doc["hallucination"].score == 1]
43
 
44
 
45
- def filter_summaries(final_docs, state: OverallState):
46
- return [
47
- doc
48
- for id, doc in zip(range(state["n_docs"]), final_docs)
49
- if doc["summary"].summary != "INVALID"
50
- and doc["themes"] != set()
51
- and doc["iteration"] != 99
52
- ]
 
 
 
 
 
53
 
54
 
55
- def save_summaries_to_json(summaries):
56
  """Saves summaries to JSON files.
57
 
58
  Args:
@@ -61,19 +64,19 @@ def save_summaries_to_json(summaries):
61
  out = [
62
  {
63
  "document": doc["document"].model_dump()["page_content"],
 
64
  "filename": doc["filename"],
65
  "entities": doc["entities"],
66
- "theme_docs": [d.model_dump() for d in doc["theme_docs"]],
67
  "themes": list(doc["themes"]),
68
  "summary": doc["summary"].model_dump()["summary"],
69
  "policies": [
70
  {"policy": policy["policy"].name, "note": policy["note"]}
71
- for policy in doc["summary"].model_dump().get("policies", [])
72
  ],
73
  "iteration": doc["iteration"],
74
  "hallucination": doc["hallucination"].model_dump(),
75
  }
76
- for doc in summaries
77
  ]
78
  for doc in out:
79
  filename = Path(str(doc["filename"])).stem
@@ -90,11 +93,16 @@ def batch_generate_executive_summaries(summaries):
90
  Returns:
91
  list: A list of final responses.
92
  """
93
- summaries_text = [s["summary"].summary for s in summaries]
 
 
 
94
  final_responses = []
95
  batch_size = 50
96
  for i in range(0, len(summaries_text), batch_size):
97
- logger.warning("Processing batches.")
 
 
98
  batch = summaries_text[i : i + batch_size]
99
  response = reduce_chain.invoke({"context": batch})
100
  final_responses.append(response)
@@ -102,68 +110,62 @@ def batch_generate_executive_summaries(summaries):
102
 
103
 
104
  def generate_policy_output(policy_groups):
105
- """Generates policy output from grouped policies.
106
-
107
- Args:
108
- pols (pl.DataFrame): A DataFrame with grouped policies.
109
-
110
- Returns:
111
- list: A list of policy outputs.
112
- """
113
- pol_out = []
114
  for _, policy in policy_groups.group_by(["themes", "policies"]):
115
  logger.warning("Processing policies.")
116
  bullets = "* " + "* \n".join(policy["details"][0])
117
  pchain_out = policy_chain.invoke(
118
  {"policy": policy["policies"][0], "bullet_points": bullets}
119
  )
120
- pol_out.append(
121
- {
122
- "theme": policy["themes"][0],
123
- "policy": policy["policies"][0],
124
- "points": pchain_out,
125
- }
126
- )
127
- return pol_out
 
 
 
 
 
 
 
 
 
128
 
129
 
130
  def format_themes(policies):
131
- """Formats themes and policies into a markdown string.
132
-
133
- Args:
134
- policies (list): A list of policy outputs.
135
-
136
- Returns:
137
- str: A formatted markdown string of themes and policies.
138
- """
139
  themes = ""
140
  for theme, policies in pl.DataFrame(policies).group_by("theme"):
141
- themes += f"# {theme[0]}\n\n"
142
  for row in policies.iter_rows(named=True):
143
- themes += f"\n## {row['policy']}\n\n"
144
  themes += f"{row['points']}\n"
145
  themes += "\n"
146
  return themes
147
 
148
 
149
- def generate_final_summary(state: OverallState):
150
  logger.warning("Generating final summary")
151
  final_docs = filter_final_documents(state)
152
  logger.warning(f"Number of final docs: {len(final_docs)}")
153
 
154
  if len(final_docs) == state["n_docs"]:
155
- summaries = filter_summaries(final_docs, state)
156
- save_summaries_to_json(summaries)
157
 
158
- final_responses = batch_generate_executive_summaries(summaries)
159
- final_response = reduce_chain.invoke({"context": "\n\n".join(final_responses)})
160
 
161
- policy_groups = markdown_bullets(summaries)
162
- policy_outputs = generate_policy_output(policy_groups)
163
- themes = format_themes(policy_outputs)
164
 
165
  return {
166
- "final_summary": final_response,
167
  "documents": final_docs,
168
- "policies": themes,
 
169
  }
 
15
  logger = logging.getLogger(__name__)
16
 
17
 
18
+ def extract_policies_from_docs(docs):
19
+ policies = {"themes": [], "policies": [], "details": [], "stance": []}
20
+ for doc in docs:
21
+ if not doc["summary"].policies:
22
  continue
23
+ for policy in doc["summary"].policies:
24
  for theme, p in THEMES_AND_POLICIES.items():
25
  if policy.policy.name in p:
26
  policies["themes"].append(theme)
27
  policies["policies"].append(policy.policy.name)
28
+ policies["details"].append(
29
+ f"{policy.note} [{doc['document'].metadata['index']}]"
30
+ )
31
+ policies["stance"].append(
32
+ doc["document"].metadata["representations_support/object"]
33
+ )
34
  df = pl.DataFrame(policies)
35
+ grouped = df.group_by(["themes", "policies", "stance"]).agg(pl.col("details"))
 
 
 
 
 
 
 
36
  return grouped
37
 
38
 
 
40
  return [doc for doc in state["documents"] if doc["hallucination"].score == 1]
41
 
42
 
43
+ def filter_docs(final_docs):
44
+ out_docs = []
45
+ for doc in final_docs:
46
+ if (
47
+ (doc["summary"].summary != "INVALID")
48
+ and (doc["themes"] != set())
49
+ and (doc["iteration"] != 99)
50
+ ):
51
+ doc["summary"].summary = (
52
+ f"Document ID: [{doc['document'].metadata['index']}]\n\n{doc['summary'].summary}"
53
+ )
54
+ out_docs.append(doc)
55
+ return out_docs
56
 
57
 
58
+ def save_summaries_to_json(docs):
59
  """Saves summaries to JSON files.
60
 
61
  Args:
 
64
  out = [
65
  {
66
  "document": doc["document"].model_dump()["page_content"],
67
+ **doc["document"].metadata,
68
  "filename": doc["filename"],
69
  "entities": doc["entities"],
 
70
  "themes": list(doc["themes"]),
71
  "summary": doc["summary"].model_dump()["summary"],
72
  "policies": [
73
  {"policy": policy["policy"].name, "note": policy["note"]}
74
+ for policy in (doc["summary"].model_dump().get("policies", []) or [])
75
  ],
76
  "iteration": doc["iteration"],
77
  "hallucination": doc["hallucination"].model_dump(),
78
  }
79
+ for doc in docs
80
  ]
81
  for doc in out:
82
  filename = Path(str(doc["filename"])).stem
 
93
  Returns:
94
  list: A list of final responses.
95
  """
96
+ summaries_text = [
97
+ f"Document ID: {[s['document'].metadata['index']]} {s['summary'].summary}"
98
+ for s in summaries
99
+ ]
100
  final_responses = []
101
  batch_size = 50
102
  for i in range(0, len(summaries_text), batch_size):
103
+ logger.warning(
104
+ f"Processing batches... {i/50}/{len(summaries_text)//batch_size}"
105
+ )
106
  batch = summaries_text[i : i + batch_size]
107
  response = reduce_chain.invoke({"context": batch})
108
  final_responses.append(response)
 
110
 
111
 
112
  def generate_policy_output(policy_groups):
113
+ policies_support = []
114
+ policies_object = []
 
 
 
 
 
 
 
115
  for _, policy in policy_groups.group_by(["themes", "policies"]):
116
  logger.warning("Processing policies.")
117
  bullets = "* " + "* \n".join(policy["details"][0])
118
  pchain_out = policy_chain.invoke(
119
  {"policy": policy["policies"][0], "bullet_points": bullets}
120
  )
121
+ if policy["stance"][0] == "Support":
122
+ policies_support.append(
123
+ {
124
+ "theme": policy["themes"][0],
125
+ "policy": policy["policies"][0],
126
+ "points": pchain_out,
127
+ }
128
+ )
129
+ else:
130
+ policies_object.append(
131
+ {
132
+ "theme": policy["themes"][0],
133
+ "policy": policy["policies"][0],
134
+ "points": pchain_out,
135
+ }
136
+ )
137
+ return policies_support, policies_object
138
 
139
 
140
  def format_themes(policies):
 
 
 
 
 
 
 
 
141
  themes = ""
142
  for theme, policies in pl.DataFrame(policies).group_by("theme"):
143
+ themes += f"### {theme[0]}\n\n"
144
  for row in policies.iter_rows(named=True):
145
+ themes += f"\n#### {row['policy']}\n\n"
146
  themes += f"{row['points']}\n"
147
  themes += "\n"
148
  return themes
149
 
150
 
151
+ def generate_final_report(state: OverallState):
152
  logger.warning("Generating final summary")
153
  final_docs = filter_final_documents(state)
154
  logger.warning(f"Number of final docs: {len(final_docs)}")
155
 
156
  if len(final_docs) == state["n_docs"]:
157
+ docs = filter_docs(final_docs)
158
+ save_summaries_to_json(docs)
159
 
160
+ policy_groups = extract_policies_from_docs(docs)
161
+ policies_support, policies_object = generate_policy_output(policy_groups)
162
 
163
+ batch_executive = batch_generate_executive_summaries(docs)
164
+ executive = reduce_chain.invoke({"context": "\n\n".join(batch_executive)})
 
165
 
166
  return {
167
+ "executive": executive,
168
  "documents": final_docs,
169
+ "policies_support": format_themes(policies_support),
170
+ "policies_object": format_themes(policies_object),
171
  }