DrishtiSharma commited on
Commit
d040e4c
·
verified ·
1 Parent(s): c9bbc9b

Create graph4.py

Browse files
Files changed (1) hide show
  1. lab/graph4.py +162 -0
lab/graph4.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __import__('pysqlite3') # Workaround for sqlite3 error on live Streamlit.
2
+ import sys
3
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') # Workaround for sqlite3 error on live Streamlit.
4
+ import graphviz
5
+ import traceback
6
+ import tempfile
7
+ from langgraph.graph import StateGraph, END
8
+ from langchain_openai import ChatOpenAI
9
+ from pydantic import BaseModel, Field
10
+ from typing import TypedDict, List, Literal, Dict, Any
11
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain.memory import ConversationBufferMemory
14
+ from pdf_writer import generate_pdf
15
+ from crew import CrewClass, Essay
16
+
17
+
18
+ class GraphState(TypedDict):
19
+ topic: str
20
+ response: str
21
+ documents: List[str]
22
+ essay: Dict[str, Any]
23
+ pdf_name: str
24
+
25
+
26
+ class RouteQuery(BaseModel):
27
+ """Route a user query to direct answer or research."""
28
+
29
+ way: Literal["edit_essay", "write_essay", "answer"] = Field(
30
+ ...,
31
+ description="Given a user question, choose to route it to write_essay, edit_essay, or answer",
32
+ )
33
+
34
+
35
+ class EssayWriter:
36
+ def __init__(self):
37
+ self.model = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0)
38
+ self.crew = CrewClass(llm=ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0.5))
39
+
40
+ self.memory = ConversationBufferMemory()
41
+ self.essay = {}
42
+ self.router_prompt = """
43
+ You are a router, and your duty is to route the user to the correct expert.
44
+ Always check conversation history and consider your move based on it.
45
+ If the topic is something about memory or daily talk, route the user to the answer expert.
46
+ If the topic starts with something like "Can you write" or the user requests an article or essay, route the user to the write_essay expert.
47
+ If the topic is about editing an essay, route the user to the edit_essay expert.
48
+
49
+ \nConversation History: {memory}
50
+ \nTopic: {topic}
51
+ """
52
+
53
+ self.simple_answer_prompt = """
54
+ You are an expert, and you are providing a simple answer to the user's question.
55
+
56
+ \nConversation History: {memory}
57
+ \nTopic: {topic}
58
+ """
59
+
60
+ builder = StateGraph(GraphState)
61
+
62
+ builder.add_node("answer", self.answer)
63
+ builder.add_node("write_essay", self.write_essay)
64
+ builder.add_node("edit_essay", self.edit_essay)
65
+
66
+ builder.set_conditional_entry_point(self.router_query, {
67
+ "write_essay": "write_essay",
68
+ "answer": "answer",
69
+ "edit_essay": "edit_essay",
70
+ })
71
+
72
+ builder.add_edge("write_essay", END)
73
+ builder.add_edge("edit_essay", END)
74
+ builder.add_edge("answer", END)
75
+
76
+ self.graph = builder.compile()
77
+ self.save_workflow_graph()
78
+
79
+
80
+ def router_query(self, state: GraphState):
81
+ print("**ROUTER**")
82
+ prompt = PromptTemplate.from_template(self.router_prompt)
83
+ memory = self.memory.load_memory_variables({})
84
+
85
+ router_query = self.model.with_structured_output(RouteQuery)
86
+ chain = prompt | router_query
87
+ result: RouteQuery = chain.invoke({"topic": state["topic"], "memory": memory})
88
+
89
+ print("Router Result: ", result.way)
90
+ return result.way
91
+
92
+ def answer(self, state: GraphState):
93
+ print("**ANSWER**")
94
+ prompt = PromptTemplate.from_template(self.simple_answer_prompt)
95
+ memory = self.memory.load_memory_variables({})
96
+ chain = prompt | self.model | StrOutputParser()
97
+ result = chain.invoke({"topic": state["topic"], "memory": memory})
98
+
99
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": result})
100
+ return {"response": result}
101
+
102
+ def write_essay(self, state: GraphState):
103
+ print("**ESSAY COMPLETION**")
104
+ # Generate the essay using the crew
105
+ self.essay = self.crew.kickoff({"topic": state["topic"]})
106
+ # Save the conversation context
107
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": str(self.essay)})
108
+ # Generate the PDF and return essay content for preview
109
+ pdf_name = generate_pdf(self.essay)
110
+ return {
111
+ "response": "Here is your essay! You can review it below before downloading.",
112
+ "essay": self.essay,
113
+ "pdf_name": pdf_name,
114
+ }
115
+
116
+ def edit_essay(self, state: GraphState):
117
+ print("**ESSAY EDIT**")
118
+ memory = self.memory.load_memory_variables({})
119
+
120
+ user_request = state["topic"]
121
+ parser = JsonOutputParser(pydantic_object=Essay)
122
+ prompt = PromptTemplate(
123
+ template=(
124
+ "Edit the JSON file as the user requested, and return the new JSON file."
125
+ "\n Request: {user_request} "
126
+ "\n Conversation History: {memory}"
127
+ "\n JSON File: {essay}"
128
+ " \n{format_instructions}"
129
+ ),
130
+ input_variables=["memory", "user_request", "essay"],
131
+ partial_variables={"format_instructions": parser.get_format_instructions()},
132
+ )
133
+
134
+ chain = prompt | self.model | parser
135
+
136
+ # Update the essay with the edits
137
+ self.essay = chain.invoke({"user_request": user_request, "memory": memory, "essay": self.essay})
138
+
139
+ # Save the conversation context
140
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": str(self.essay)})
141
+
142
+ # Generate the PDF and return essay content for preview
143
+ pdf_name = generate_pdf(self.essay)
144
+ return {
145
+ "response": "Here is your edited essay! You can review it below before downloading.",
146
+ "essay": self.essay,
147
+ "pdf_name": pdf_name,
148
+ }
149
+
150
+ def save_workflow_graph(self):
151
+ """Generate and save a dynamic LangGraph visualization to a fixed location."""
152
+ try:
153
+ graph_path = "/tmp/graph.png"
154
+
155
+ # Generate the mermaid diagram and save it to a fixed file
156
+ with open(graph_path, "wb") as f:
157
+ f.write(self.graph.get_graph().draw_mermaid_png())
158
+
159
+ print(f"✅ Workflow visualization saved at: {graph_path}")
160
+
161
+ except Exception as e:
162
+ print(f"❌ Error generating graph: {e}")