from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableLambda, RunnableSequence from langchain_core.messages import HumanMessage def strip_prompt(info): print(info) eot_token = "[/INST] " i = info.content.rfind(eot_token) if i == -1: return info info.content = info.content[i + len(eot_token):] return {"next": info.content} class Supervisor(): members = [] system_prompt = ( "You are a supervisor tasked with managing a conversation between the" " following workers: {members}. Given the following user request," " respond with the worker to act next. Each worker will perform a" " task and respond with their results and status. When finished," " respond with FINISH." ) def __init__(self, llm, members): self.members += members self.prompt = ChatPromptTemplate.from_messages( [ ("human", self.system_prompt), ("assistant", "ok"), MessagesPlaceholder(variable_name="messages"), ("assistant", "ok"), ( "human", "Given the conversation above, who should act next?" " Or should we FINISH? Select one of: {options}", ), ] ).partial(options=str(self.get_options()), members=", ".join(self.members)) self.chain = ( self.prompt | llm | RunnableLambda(strip_prompt)) def add_member(self, member): self.members.append(member) def get_members(self): return self.members; def get_options(self): return ["FINISH"] + self.members def get_chain(self): return self.chain def invoke(self, query): self.chain.invoke([HumanMessage(query)])