Spaces:
Build error
Build error
refactor: restructure graph creation to incorporate DocumentState and simplify node management
Browse files- planning_ai/graph.py +29 -30
planning_ai/graph.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from langgraph.constants import Send
|
2 |
from langgraph.graph import END, StateGraph
|
3 |
|
4 |
from planning_ai.nodes.hallucination_node import (
|
@@ -8,45 +8,44 @@ from planning_ai.nodes.hallucination_node import (
|
|
8 |
map_hallucinations,
|
9 |
)
|
10 |
from planning_ai.nodes.map_node import (
|
11 |
-
all_summaries,
|
12 |
collect_summaries,
|
13 |
generate_summary,
|
14 |
map_summaries,
|
15 |
)
|
16 |
from planning_ai.nodes.reduce_node import generate_final_summary
|
17 |
-
from planning_ai.states import OverallState
|
18 |
|
19 |
|
20 |
def create_graph():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
graph = StateGraph(OverallState)
|
22 |
-
graph.add_node("
|
23 |
-
graph.add_node("
|
24 |
-
|
25 |
-
graph.
|
26 |
-
graph.
|
27 |
-
|
28 |
-
|
29 |
-
graph.
|
30 |
-
"generate_summary", map_hallucinations, ["check_hallucination"]
|
31 |
-
)
|
32 |
-
graph.add_conditional_edges(
|
33 |
-
"check_hallucination",
|
34 |
-
map_fix_hallucinations,
|
35 |
-
["fix_hallucination", "collect_summaries"],
|
36 |
-
)
|
37 |
-
graph.add_conditional_edges(
|
38 |
-
"fix_hallucination",
|
39 |
-
map_hallucinations,
|
40 |
-
["check_hallucination", "collect_summaries"],
|
41 |
-
)
|
42 |
-
graph.add_conditional_edges(
|
43 |
-
"collect_summaries",
|
44 |
-
all_summaries,
|
45 |
-
{True: "generate_final_summary", False: "collect_summaries"},
|
46 |
-
)
|
47 |
-
graph.add_edge("generate_final_summary", END)
|
48 |
|
49 |
return graph.compile()
|
50 |
|
51 |
|
52 |
-
|
|
|
|
1 |
+
from langgraph.constants import START, Send
|
2 |
from langgraph.graph import END, StateGraph
|
3 |
|
4 |
from planning_ai.nodes.hallucination_node import (
|
|
|
8 |
map_hallucinations,
|
9 |
)
|
10 |
from planning_ai.nodes.map_node import (
|
|
|
11 |
collect_summaries,
|
12 |
generate_summary,
|
13 |
map_summaries,
|
14 |
)
|
15 |
from planning_ai.nodes.reduce_node import generate_final_summary
|
16 |
+
from planning_ai.states import DocumentState, OverallState
|
17 |
|
18 |
|
19 |
def create_graph():
|
20 |
+
subgraph = StateGraph(DocumentState)
|
21 |
+
subgraph.add_node("generate_summary", generate_summary)
|
22 |
+
# subgraph.add_node("check_hallucination", check_hallucination)
|
23 |
+
# subgraph.add_node("fix_hallucination", fix_hallucination)
|
24 |
+
|
25 |
+
subgraph.add_conditional_edges(START, map_summaries, ["generate_summary"])
|
26 |
+
# subgraph.add_conditional_edges(
|
27 |
+
# "generate_summary", map_hallucinations, ["check_hallucination"]
|
28 |
+
# )
|
29 |
+
# subgraph.add_conditional_edges(
|
30 |
+
# "check_hallucination", map_fix_hallucinations, ["fix_hallucination"]
|
31 |
+
# )
|
32 |
+
# subgraph.add_conditional_edges(
|
33 |
+
# "fix_hallucination", map_hallucinations, ["check_hallucination"]
|
34 |
+
# )
|
35 |
+
subgraph = subgraph.compile()
|
36 |
+
|
37 |
graph = StateGraph(OverallState)
|
38 |
+
graph.add_node("summary_graph", subgraph)
|
39 |
+
# graph.add_node("collect_summaries", collect_summaries)
|
40 |
+
|
41 |
+
graph.add_edge(START, "summary_graph")
|
42 |
+
# graph.add_conditional_edges(
|
43 |
+
# "summary_graph", map_hallucinations, ["collect_summaries"]
|
44 |
+
# )
|
45 |
+
# graph.add_edge("generate_final_summary", END)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
return graph.compile()
|
48 |
|
49 |
|
50 |
+
# graph = create_graph()
|
51 |
+
# graph.get_graph().draw_png("test.png")
|