solalatus commited on
Commit
ae72833
·
1 Parent(s): 79f7395

better tutoring flow

Browse files
Files changed (1) hide show
  1. agent.py +72 -1
agent.py CHANGED
@@ -17,6 +17,11 @@ from langchain.memory import ConversationBufferWindowMemory
17
  import random
18
  from pydantic import Extra
19
 
 
 
 
 
 
20
 
21
  import promptlayer
22
  from langchain.callbacks import PromptLayerCallbackHandler
@@ -67,9 +72,67 @@ class GetrandomTool(BaseTool):
67
 
68
  return text
69
 
70
- def _arun(self, value: Union[int, float]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  raise NotImplementedError("This tool does not support async")
72
 
 
 
73
  class QMLAgent():
74
 
75
  def __init__(self):
@@ -98,6 +161,8 @@ class QMLAgent():
98
  random_tool.indexer = index
99
  random_tool.index_max = index.describe_index_stats()["total_vector_count"]
100
 
 
 
101
 
102
  tools = [
103
  Tool(
@@ -119,6 +184,12 @@ class QMLAgent():
119
  #return_direct=False
120
 
121
  ),
 
 
 
 
 
 
122
  ]
123
 
124
  memory = ConversationBufferWindowMemory(k=os.environ["MEMORY_LENGTH"], memory_key="chat_history", return_messages=True)
 
17
  import random
18
  from pydantic import Extra
19
 
20
+ from langchain.chains import LLMChain
21
+ from langchain import PromptTemplate
22
+
23
+
24
+
25
 
26
  import promptlayer
27
  from langchain.callbacks import PromptLayerCallbackHandler
 
72
 
73
  return text
74
 
75
+ def _arun(self, question: str):
76
+ raise NotImplementedError("This tool does not support async")
77
+
78
+
79
+ #Tutor tool?
80
+ #A chain with retrieval for answering, and a constant input summary of the
81
+ #tutoring flow so far
82
+
83
+ class TutoringTool(BaseTool):
84
+ name = "TutoringTool"
85
+
86
+ description = """This tool is capable of generating tutoring questions.
87
+ It has to be called with a summary of the previous tutoring discussion steps,
88
+ or in case of a new tutoring session, with a randomly chosen piece of material.
89
+ As for it's output, it has to be kept at it is, sent bakc to the user."""
90
+
91
+ class Config:
92
+ extra = Extra.allow
93
+
94
+ def _run(self, question: str):
95
+ # initiate chain and prompts
96
+ # Would be waaay more elegant to have it at init time, but...
97
+ prompt_template = """
98
+ You act as a knowledgeable tutor. Based on some previous [apropos]
99
+ (a question, a piace of material, or a summary of a tutoring session)
100
+ and some [relevant documents] generate a tutoring question,
101
+ that helps in systematically think about the topic at hand and
102
+ for which the answer is deepening the knowledge of the subject, getting closser to an aswer.
103
+ You should NOT answer the question at hand, just either ask a helping question
104
+ or confirm if an aswer is correct! This should be [your output].
105
+
106
+ [apropos]
107
+ {apropos}
108
+
109
+ [relevant_documents]
110
+ {relevant_documents}
111
+
112
+ [your output]:
113
+ """
114
+
115
+ prompt = PromptTemplate(
116
+ input_variables=["apropos", "relevant_documents"],
117
+ template=prompt_template,
118
+ )
119
+
120
+ # do retrieval
121
+ relevant_documents = self.retriever.get_relevant_documents(question)
122
+
123
+ # concat the two
124
+ # execute a chain
125
+ llm = ChatOpenAI(model_name=os.environ["CHAT_MODEL"])
126
+ chain = LLMChain(llm=llm, prompt=prompt)
127
+
128
+ result = chain.run(apropos=question, relevant_documents=relevant_documents)
129
+ return result
130
+
131
+ def _arun(self, question: str):
132
  raise NotImplementedError("This tool does not support async")
133
 
134
+
135
+
136
  class QMLAgent():
137
 
138
  def __init__(self):
 
161
  random_tool.indexer = index
162
  random_tool.index_max = index.describe_index_stats()["total_vector_count"]
163
 
164
+ tutoring_tool = TutoringTool()
165
+ tutoring_tool.retriever = retriever
166
 
167
  tools = [
168
  Tool(
 
184
  #return_direct=False
185
 
186
  ),
187
+ Tool.from_function(
188
+ name="Tutoring",
189
+ func=tutoring_tool._run,
190
+ description=tutoring_tool.description,
191
+ return_direct=True
192
+ ),
193
  ]
194
 
195
  memory = ConversationBufferWindowMemory(k=os.environ["MEMORY_LENGTH"], memory_key="chat_history", return_messages=True)