Spaces:
Build error
Build error
add policies to summary
Browse files- planning_ai/chains/map_chain.py +61 -29
planning_ai/chains/map_chain.py
CHANGED
@@ -1,62 +1,94 @@
|
|
1 |
from enum import Enum
|
2 |
-
from typing import Literal
|
3 |
|
|
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
5 |
from pydantic import BaseModel, Field
|
6 |
|
7 |
from planning_ai.common.utils import Paths
|
8 |
from planning_ai.llms.llm import LLM
|
|
|
9 |
|
10 |
-
with open(Paths.PROMPTS / "
|
11 |
-
|
12 |
|
|
|
|
|
13 |
|
14 |
-
class Theme(str, Enum):
|
15 |
-
climate = "Climate change"
|
16 |
-
biodiversity = "Biodiversity and green spaces"
|
17 |
-
wellbeing = "Wellbeing and social inclusion"
|
18 |
-
great_places = "Great places"
|
19 |
-
jobs = "Jobs"
|
20 |
-
homes = "Homes"
|
21 |
-
infrastructure = "Infrastructure"
|
22 |
|
23 |
-
|
24 |
-
|
|
|
|
|
25 |
|
26 |
|
27 |
class Place(BaseModel):
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
class BriefSummary(BaseModel):
|
33 |
"""A summary of the response with generated metadata"""
|
34 |
|
35 |
-
summary: str = Field(
|
36 |
-
stance: Literal["SUPPORT", "OPPOSE", "MIXED", "NEUTRAL"] = Field(
|
37 |
...,
|
38 |
-
description=
|
|
|
|
|
|
|
39 |
)
|
40 |
themes: list[Theme] = Field(
|
41 |
-
...,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
)
|
43 |
places: list[Place] = Field(
|
44 |
...,
|
45 |
-
description=
|
|
|
|
|
|
|
|
|
46 |
)
|
47 |
-
|
48 |
...,
|
49 |
-
description=
|
|
|
|
|
|
|
|
|
50 |
)
|
51 |
|
52 |
-
def __str__(self) -> str:
|
53 |
-
return (f"Summary:\n\n{self.summary}\n\n" "Related Themes:\n\n") + "\n".join(
|
54 |
-
[f"{idx+1}: {theme}" for (idx, theme) in enumerate(self.themes)]
|
55 |
-
)
|
56 |
-
|
57 |
|
58 |
-
SLLM = LLM.with_structured_output(BriefSummary, strict=
|
59 |
|
|
|
|
|
60 |
map_prompt = ChatPromptTemplate.from_messages([("system", map_template)])
|
61 |
map_chain = map_prompt | SLLM
|
62 |
|
@@ -70,4 +102,4 @@ if __name__ == "__main__":
|
|
70 |
"""
|
71 |
|
72 |
result = map_chain.invoke({"context": test_document})
|
73 |
-
|
|
|
1 |
from enum import Enum
|
|
|
2 |
|
3 |
+
from langchain.output_parsers import RetryOutputParser
|
4 |
from langchain_core.prompts import ChatPromptTemplate
|
5 |
+
from langchain_core.runnables import RunnableLambda
|
6 |
from pydantic import BaseModel, Field
|
7 |
|
8 |
from planning_ai.common.utils import Paths
|
9 |
from planning_ai.llms.llm import LLM
|
10 |
+
from planning_ai.themes import PolicySelection, Theme
|
11 |
|
12 |
+
with open(Paths.PROMPTS / "themes.txt", "r") as f:
|
13 |
+
themes_txt = f.read()
|
14 |
|
15 |
+
with open(Paths.PROMPTS / "map.txt", "r") as f:
|
16 |
+
map_template = f"{themes_txt}\n\n {f.read()}"
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
class Sentiment(Enum):
|
20 |
+
POSITIVE = "positive"
|
21 |
+
NEGATIVE = "negative"
|
22 |
+
NEUTRAL = "neutral"
|
23 |
|
24 |
|
25 |
class Place(BaseModel):
|
26 |
+
"""Represents a geographical location mentioned in the response with associated sentiment."""
|
27 |
+
|
28 |
+
place: str = Field(
|
29 |
+
...,
|
30 |
+
description=(
|
31 |
+
"The name of the geographical location mentioned in the response. "
|
32 |
+
"This can be a city, town, region, or any identifiable place."
|
33 |
+
),
|
34 |
+
)
|
35 |
+
sentiment: Sentiment = Field(
|
36 |
+
...,
|
37 |
+
description=(
|
38 |
+
"The sentiment associated with the mentioned place, categorized as 'positive', 'negative', or 'neutral'. "
|
39 |
+
"Assess sentiment based on the context in which the place is mentioned, considering both positive and negative connotations."
|
40 |
+
),
|
41 |
+
)
|
42 |
|
43 |
|
44 |
class BriefSummary(BaseModel):
|
45 |
"""A summary of the response with generated metadata"""
|
46 |
|
47 |
+
summary: str = Field(
|
|
|
48 |
...,
|
49 |
+
description=(
|
50 |
+
"A concise summary of the response, capturing the main points and overall sentiment. "
|
51 |
+
"The summary should reflect the key arguments and conclusions presented in the response."
|
52 |
+
),
|
53 |
)
|
54 |
themes: list[Theme] = Field(
|
55 |
+
...,
|
56 |
+
description=(
|
57 |
+
"A list of themes associated with the response. Themes are overarching topics or "
|
58 |
+
"categories that the response addresses, such as 'Climate change' or 'Infrastructure'. "
|
59 |
+
"Identify themes based on the content and context of the response."
|
60 |
+
),
|
61 |
+
)
|
62 |
+
policies: list[PolicySelection] = Field(
|
63 |
+
...,
|
64 |
+
description=(
|
65 |
+
"A list of policies associated with the response, each accompanied by directly related "
|
66 |
+
"information as bullet points. Bullet points should provide specific details or examples "
|
67 |
+
"that illustrate how the policy is relevant to the response."
|
68 |
+
),
|
69 |
)
|
70 |
places: list[Place] = Field(
|
71 |
...,
|
72 |
+
description=(
|
73 |
+
"All places mentioned in the response, with the sentiment categorized as 'positive', 'negative', or 'neutral'. "
|
74 |
+
"A place can be a city, region, or any geographical location. Assess sentiment based on the context "
|
75 |
+
"in which the place is mentioned, considering both positive and negative connotations."
|
76 |
+
),
|
77 |
)
|
78 |
+
is_constructive: bool = Field(
|
79 |
...,
|
80 |
+
description=(
|
81 |
+
"A flag indicating whether the response is constructive. A response is considered constructive if it "
|
82 |
+
"provides actionable suggestions or feedback, addresses specific themes or policies, and is presented "
|
83 |
+
"in a coherent and logical manner."
|
84 |
+
),
|
85 |
)
|
86 |
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
+
SLLM = LLM.with_structured_output(BriefSummary, strict=False)
|
89 |
|
90 |
+
# TODO: Split out the policy stuff from this class. Find policies later based on
|
91 |
+
# what themes are already identified (should improve accuracy)
|
92 |
map_prompt = ChatPromptTemplate.from_messages([("system", map_template)])
|
93 |
map_chain = map_prompt | SLLM
|
94 |
|
|
|
102 |
"""
|
103 |
|
104 |
result = map_chain.invoke({"context": test_document})
|
105 |
+
__import__("pprint").pprint(dict(result))
|