Spaces:
Build error
Build error
feat: add scores for theme selection to allow filtering
Browse files
planning_ai/chains/map_chain.py
CHANGED
@@ -2,7 +2,7 @@ from enum import Enum, auto
|
|
2 |
from typing import Optional, Set, Type
|
3 |
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
5 |
-
from pydantic import BaseModel, create_model
|
6 |
|
7 |
from planning_ai.common.utils import Paths
|
8 |
from planning_ai.llms.llm import LLM
|
@@ -13,13 +13,13 @@ with open(Paths.PROMPTS / "map.txt", "r") as f:
|
|
13 |
|
14 |
|
15 |
def create_policy_enum(
|
16 |
-
policy_groups:
|
17 |
) -> Enum:
|
18 |
"""
|
19 |
Create a dynamic enum for policies based on the given policy groups.
|
20 |
|
21 |
Args:
|
22 |
-
policy_groups (
|
23 |
name (str): Name of the enum to be created.
|
24 |
|
25 |
Returns:
|
@@ -39,29 +39,24 @@ def create_brief_summary_model(policy_enum: Enum) -> Type[BaseModel]:
|
|
39 |
Type[BaseModel]: A dynamically generated Pydantic model for BriefSummary.
|
40 |
"""
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
# policy=(policy_enum, ...),
|
46 |
-
policy=(str, ...),
|
47 |
-
note=(str, ...),
|
48 |
-
__config__={"extra": "forbid"},
|
49 |
-
)
|
50 |
|
51 |
return create_model(
|
52 |
"DynamicBriefSummary",
|
53 |
summary=(str, ...),
|
54 |
-
policies=(
|
55 |
__module__=__name__,
|
56 |
__config__={"extra": "forbid"},
|
57 |
)
|
58 |
|
59 |
|
60 |
def create_dynamic_map_chain(themes, prompt: str):
|
61 |
-
policy_groups =
|
62 |
for theme in themes:
|
63 |
if theme in THEMES_AND_POLICIES:
|
64 |
-
policy_groups.
|
65 |
|
66 |
PolicyEnum = create_policy_enum(policy_groups)
|
67 |
DynamicBriefSummary = create_brief_summary_model(PolicyEnum)
|
|
|
2 |
from typing import Optional, Set, Type
|
3 |
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
5 |
+
from pydantic import BaseModel, Field, create_model
|
6 |
|
7 |
from planning_ai.common.utils import Paths
|
8 |
from planning_ai.llms.llm import LLM
|
|
|
13 |
|
14 |
|
15 |
def create_policy_enum(
|
16 |
+
policy_groups: list[str], name: str = "DynamicPolicyEnum"
|
17 |
) -> Enum:
|
18 |
"""
|
19 |
Create a dynamic enum for policies based on the given policy groups.
|
20 |
|
21 |
Args:
|
22 |
+
policy_groups (list[str]): A set of policy group names.
|
23 |
name (str): Name of the enum to be created.
|
24 |
|
25 |
Returns:
|
|
|
39 |
Type[BaseModel]: A dynamically generated Pydantic model for BriefSummary.
|
40 |
"""
|
41 |
|
42 |
+
class Policy(BaseModel):
|
43 |
+
policy: policy_enum
|
44 |
+
note: str
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
return create_model(
|
47 |
"DynamicBriefSummary",
|
48 |
summary=(str, ...),
|
49 |
+
policies=(list[Policy], ...),
|
50 |
__module__=__name__,
|
51 |
__config__={"extra": "forbid"},
|
52 |
)
|
53 |
|
54 |
|
55 |
def create_dynamic_map_chain(themes, prompt: str):
|
56 |
+
policy_groups = []
|
57 |
for theme in themes:
|
58 |
if theme in THEMES_AND_POLICIES:
|
59 |
+
policy_groups.extend(THEMES_AND_POLICIES[theme])
|
60 |
|
61 |
PolicyEnum = create_policy_enum(policy_groups)
|
62 |
DynamicBriefSummary = create_brief_summary_model(PolicyEnum)
|
planning_ai/chains/themes_chain.py
CHANGED
@@ -2,13 +2,13 @@ from enum import Enum
|
|
2 |
from typing import Optional
|
3 |
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
5 |
-
from pydantic import BaseModel
|
6 |
|
7 |
from planning_ai.common.utils import Paths
|
8 |
from planning_ai.llms.llm import LLM
|
9 |
|
10 |
|
11 |
-
class
|
12 |
climate_change = "Climate Change"
|
13 |
biodiversity = "Biodiversity and Green Spaces"
|
14 |
wellbeing = "Wellbeing and Social Inclusion"
|
@@ -18,8 +18,13 @@ class Themes(Enum):
|
|
18 |
infrastructure = "Infrastructure"
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
21 |
class ThemeSelector(BaseModel):
|
22 |
-
themes: Optional[list[
|
23 |
|
24 |
|
25 |
with open(Paths.PROMPTS / "themes.txt", "r") as f:
|
|
|
2 |
from typing import Optional
|
3 |
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
|
7 |
from planning_ai.common.utils import Paths
|
8 |
from planning_ai.llms.llm import LLM
|
9 |
|
10 |
|
11 |
+
class Theme(Enum):
|
12 |
climate_change = "Climate Change"
|
13 |
biodiversity = "Biodiversity and Green Spaces"
|
14 |
wellbeing = "Wellbeing and Social Inclusion"
|
|
|
18 |
infrastructure = "Infrastructure"
|
19 |
|
20 |
|
21 |
+
class ThemeScore(BaseModel):
|
22 |
+
theme: Theme
|
23 |
+
score: int
|
24 |
+
|
25 |
+
|
26 |
class ThemeSelector(BaseModel):
|
27 |
+
themes: Optional[list[ThemeScore]]
|
28 |
|
29 |
|
30 |
with open(Paths.PROMPTS / "themes.txt", "r") as f:
|
planning_ai/nodes/hallucination_node.py
CHANGED
@@ -83,7 +83,8 @@ def fix_hallucination(state: DocumentState):
|
|
83 |
hallucinations.
|
84 |
"""
|
85 |
logger.warning(f"Fixing hallucinations for document {state['filename']}")
|
86 |
-
|
|
|
87 |
try:
|
88 |
response = fix_chain.invoke(
|
89 |
{
|
|
|
83 |
hallucinations.
|
84 |
"""
|
85 |
logger.warning(f"Fixing hallucinations for document {state['filename']}")
|
86 |
+
themes = [theme["theme"].value for theme in state["themes"]]
|
87 |
+
fix_chain = create_dynamic_map_chain(themes, fix_template)
|
88 |
try:
|
89 |
response = fix_chain.invoke(
|
90 |
{
|
planning_ai/nodes/map_node.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import spacy
|
2 |
from langgraph.types import Send
|
3 |
from presidio_analyzer import AnalyzerEngine
|
@@ -14,19 +15,24 @@ anonymizer = AnonymizerEngine()
|
|
14 |
nlp = spacy.load("en_core_web_lg")
|
15 |
|
16 |
|
17 |
-
|
18 |
def retrieve_themes(state: DocumentState) -> DocumentState:
|
19 |
try:
|
20 |
result = themes_chain.invoke({"document": state["document"].page_content})
|
21 |
if not result.themes:
|
22 |
-
state["themes"] =
|
23 |
return state
|
24 |
-
themes = [theme.
|
25 |
except Exception as e:
|
26 |
logger.error(f"Theme selection error: {e}")
|
27 |
themes = []
|
28 |
-
|
29 |
-
state["themes"] =
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
return state
|
31 |
|
32 |
|
@@ -98,7 +104,8 @@ def generate_summary(state: DocumentState) -> dict:
|
|
98 |
]
|
99 |
}
|
100 |
|
101 |
-
|
|
|
102 |
try:
|
103 |
response = map_chain.invoke({"context": state["document"].page_content})
|
104 |
except Exception as e:
|
|
|
1 |
+
import numpy as np
|
2 |
import spacy
|
3 |
from langgraph.types import Send
|
4 |
from presidio_analyzer import AnalyzerEngine
|
|
|
15 |
nlp = spacy.load("en_core_web_lg")
|
16 |
|
17 |
|
|
|
18 |
def retrieve_themes(state: DocumentState) -> DocumentState:
|
19 |
try:
|
20 |
result = themes_chain.invoke({"document": state["document"].page_content})
|
21 |
if not result.themes:
|
22 |
+
state["themes"] = []
|
23 |
return state
|
24 |
+
themes = [theme.model_dump() for theme in result.themes]
|
25 |
except Exception as e:
|
26 |
logger.error(f"Theme selection error: {e}")
|
27 |
themes = []
|
28 |
+
state["themes"] = themes
|
29 |
+
state["themes"] = [d for d in state["themes"] if d["score"] > 2]
|
30 |
+
state["score"] = np.mean([theme["score"] for theme in state["themes"]])
|
31 |
+
if state["score"] < 3:
|
32 |
+
state["processed"] = True
|
33 |
+
state["failed"] = True
|
34 |
+
|
35 |
+
logger.info(f"Document {state['filename']} theme score: {state['score']}")
|
36 |
return state
|
37 |
|
38 |
|
|
|
104 |
]
|
105 |
}
|
106 |
|
107 |
+
themes = [theme["theme"].value for theme in state["themes"]]
|
108 |
+
map_chain = create_dynamic_map_chain(themes=themes, prompt=map_template)
|
109 |
try:
|
110 |
response = map_chain.invoke({"context": state["document"].page_content})
|
111 |
except Exception as e:
|
planning_ai/nodes/reduce_node.py
CHANGED
@@ -23,7 +23,7 @@ def save_summaries_to_json(docs):
|
|
23 |
**doc["document"].metadata,
|
24 |
"filename": doc["filename"],
|
25 |
"entities": doc["entities"],
|
26 |
-
"themes":
|
27 |
"summary": doc["summary"].model_dump()["summary"],
|
28 |
"policies": doc["policies"],
|
29 |
"notes": doc["notes"],
|
@@ -47,10 +47,10 @@ def extract_policies_from_docs(docs):
|
|
47 |
continue
|
48 |
for policy in doc["summary"].policies:
|
49 |
for theme, p in THEMES_AND_POLICIES.items():
|
50 |
-
if policy.policy in p:
|
51 |
policies["doc_id"].append(doc["doc_id"])
|
52 |
policies["themes"].append(theme)
|
53 |
-
policies["policies"].append(policy.policy)
|
54 |
policies["details"].append(policy.note)
|
55 |
policies["stance"].append(
|
56 |
doc["document"].metadata["representations_support/object"]
|
@@ -131,7 +131,6 @@ def generate_final_report(state: OverallState):
|
|
131 |
|
132 |
def final_output(final_docs):
|
133 |
docs = [doc for doc in final_docs if not doc["failed"]]
|
134 |
-
|
135 |
docs = add_doc_id(docs)
|
136 |
|
137 |
policy_groups = extract_policies_from_docs(docs)
|
|
|
23 |
**doc["document"].metadata,
|
24 |
"filename": doc["filename"],
|
25 |
"entities": doc["entities"],
|
26 |
+
"themes": doc["themes"].model_dump(),
|
27 |
"summary": doc["summary"].model_dump()["summary"],
|
28 |
"policies": doc["policies"],
|
29 |
"notes": doc["notes"],
|
|
|
47 |
continue
|
48 |
for policy in doc["summary"].policies:
|
49 |
for theme, p in THEMES_AND_POLICIES.items():
|
50 |
+
if policy.policy.name in p:
|
51 |
policies["doc_id"].append(doc["doc_id"])
|
52 |
policies["themes"].append(theme)
|
53 |
+
policies["policies"].append(policy.policy.name)
|
54 |
policies["details"].append(policy.note)
|
55 |
policies["stance"].append(
|
56 |
doc["document"].metadata["representations_support/object"]
|
|
|
131 |
|
132 |
def final_output(final_docs):
|
133 |
docs = [doc for doc in final_docs if not doc["failed"]]
|
|
|
134 |
docs = add_doc_id(docs)
|
135 |
|
136 |
policy_groups = extract_policies_from_docs(docs)
|