Spaces:
Build error
Build error
test iterate chain
Browse files
planning_ai/chains/iterate_chain.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Literal, TypedDict
|
2 |
+
|
3 |
+
from langchain_community.document_loaders import DirectoryLoader, TextLoader
|
4 |
+
from langchain_core.output_parsers import StrOutputParser
|
5 |
+
from langchain_core.prompts import ChatPromptTemplate
|
6 |
+
from langchain_core.runnables import RunnableConfig
|
7 |
+
from langgraph.graph import END, START, StateGraph
|
8 |
+
|
9 |
+
from planning_ai.common.utils import Paths
|
10 |
+
from planning_ai.llms.llm import LLM
|
11 |
+
|
12 |
+
# Initial summary
|
13 |
+
summarize_prompt = ChatPromptTemplate(
|
14 |
+
[
|
15 |
+
("human", "Write a concise summary of the following: {context}"),
|
16 |
+
]
|
17 |
+
)
|
18 |
+
initial_summary_chain = summarize_prompt | LLM | StrOutputParser()
|
19 |
+
|
20 |
+
refine_template = """
|
21 |
+
Produce a final summary.
|
22 |
+
|
23 |
+
Existing summary up to this point:
|
24 |
+
{existing_answer}
|
25 |
+
|
26 |
+
New context:
|
27 |
+
------------
|
28 |
+
{context}
|
29 |
+
------------
|
30 |
+
|
31 |
+
Given the new context, refine the original summary.
|
32 |
+
"""
|
33 |
+
refine_prompt = ChatPromptTemplate([("human", refine_template)])
|
34 |
+
|
35 |
+
refine_summary_chain = refine_prompt | LLM | StrOutputParser()
|
36 |
+
|
37 |
+
|
38 |
+
# We will define the state of the graph to hold the document
|
39 |
+
# contents and summary. We also include an index to keep track
|
40 |
+
# of our position in the sequence of documents.
|
41 |
+
class State(TypedDict):
|
42 |
+
contents: List[str]
|
43 |
+
index: int
|
44 |
+
summary: str
|
45 |
+
|
46 |
+
|
47 |
+
# We define functions for each node, including a node that generates
|
48 |
+
# the initial summary:
|
49 |
+
async def generate_initial_summary(state: State, config: RunnableConfig):
|
50 |
+
summary = await initial_summary_chain.ainvoke(
|
51 |
+
state["contents"][0],
|
52 |
+
config,
|
53 |
+
)
|
54 |
+
return {"summary": summary, "index": 1}
|
55 |
+
|
56 |
+
|
57 |
+
# And a node that refines the summary based on the next document
|
58 |
+
async def refine_summary(state: State, config: RunnableConfig):
|
59 |
+
content = state["contents"][state["index"]]
|
60 |
+
summary = await refine_summary_chain.ainvoke(
|
61 |
+
{"existing_answer": state["summary"], "context": content},
|
62 |
+
config,
|
63 |
+
)
|
64 |
+
|
65 |
+
return {"summary": summary, "index": state["index"] + 1}
|
66 |
+
|
67 |
+
|
68 |
+
# Here we implement logic to either exit the application or refine
|
69 |
+
# the summary.
|
70 |
+
def should_refine(state: State) -> Literal["refine_summary", END]:
|
71 |
+
if state["index"] >= len(state["contents"]):
|
72 |
+
return END
|
73 |
+
else:
|
74 |
+
return "refine_summary"
|
75 |
+
|
76 |
+
|
77 |
+
graph = StateGraph(State)
|
78 |
+
graph.add_node("generate_initial_summary", generate_initial_summary)
|
79 |
+
graph.add_node("refine_summary", refine_summary)
|
80 |
+
|
81 |
+
graph.add_edge(START, "generate_initial_summary")
|
82 |
+
graph.add_conditional_edges("generate_initial_summary", should_refine)
|
83 |
+
graph.add_conditional_edges("refine_summary", should_refine)
|
84 |
+
app = graph.compile()
|
85 |
+
|
86 |
+
loader = DirectoryLoader(
|
87 |
+
path=str(Paths.STAGING),
|
88 |
+
show_progress=True,
|
89 |
+
use_multithreading=True,
|
90 |
+
loader_cls=TextLoader,
|
91 |
+
recursive=True,
|
92 |
+
)
|
93 |
+
docs = [doc for doc in loader.load()[:20] if doc.page_content]
|
94 |
+
|
95 |
+
async for step in app.astream(
|
96 |
+
{"contents": [doc.page_content for doc in docs]},
|
97 |
+
stream_mode="values",
|
98 |
+
):
|
99 |
+
if summary := step.get("summary"):
|
100 |
+
print(summary)
|