cjber commited on
Commit
73fcee0
·
1 Parent(s): 92ba901

add policies to summary

Browse files
Files changed (1) hide show
  1. planning_ai/chains/map_chain.py +61 -29
planning_ai/chains/map_chain.py CHANGED
@@ -1,62 +1,94 @@
1
  from enum import Enum
2
- from typing import Literal
3
 
 
4
  from langchain_core.prompts import ChatPromptTemplate
 
5
  from pydantic import BaseModel, Field
6
 
7
  from planning_ai.common.utils import Paths
8
  from planning_ai.llms.llm import LLM
 
9
 
10
- with open(Paths.PROMPTS / "map.txt", "r") as f:
11
- map_template = f.read()
12
 
 
 
13
 
14
- class Theme(str, Enum):
15
- climate = "Climate change"
16
- biodiversity = "Biodiversity and green spaces"
17
- wellbeing = "Wellbeing and social inclusion"
18
- great_places = "Great places"
19
- jobs = "Jobs"
20
- homes = "Homes"
21
- infrastructure = "Infrastructure"
22
 
23
- def __repr__(self) -> str:
24
- return self.value
 
 
25
 
26
 
27
  class Place(BaseModel):
28
- place: str = Field(..., description="Place mentioned in the response.")
29
- sentiment: int = Field(..., description="Related sentiment ranked 1 to 10.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  class BriefSummary(BaseModel):
33
  """A summary of the response with generated metadata"""
34
 
35
- summary: str = Field(..., description="A summary of the response.")
36
- stance: Literal["SUPPORT", "OPPOSE", "MIXED", "NEUTRAL"] = Field(
37
  ...,
38
- description="Overall stance of the response. Either SUPPORT, OPPOSE, MIXED, or NEUTRAL.",
 
 
 
39
  )
40
  themes: list[Theme] = Field(
41
- ..., description="A list of themes associated with the response."
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
  places: list[Place] = Field(
44
  ...,
45
- description="All places mentioned in the response, with the positivity of the related sentiment ranked 1 to 10",
 
 
 
 
46
  )
47
- rating: int = Field(
48
  ...,
49
- description="How constructive the response is, from a rating of 1 to 10.",
 
 
 
 
50
  )
51
 
52
- def __str__(self) -> str:
53
- return (f"Summary:\n\n{self.summary}\n\n" "Related Themes:\n\n") + "\n".join(
54
- [f"{idx+1}: {theme}" for (idx, theme) in enumerate(self.themes)]
55
- )
56
-
57
 
58
- SLLM = LLM.with_structured_output(BriefSummary, strict=True)
59
 
 
 
60
  map_prompt = ChatPromptTemplate.from_messages([("system", map_template)])
61
  map_chain = map_prompt | SLLM
62
 
@@ -70,4 +102,4 @@ if __name__ == "__main__":
70
  """
71
 
72
  result = map_chain.invoke({"context": test_document})
73
- print(result)
 
1
  from enum import Enum
 
2
 
3
+ from langchain.output_parsers import RetryOutputParser
4
  from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_core.runnables import RunnableLambda
6
  from pydantic import BaseModel, Field
7
 
8
  from planning_ai.common.utils import Paths
9
  from planning_ai.llms.llm import LLM
10
+ from planning_ai.themes import PolicySelection, Theme
11
 
12
+ with open(Paths.PROMPTS / "themes.txt", "r") as f:
13
+ themes_txt = f.read()
14
 
15
+ with open(Paths.PROMPTS / "map.txt", "r") as f:
16
+ map_template = f"{themes_txt}\n\n {f.read()}"
17
 
 
 
 
 
 
 
 
 
18
 
19
+ class Sentiment(Enum):
20
+ POSITIVE = "positive"
21
+ NEGATIVE = "negative"
22
+ NEUTRAL = "neutral"
23
 
24
 
25
  class Place(BaseModel):
26
+ """Represents a geographical location mentioned in the response with associated sentiment."""
27
+
28
+ place: str = Field(
29
+ ...,
30
+ description=(
31
+ "The name of the geographical location mentioned in the response. "
32
+ "This can be a city, town, region, or any identifiable place."
33
+ ),
34
+ )
35
+ sentiment: Sentiment = Field(
36
+ ...,
37
+ description=(
38
+ "The sentiment associated with the mentioned place, categorized as 'positive', 'negative', or 'neutral'. "
39
+ "Assess sentiment based on the context in which the place is mentioned, considering both positive and negative connotations."
40
+ ),
41
+ )
42
 
43
 
44
  class BriefSummary(BaseModel):
45
  """A summary of the response with generated metadata"""
46
 
47
+ summary: str = Field(
 
48
  ...,
49
+ description=(
50
+ "A concise summary of the response, capturing the main points and overall sentiment. "
51
+ "The summary should reflect the key arguments and conclusions presented in the response."
52
+ ),
53
  )
54
  themes: list[Theme] = Field(
55
+ ...,
56
+ description=(
57
+ "A list of themes associated with the response. Themes are overarching topics or "
58
+ "categories that the response addresses, such as 'Climate change' or 'Infrastructure'. "
59
+ "Identify themes based on the content and context of the response."
60
+ ),
61
+ )
62
+ policies: list[PolicySelection] = Field(
63
+ ...,
64
+ description=(
65
+ "A list of policies associated with the response, each accompanied by directly related "
66
+ "information as bullet points. Bullet points should provide specific details or examples "
67
+ "that illustrate how the policy is relevant to the response."
68
+ ),
69
  )
70
  places: list[Place] = Field(
71
  ...,
72
+ description=(
73
+ "All places mentioned in the response, with the sentiment categorized as 'positive', 'negative', or 'neutral'. "
74
+ "A place can be a city, region, or any geographical location. Assess sentiment based on the context "
75
+ "in which the place is mentioned, considering both positive and negative connotations."
76
+ ),
77
  )
78
+ is_constructive: bool = Field(
79
  ...,
80
+ description=(
81
+ "A flag indicating whether the response is constructive. A response is considered constructive if it "
82
+ "provides actionable suggestions or feedback, addresses specific themes or policies, and is presented "
83
+ "in a coherent and logical manner."
84
+ ),
85
  )
86
 
 
 
 
 
 
87
 
88
+ SLLM = LLM.with_structured_output(BriefSummary, strict=False)
89
 
90
+ # TODO: Split out the policy stuff from this class. Find policies later based on
91
+ # what themes are already identified (should improve accuracy)
92
  map_prompt = ChatPromptTemplate.from_messages([("system", map_template)])
93
  map_chain = map_prompt | SLLM
94
 
 
102
  """
103
 
104
  result = map_chain.invoke({"context": test_document})
105
+ __import__("pprint").pprint(dict(result))