cjber commited on
Commit
ceee7f6
·
1 Parent(s): 4800b83

test iterate chain

Browse files
Files changed (1) hide show
  1. planning_ai/chains/iterate_chain.py +100 -0
planning_ai/chains/iterate_chain.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal, TypedDict
2
+
3
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+ from langchain_core.runnables import RunnableConfig
7
+ from langgraph.graph import END, START, StateGraph
8
+
9
+ from planning_ai.common.utils import Paths
10
+ from planning_ai.llms.llm import LLM
11
+
12
+ # Initial summary
13
+ summarize_prompt = ChatPromptTemplate(
14
+ [
15
+ ("human", "Write a concise summary of the following: {context}"),
16
+ ]
17
+ )
18
+ initial_summary_chain = summarize_prompt | LLM | StrOutputParser()
19
+
20
+ refine_template = """
21
+ Produce a final summary.
22
+
23
+ Existing summary up to this point:
24
+ {existing_answer}
25
+
26
+ New context:
27
+ ------------
28
+ {context}
29
+ ------------
30
+
31
+ Given the new context, refine the original summary.
32
+ """
33
+ refine_prompt = ChatPromptTemplate([("human", refine_template)])
34
+
35
+ refine_summary_chain = refine_prompt | LLM | StrOutputParser()
36
+
37
+
38
+ # We will define the state of the graph to hold the document
39
+ # contents and summary. We also include an index to keep track
40
+ # of our position in the sequence of documents.
41
+ class State(TypedDict):
42
+ contents: List[str]
43
+ index: int
44
+ summary: str
45
+
46
+
47
+ # We define functions for each node, including a node that generates
48
+ # the initial summary:
49
+ async def generate_initial_summary(state: State, config: RunnableConfig):
50
+ summary = await initial_summary_chain.ainvoke(
51
+ state["contents"][0],
52
+ config,
53
+ )
54
+ return {"summary": summary, "index": 1}
55
+
56
+
57
+ # And a node that refines the summary based on the next document
58
+ async def refine_summary(state: State, config: RunnableConfig):
59
+ content = state["contents"][state["index"]]
60
+ summary = await refine_summary_chain.ainvoke(
61
+ {"existing_answer": state["summary"], "context": content},
62
+ config,
63
+ )
64
+
65
+ return {"summary": summary, "index": state["index"] + 1}
66
+
67
+
68
+ # Here we implement logic to either exit the application or refine
69
+ # the summary.
70
+ def should_refine(state: State) -> Literal["refine_summary", END]:
71
+ if state["index"] >= len(state["contents"]):
72
+ return END
73
+ else:
74
+ return "refine_summary"
75
+
76
+
77
+ graph = StateGraph(State)
78
+ graph.add_node("generate_initial_summary", generate_initial_summary)
79
+ graph.add_node("refine_summary", refine_summary)
80
+
81
+ graph.add_edge(START, "generate_initial_summary")
82
+ graph.add_conditional_edges("generate_initial_summary", should_refine)
83
+ graph.add_conditional_edges("refine_summary", should_refine)
84
+ app = graph.compile()
85
+
86
+ loader = DirectoryLoader(
87
+ path=str(Paths.STAGING),
88
+ show_progress=True,
89
+ use_multithreading=True,
90
+ loader_cls=TextLoader,
91
+ recursive=True,
92
+ )
93
+ docs = [doc for doc in loader.load()[:20] if doc.page_content]
94
+
95
+ async for step in app.astream(
96
+ {"contents": [doc.page_content for doc in docs]},
97
+ stream_mode="values",
98
+ ):
99
+ if summary := step.get("summary"):
100
+ print(summary)