DrishtiSharma commited on
Commit
45f45d3
·
verified ·
1 Parent(s): 279951c

Create graph.py

Browse files
Files changed (1) hide show
  1. lab/graph.py +146 -0
lab/graph.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __import__('pysqlite3') # Workaround for sqlite3 error on live Streamlit.
2
+ import sys
3
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') # Workaround for sqlite3 error on live Streamlit.
4
+
5
+ from langgraph.graph import StateGraph, END
6
+ from langchain_openai import ChatOpenAI
7
+ from pydantic import BaseModel, Field
8
+ from typing import TypedDict, List, Literal, Dict, Any
9
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.memory import ConversationBufferMemory
12
+ from pdf_writer import generate_pdf
13
+
14
+ from crew import CrewClass, Essay
15
+
16
+
17
+ class GraphState(TypedDict):
18
+ topic: str
19
+ response: str
20
+ documents: List[str]
21
+ essay: Dict[str, Any]
22
+ pdf_name: str
23
+
24
+
25
+ class RouteQuery(BaseModel):
26
+ """Route a user query to direct answer or research."""
27
+
28
+ way: Literal["edit_essay", "write_essay", "answer"] = Field(
29
+ ...,
30
+ description="Given a user question, choose to route it to write_essay, edit_essay, or answer",
31
+ )
32
+
33
+
34
+ class EssayWriter:
35
+ def __init__(self):
36
+ self.model = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0)
37
+ self.crew = CrewClass(llm=ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0.5))
38
+
39
+ self.memory = ConversationBufferMemory()
40
+ self.essay = {}
41
+ self.router_prompt = """
42
+ You are a router, and your duty is to route the user to the correct expert.
43
+ Always check conversation history and consider your move based on it.
44
+ If the topic is something about memory or daily talk, route the user to the answer expert.
45
+ If the topic starts with something like "Can you write" or the user requests an article or essay, route the user to the write_essay expert.
46
+ If the topic is about editing an essay, route the user to the edit_essay expert.
47
+
48
+ \nConversation History: {memory}
49
+ \nTopic: {topic}
50
+ """
51
+
52
+ self.simple_answer_prompt = """
53
+ You are an expert, and you are providing a simple answer to the user's question.
54
+
55
+ \nConversation History: {memory}
56
+ \nTopic: {topic}
57
+ """
58
+
59
+ builder = StateGraph(GraphState)
60
+
61
+ builder.add_node("answer", self.answer)
62
+ builder.add_node("write_essay", self.write_essay)
63
+ builder.add_node("edit_essay", self.edit_essay)
64
+
65
+ builder.set_conditional_entry_point(self.router_query, {
66
+ "write_essay": "write_essay",
67
+ "answer": "answer",
68
+ "edit_essay": "edit_essay",
69
+ })
70
+
71
+ builder.add_edge("write_essay", END)
72
+ builder.add_edge("edit_essay", END)
73
+ builder.add_edge("answer", END)
74
+
75
+ self.graph = builder.compile()
76
+ self.graph.get_graph().draw_mermaid_png(output_file_path="graph.png")
77
+
78
+ def router_query(self, state: GraphState):
79
+ print("**ROUTER**")
80
+ prompt = PromptTemplate.from_template(self.router_prompt)
81
+ memory = self.memory.load_memory_variables({})
82
+
83
+ router_query = self.model.with_structured_output(RouteQuery)
84
+ chain = prompt | router_query
85
+ result: RouteQuery = chain.invoke({"topic": state["topic"], "memory": memory})
86
+
87
+ print("Router Result: ", result.way)
88
+ return result.way
89
+
90
+ def answer(self, state: GraphState):
91
+ print("**ANSWER**")
92
+ prompt = PromptTemplate.from_template(self.simple_answer_prompt)
93
+ memory = self.memory.load_memory_variables({})
94
+ chain = prompt | self.model | StrOutputParser()
95
+ result = chain.invoke({"topic": state["topic"], "memory": memory})
96
+
97
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": result})
98
+ return {"response": result}
99
+
100
+ def write_essay(self, state: GraphState):
101
+ print("**ESSAY COMPLETION**")
102
+ # Generate the essay using the crew
103
+ self.essay = self.crew.kickoff({"topic": state["topic"]})
104
+ # Save the conversation context
105
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": str(self.essay)})
106
+ # Generate the PDF and return essay content for preview
107
+ pdf_name = generate_pdf(self.essay)
108
+ return {
109
+ "response": "Here is your essay! You can review it below before downloading.",
110
+ "essay": self.essay,
111
+ "pdf_name": pdf_name,
112
+ }
113
+
114
+ def edit_essay(self, state: GraphState):
115
+ print("**ESSAY EDIT**")
116
+ memory = self.memory.load_memory_variables({})
117
+
118
+ user_request = state["topic"]
119
+ parser = JsonOutputParser(pydantic_object=Essay)
120
+ prompt = PromptTemplate(
121
+ template=(
122
+ "Edit the JSON file as the user requested, and return the new JSON file."
123
+ "\n Request: {user_request} "
124
+ "\n Conversation History: {memory}"
125
+ "\n JSON File: {essay}"
126
+ " \n{format_instructions}"
127
+ ),
128
+ input_variables=["memory", "user_request", "essay"],
129
+ partial_variables={"format_instructions": parser.get_format_instructions()},
130
+ )
131
+
132
+ chain = prompt | self.model | parser
133
+
134
+ # Update the essay with the edits
135
+ self.essay = chain.invoke({"user_request": user_request, "memory": memory, "essay": self.essay})
136
+
137
+ # Save the conversation context
138
+ self.memory.save_context(inputs={"input": state["topic"]}, outputs={"output": str(self.essay)})
139
+
140
+ # Generate the PDF and return essay content for preview
141
+ pdf_name = generate_pdf(self.essay)
142
+ return {
143
+ "response": "Here is your edited essay! You can review it below before downloading.",
144
+ "essay": self.essay,
145
+ "pdf_name": pdf_name,
146
+ }