Spaces:
Build error
Build error
add themes and policies processing
Browse files- planning_ai/nodes/reduce_node.py +67 -65
planning_ai/nodes/reduce_node.py
CHANGED
@@ -15,26 +15,24 @@ logging.basicConfig(
|
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
|
18 |
-
def
|
19 |
-
policies = {"themes": [], "policies": [], "details": []}
|
20 |
-
for
|
21 |
-
if not
|
22 |
continue
|
23 |
-
for policy in
|
24 |
for theme, p in THEMES_AND_POLICIES.items():
|
25 |
if policy.policy.name in p:
|
26 |
policies["themes"].append(theme)
|
27 |
policies["policies"].append(policy.policy.name)
|
28 |
-
policies["details"].append(
|
|
|
|
|
|
|
|
|
|
|
29 |
df = pl.DataFrame(policies)
|
30 |
-
|
31 |
-
grouped = df.group_by(["themes", "policies"]).agg(pl.col("details"))
|
32 |
-
return grouped
|
33 |
-
|
34 |
-
|
35 |
-
def markdown_bullets(summaries):
|
36 |
-
policies = extract_policies_from_summaries(summaries)
|
37 |
-
grouped = policies.group_by(["themes", "policies"]).agg(pl.col("details"))
|
38 |
return grouped
|
39 |
|
40 |
|
@@ -42,17 +40,22 @@ def filter_final_documents(state: OverallState):
|
|
42 |
return [doc for doc in state["documents"] if doc["hallucination"].score == 1]
|
43 |
|
44 |
|
45 |
-
def
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
|
55 |
-
def save_summaries_to_json(
|
56 |
"""Saves summaries to JSON files.
|
57 |
|
58 |
Args:
|
@@ -61,19 +64,19 @@ def save_summaries_to_json(summaries):
|
|
61 |
out = [
|
62 |
{
|
63 |
"document": doc["document"].model_dump()["page_content"],
|
|
|
64 |
"filename": doc["filename"],
|
65 |
"entities": doc["entities"],
|
66 |
-
"theme_docs": [d.model_dump() for d in doc["theme_docs"]],
|
67 |
"themes": list(doc["themes"]),
|
68 |
"summary": doc["summary"].model_dump()["summary"],
|
69 |
"policies": [
|
70 |
{"policy": policy["policy"].name, "note": policy["note"]}
|
71 |
-
for policy in doc["summary"].model_dump().get("policies", [])
|
72 |
],
|
73 |
"iteration": doc["iteration"],
|
74 |
"hallucination": doc["hallucination"].model_dump(),
|
75 |
}
|
76 |
-
for doc in
|
77 |
]
|
78 |
for doc in out:
|
79 |
filename = Path(str(doc["filename"])).stem
|
@@ -90,11 +93,16 @@ def batch_generate_executive_summaries(summaries):
|
|
90 |
Returns:
|
91 |
list: A list of final responses.
|
92 |
"""
|
93 |
-
summaries_text = [
|
|
|
|
|
|
|
94 |
final_responses = []
|
95 |
batch_size = 50
|
96 |
for i in range(0, len(summaries_text), batch_size):
|
97 |
-
logger.warning(
|
|
|
|
|
98 |
batch = summaries_text[i : i + batch_size]
|
99 |
response = reduce_chain.invoke({"context": batch})
|
100 |
final_responses.append(response)
|
@@ -102,68 +110,62 @@ def batch_generate_executive_summaries(summaries):
|
|
102 |
|
103 |
|
104 |
def generate_policy_output(policy_groups):
|
105 |
-
|
106 |
-
|
107 |
-
Args:
|
108 |
-
pols (pl.DataFrame): A DataFrame with grouped policies.
|
109 |
-
|
110 |
-
Returns:
|
111 |
-
list: A list of policy outputs.
|
112 |
-
"""
|
113 |
-
pol_out = []
|
114 |
for _, policy in policy_groups.group_by(["themes", "policies"]):
|
115 |
logger.warning("Processing policies.")
|
116 |
bullets = "* " + "* \n".join(policy["details"][0])
|
117 |
pchain_out = policy_chain.invoke(
|
118 |
{"policy": policy["policies"][0], "bullet_points": bullets}
|
119 |
)
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
|
130 |
def format_themes(policies):
|
131 |
-
"""Formats themes and policies into a markdown string.
|
132 |
-
|
133 |
-
Args:
|
134 |
-
policies (list): A list of policy outputs.
|
135 |
-
|
136 |
-
Returns:
|
137 |
-
str: A formatted markdown string of themes and policies.
|
138 |
-
"""
|
139 |
themes = ""
|
140 |
for theme, policies in pl.DataFrame(policies).group_by("theme"):
|
141 |
-
themes += f"
|
142 |
for row in policies.iter_rows(named=True):
|
143 |
-
themes += f"\n
|
144 |
themes += f"{row['points']}\n"
|
145 |
themes += "\n"
|
146 |
return themes
|
147 |
|
148 |
|
149 |
-
def
|
150 |
logger.warning("Generating final summary")
|
151 |
final_docs = filter_final_documents(state)
|
152 |
logger.warning(f"Number of final docs: {len(final_docs)}")
|
153 |
|
154 |
if len(final_docs) == state["n_docs"]:
|
155 |
-
|
156 |
-
save_summaries_to_json(
|
157 |
|
158 |
-
|
159 |
-
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
themes = format_themes(policy_outputs)
|
164 |
|
165 |
return {
|
166 |
-
"
|
167 |
"documents": final_docs,
|
168 |
-
"
|
|
|
169 |
}
|
|
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
|
18 |
+
def extract_policies_from_docs(docs):
|
19 |
+
policies = {"themes": [], "policies": [], "details": [], "stance": []}
|
20 |
+
for doc in docs:
|
21 |
+
if not doc["summary"].policies:
|
22 |
continue
|
23 |
+
for policy in doc["summary"].policies:
|
24 |
for theme, p in THEMES_AND_POLICIES.items():
|
25 |
if policy.policy.name in p:
|
26 |
policies["themes"].append(theme)
|
27 |
policies["policies"].append(policy.policy.name)
|
28 |
+
policies["details"].append(
|
29 |
+
f"{policy.note} [{doc['document'].metadata['index']}]"
|
30 |
+
)
|
31 |
+
policies["stance"].append(
|
32 |
+
doc["document"].metadata["representations_support/object"]
|
33 |
+
)
|
34 |
df = pl.DataFrame(policies)
|
35 |
+
grouped = df.group_by(["themes", "policies", "stance"]).agg(pl.col("details"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
return grouped
|
37 |
|
38 |
|
|
|
40 |
return [doc for doc in state["documents"] if doc["hallucination"].score == 1]
|
41 |
|
42 |
|
43 |
+
def filter_docs(final_docs):
|
44 |
+
out_docs = []
|
45 |
+
for doc in final_docs:
|
46 |
+
if (
|
47 |
+
(doc["summary"].summary != "INVALID")
|
48 |
+
and (doc["themes"] != set())
|
49 |
+
and (doc["iteration"] != 99)
|
50 |
+
):
|
51 |
+
doc["summary"].summary = (
|
52 |
+
f"Document ID: [{doc['document'].metadata['index']}]\n\n{doc['summary'].summary}"
|
53 |
+
)
|
54 |
+
out_docs.append(doc)
|
55 |
+
return out_docs
|
56 |
|
57 |
|
58 |
+
def save_summaries_to_json(docs):
|
59 |
"""Saves summaries to JSON files.
|
60 |
|
61 |
Args:
|
|
|
64 |
out = [
|
65 |
{
|
66 |
"document": doc["document"].model_dump()["page_content"],
|
67 |
+
**doc["document"].metadata,
|
68 |
"filename": doc["filename"],
|
69 |
"entities": doc["entities"],
|
|
|
70 |
"themes": list(doc["themes"]),
|
71 |
"summary": doc["summary"].model_dump()["summary"],
|
72 |
"policies": [
|
73 |
{"policy": policy["policy"].name, "note": policy["note"]}
|
74 |
+
for policy in (doc["summary"].model_dump().get("policies", []) or [])
|
75 |
],
|
76 |
"iteration": doc["iteration"],
|
77 |
"hallucination": doc["hallucination"].model_dump(),
|
78 |
}
|
79 |
+
for doc in docs
|
80 |
]
|
81 |
for doc in out:
|
82 |
filename = Path(str(doc["filename"])).stem
|
|
|
93 |
Returns:
|
94 |
list: A list of final responses.
|
95 |
"""
|
96 |
+
summaries_text = [
|
97 |
+
f"Document ID: {[s['document'].metadata['index']]} {s['summary'].summary}"
|
98 |
+
for s in summaries
|
99 |
+
]
|
100 |
final_responses = []
|
101 |
batch_size = 50
|
102 |
for i in range(0, len(summaries_text), batch_size):
|
103 |
+
logger.warning(
|
104 |
+
f"Processing batches... {i/50}/{len(summaries_text)//batch_size}"
|
105 |
+
)
|
106 |
batch = summaries_text[i : i + batch_size]
|
107 |
response = reduce_chain.invoke({"context": batch})
|
108 |
final_responses.append(response)
|
|
|
110 |
|
111 |
|
112 |
def generate_policy_output(policy_groups):
|
113 |
+
policies_support = []
|
114 |
+
policies_object = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
for _, policy in policy_groups.group_by(["themes", "policies"]):
|
116 |
logger.warning("Processing policies.")
|
117 |
bullets = "* " + "* \n".join(policy["details"][0])
|
118 |
pchain_out = policy_chain.invoke(
|
119 |
{"policy": policy["policies"][0], "bullet_points": bullets}
|
120 |
)
|
121 |
+
if policy["stance"][0] == "Support":
|
122 |
+
policies_support.append(
|
123 |
+
{
|
124 |
+
"theme": policy["themes"][0],
|
125 |
+
"policy": policy["policies"][0],
|
126 |
+
"points": pchain_out,
|
127 |
+
}
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
policies_object.append(
|
131 |
+
{
|
132 |
+
"theme": policy["themes"][0],
|
133 |
+
"policy": policy["policies"][0],
|
134 |
+
"points": pchain_out,
|
135 |
+
}
|
136 |
+
)
|
137 |
+
return policies_support, policies_object
|
138 |
|
139 |
|
140 |
def format_themes(policies):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
themes = ""
|
142 |
for theme, policies in pl.DataFrame(policies).group_by("theme"):
|
143 |
+
themes += f"### {theme[0]}\n\n"
|
144 |
for row in policies.iter_rows(named=True):
|
145 |
+
themes += f"\n#### {row['policy']}\n\n"
|
146 |
themes += f"{row['points']}\n"
|
147 |
themes += "\n"
|
148 |
return themes
|
149 |
|
150 |
|
151 |
+
def generate_final_report(state: OverallState):
|
152 |
logger.warning("Generating final summary")
|
153 |
final_docs = filter_final_documents(state)
|
154 |
logger.warning(f"Number of final docs: {len(final_docs)}")
|
155 |
|
156 |
if len(final_docs) == state["n_docs"]:
|
157 |
+
docs = filter_docs(final_docs)
|
158 |
+
save_summaries_to_json(docs)
|
159 |
|
160 |
+
policy_groups = extract_policies_from_docs(docs)
|
161 |
+
policies_support, policies_object = generate_policy_output(policy_groups)
|
162 |
|
163 |
+
batch_executive = batch_generate_executive_summaries(docs)
|
164 |
+
executive = reduce_chain.invoke({"context": "\n\n".join(batch_executive)})
|
|
|
165 |
|
166 |
return {
|
167 |
+
"executive": executive,
|
168 |
"documents": final_docs,
|
169 |
+
"policies_support": format_themes(policies_support),
|
170 |
+
"policies_object": format_themes(policies_object),
|
171 |
}
|