cjber commited on
Commit
aa05cc8
·
1 Parent(s): 83e127b

refactor: restructure graph creation to incorporate DocumentState and simplify node management

Browse files
Files changed (1) hide show
  1. planning_ai/graph.py +29 -30
planning_ai/graph.py CHANGED
@@ -1,4 +1,4 @@
1
- from langgraph.constants import Send
2
  from langgraph.graph import END, StateGraph
3
 
4
  from planning_ai.nodes.hallucination_node import (
@@ -8,45 +8,44 @@ from planning_ai.nodes.hallucination_node import (
8
  map_hallucinations,
9
  )
10
  from planning_ai.nodes.map_node import (
11
- all_summaries,
12
  collect_summaries,
13
  generate_summary,
14
  map_summaries,
15
  )
16
  from planning_ai.nodes.reduce_node import generate_final_summary
17
- from planning_ai.states import OverallState
18
 
19
 
20
  def create_graph():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  graph = StateGraph(OverallState)
22
- graph.add_node("generate_summary", generate_summary)
23
- graph.add_node("check_hallucination", check_hallucination)
24
- graph.add_node("fix_hallucination", fix_hallucination)
25
- graph.add_node("collect_summaries", collect_summaries)
26
- graph.add_node("generate_final_summary", generate_final_summary)
27
-
28
- graph.set_conditional_entry_point(map_summaries, ["generate_summary"])
29
- graph.add_conditional_edges(
30
- "generate_summary", map_hallucinations, ["check_hallucination"]
31
- )
32
- graph.add_conditional_edges(
33
- "check_hallucination",
34
- map_fix_hallucinations,
35
- ["fix_hallucination", "collect_summaries"],
36
- )
37
- graph.add_conditional_edges(
38
- "fix_hallucination",
39
- map_hallucinations,
40
- ["check_hallucination", "collect_summaries"],
41
- )
42
- graph.add_conditional_edges(
43
- "collect_summaries",
44
- all_summaries,
45
- {True: "generate_final_summary", False: "collect_summaries"},
46
- )
47
- graph.add_edge("generate_final_summary", END)
48
 
49
  return graph.compile()
50
 
51
 
52
- print(create_graph().get_graph().draw_ascii())
 
 
1
+ from langgraph.constants import START, Send
2
  from langgraph.graph import END, StateGraph
3
 
4
  from planning_ai.nodes.hallucination_node import (
 
8
  map_hallucinations,
9
  )
10
  from planning_ai.nodes.map_node import (
 
11
  collect_summaries,
12
  generate_summary,
13
  map_summaries,
14
  )
15
  from planning_ai.nodes.reduce_node import generate_final_summary
16
+ from planning_ai.states import DocumentState, OverallState
17
 
18
 
19
  def create_graph():
20
+ subgraph = StateGraph(DocumentState)
21
+ subgraph.add_node("generate_summary", generate_summary)
22
+ # subgraph.add_node("check_hallucination", check_hallucination)
23
+ # subgraph.add_node("fix_hallucination", fix_hallucination)
24
+
25
+ subgraph.add_conditional_edges(START, map_summaries, ["generate_summary"])
26
+ # subgraph.add_conditional_edges(
27
+ # "generate_summary", map_hallucinations, ["check_hallucination"]
28
+ # )
29
+ # subgraph.add_conditional_edges(
30
+ # "check_hallucination", map_fix_hallucinations, ["fix_hallucination"]
31
+ # )
32
+ # subgraph.add_conditional_edges(
33
+ # "fix_hallucination", map_hallucinations, ["check_hallucination"]
34
+ # )
35
+ subgraph = subgraph.compile()
36
+
37
  graph = StateGraph(OverallState)
38
+ graph.add_node("summary_graph", subgraph)
39
+ # graph.add_node("collect_summaries", collect_summaries)
40
+
41
+ graph.add_edge(START, "summary_graph")
42
+ # graph.add_conditional_edges(
43
+ # "summary_graph", map_hallucinations, ["collect_summaries"]
44
+ # )
45
+ # graph.add_edge("generate_final_summary", END)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  return graph.compile()
48
 
49
 
50
+ # graph = create_graph()
51
+ # graph.get_graph().draw_png("test.png")