cjber commited on
Commit
3a57990
·
1 Parent(s): bed065c

feat: add scores for theme selection to allow filtering

Browse files
planning_ai/chains/map_chain.py CHANGED
@@ -2,7 +2,7 @@ from enum import Enum, auto
2
  from typing import Optional, Set, Type
3
 
4
  from langchain_core.prompts import ChatPromptTemplate
5
- from pydantic import BaseModel, create_model
6
 
7
  from planning_ai.common.utils import Paths
8
  from planning_ai.llms.llm import LLM
@@ -13,13 +13,13 @@ with open(Paths.PROMPTS / "map.txt", "r") as f:
13
 
14
 
15
  def create_policy_enum(
16
- policy_groups: Set[str], name: str = "DynamicPolicyEnum"
17
  ) -> Enum:
18
  """
19
  Create a dynamic enum for policies based on the given policy groups.
20
 
21
  Args:
22
- policy_groups (Set[str]): A set of policy group names.
23
  name (str): Name of the enum to be created.
24
 
25
  Returns:
@@ -39,29 +39,24 @@ def create_brief_summary_model(policy_enum: Enum) -> Type[BaseModel]:
39
  Type[BaseModel]: A dynamically generated Pydantic model for BriefSummary.
40
  """
41
 
42
- # NOTE: For some reason GPT4o doesn't work if we use too much structure
43
- DynamicPolicy = create_model(
44
- "DynamicPolicy",
45
- # policy=(policy_enum, ...),
46
- policy=(str, ...),
47
- note=(str, ...),
48
- __config__={"extra": "forbid"},
49
- )
50
 
51
  return create_model(
52
  "DynamicBriefSummary",
53
  summary=(str, ...),
54
- policies=(Optional[list[DynamicPolicy]], ...),
55
  __module__=__name__,
56
  __config__={"extra": "forbid"},
57
  )
58
 
59
 
60
  def create_dynamic_map_chain(themes, prompt: str):
61
- policy_groups = set()
62
  for theme in themes:
63
  if theme in THEMES_AND_POLICIES:
64
- policy_groups.update(THEMES_AND_POLICIES[theme])
65
 
66
  PolicyEnum = create_policy_enum(policy_groups)
67
  DynamicBriefSummary = create_brief_summary_model(PolicyEnum)
 
2
  from typing import Optional, Set, Type
3
 
4
  from langchain_core.prompts import ChatPromptTemplate
5
+ from pydantic import BaseModel, Field, create_model
6
 
7
  from planning_ai.common.utils import Paths
8
  from planning_ai.llms.llm import LLM
 
13
 
14
 
15
  def create_policy_enum(
16
+ policy_groups: list[str], name: str = "DynamicPolicyEnum"
17
  ) -> Enum:
18
  """
19
  Create a dynamic enum for policies based on the given policy groups.
20
 
21
  Args:
22
+ policy_groups (list[str]): A set of policy group names.
23
  name (str): Name of the enum to be created.
24
 
25
  Returns:
 
39
  Type[BaseModel]: A dynamically generated Pydantic model for BriefSummary.
40
  """
41
 
42
+ class Policy(BaseModel):
43
+ policy: policy_enum
44
+ note: str
 
 
 
 
 
45
 
46
  return create_model(
47
  "DynamicBriefSummary",
48
  summary=(str, ...),
49
+ policies=(list[Policy], ...),
50
  __module__=__name__,
51
  __config__={"extra": "forbid"},
52
  )
53
 
54
 
55
  def create_dynamic_map_chain(themes, prompt: str):
56
+ policy_groups = []
57
  for theme in themes:
58
  if theme in THEMES_AND_POLICIES:
59
+ policy_groups.extend(THEMES_AND_POLICIES[theme])
60
 
61
  PolicyEnum = create_policy_enum(policy_groups)
62
  DynamicBriefSummary = create_brief_summary_model(PolicyEnum)
planning_ai/chains/themes_chain.py CHANGED
@@ -2,13 +2,13 @@ from enum import Enum
2
  from typing import Optional
3
 
4
  from langchain_core.prompts import ChatPromptTemplate
5
- from pydantic import BaseModel
6
 
7
  from planning_ai.common.utils import Paths
8
  from planning_ai.llms.llm import LLM
9
 
10
 
11
- class Themes(Enum):
12
  climate_change = "Climate Change"
13
  biodiversity = "Biodiversity and Green Spaces"
14
  wellbeing = "Wellbeing and Social Inclusion"
@@ -18,8 +18,13 @@ class Themes(Enum):
18
  infrastructure = "Infrastructure"
19
 
20
 
 
 
 
 
 
21
  class ThemeSelector(BaseModel):
22
- themes: Optional[list[Themes]]
23
 
24
 
25
  with open(Paths.PROMPTS / "themes.txt", "r") as f:
 
2
  from typing import Optional
3
 
4
  from langchain_core.prompts import ChatPromptTemplate
5
+ from pydantic import BaseModel, Field
6
 
7
  from planning_ai.common.utils import Paths
8
  from planning_ai.llms.llm import LLM
9
 
10
 
11
+ class Theme(Enum):
12
  climate_change = "Climate Change"
13
  biodiversity = "Biodiversity and Green Spaces"
14
  wellbeing = "Wellbeing and Social Inclusion"
 
18
  infrastructure = "Infrastructure"
19
 
20
 
21
+ class ThemeScore(BaseModel):
22
+ theme: Theme
23
+ score: int
24
+
25
+
26
  class ThemeSelector(BaseModel):
27
+ themes: Optional[list[ThemeScore]]
28
 
29
 
30
  with open(Paths.PROMPTS / "themes.txt", "r") as f:
planning_ai/nodes/hallucination_node.py CHANGED
@@ -83,7 +83,8 @@ def fix_hallucination(state: DocumentState):
83
  hallucinations.
84
  """
85
  logger.warning(f"Fixing hallucinations for document {state['filename']}")
86
- fix_chain = create_dynamic_map_chain(state["themes"], fix_template)
 
87
  try:
88
  response = fix_chain.invoke(
89
  {
 
83
  hallucinations.
84
  """
85
  logger.warning(f"Fixing hallucinations for document {state['filename']}")
86
+ themes = [theme["theme"].value for theme in state["themes"]]
87
+ fix_chain = create_dynamic_map_chain(themes, fix_template)
88
  try:
89
  response = fix_chain.invoke(
90
  {
planning_ai/nodes/map_node.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import spacy
2
  from langgraph.types import Send
3
  from presidio_analyzer import AnalyzerEngine
@@ -14,19 +15,24 @@ anonymizer = AnonymizerEngine()
14
  nlp = spacy.load("en_core_web_lg")
15
 
16
 
17
-
18
  def retrieve_themes(state: DocumentState) -> DocumentState:
19
  try:
20
  result = themes_chain.invoke({"document": state["document"].page_content})
21
  if not result.themes:
22
- state["themes"] = set()
23
  return state
24
- themes = [theme.value for theme in result.themes]
25
  except Exception as e:
26
  logger.error(f"Theme selection error: {e}")
27
  themes = []
28
-
29
- state["themes"] = set(themes)
 
 
 
 
 
 
30
  return state
31
 
32
 
@@ -98,7 +104,8 @@ def generate_summary(state: DocumentState) -> dict:
98
  ]
99
  }
100
 
101
- map_chain = create_dynamic_map_chain(themes=state["themes"], prompt=map_template)
 
102
  try:
103
  response = map_chain.invoke({"context": state["document"].page_content})
104
  except Exception as e:
 
1
+ import numpy as np
2
  import spacy
3
  from langgraph.types import Send
4
  from presidio_analyzer import AnalyzerEngine
 
15
  nlp = spacy.load("en_core_web_lg")
16
 
17
 
 
18
  def retrieve_themes(state: DocumentState) -> DocumentState:
19
  try:
20
  result = themes_chain.invoke({"document": state["document"].page_content})
21
  if not result.themes:
22
+ state["themes"] = []
23
  return state
24
+ themes = [theme.model_dump() for theme in result.themes]
25
  except Exception as e:
26
  logger.error(f"Theme selection error: {e}")
27
  themes = []
28
+ state["themes"] = themes
29
+ state["themes"] = [d for d in state["themes"] if d["score"] > 2]
30
+ state["score"] = np.mean([theme["score"] for theme in state["themes"]])
31
+ if state["score"] < 3:
32
+ state["processed"] = True
33
+ state["failed"] = True
34
+
35
+ logger.info(f"Document {state['filename']} theme score: {state['score']}")
36
  return state
37
 
38
 
 
104
  ]
105
  }
106
 
107
+ themes = [theme["theme"].value for theme in state["themes"]]
108
+ map_chain = create_dynamic_map_chain(themes=themes, prompt=map_template)
109
  try:
110
  response = map_chain.invoke({"context": state["document"].page_content})
111
  except Exception as e:
planning_ai/nodes/reduce_node.py CHANGED
@@ -23,7 +23,7 @@ def save_summaries_to_json(docs):
23
  **doc["document"].metadata,
24
  "filename": doc["filename"],
25
  "entities": doc["entities"],
26
- "themes": list(doc["themes"]),
27
  "summary": doc["summary"].model_dump()["summary"],
28
  "policies": doc["policies"],
29
  "notes": doc["notes"],
@@ -47,10 +47,10 @@ def extract_policies_from_docs(docs):
47
  continue
48
  for policy in doc["summary"].policies:
49
  for theme, p in THEMES_AND_POLICIES.items():
50
- if policy.policy in p:
51
  policies["doc_id"].append(doc["doc_id"])
52
  policies["themes"].append(theme)
53
- policies["policies"].append(policy.policy)
54
  policies["details"].append(policy.note)
55
  policies["stance"].append(
56
  doc["document"].metadata["representations_support/object"]
@@ -131,7 +131,6 @@ def generate_final_report(state: OverallState):
131
 
132
  def final_output(final_docs):
133
  docs = [doc for doc in final_docs if not doc["failed"]]
134
-
135
  docs = add_doc_id(docs)
136
 
137
  policy_groups = extract_policies_from_docs(docs)
 
23
  **doc["document"].metadata,
24
  "filename": doc["filename"],
25
  "entities": doc["entities"],
26
+ "themes": doc["themes"].model_dump(),
27
  "summary": doc["summary"].model_dump()["summary"],
28
  "policies": doc["policies"],
29
  "notes": doc["notes"],
 
47
  continue
48
  for policy in doc["summary"].policies:
49
  for theme, p in THEMES_AND_POLICIES.items():
50
+ if policy.policy.name in p:
51
  policies["doc_id"].append(doc["doc_id"])
52
  policies["themes"].append(theme)
53
+ policies["policies"].append(policy.policy.name)
54
  policies["details"].append(policy.note)
55
  policies["stance"].append(
56
  doc["document"].metadata["representations_support/object"]
 
131
 
132
  def final_output(final_docs):
133
  docs = [doc for doc in final_docs if not doc["failed"]]
 
134
  docs = add_doc_id(docs)
135
 
136
  policy_groups = extract_policies_from_docs(docs)