import json from typing import List from langchain.pydantic_v1 import BaseModel, Field from langchain.schema import BaseRetriever, Document from langchain.tools import Tool from backend.chat_bot.json_decoder import CustomJSONEncoder class RetrieverInput(BaseModel): query: str = Field(description="query to look up in retriever") def create_retriever_tool( retriever: BaseRetriever, tool_name: str, description: str ) -> Tool: """Create a tool to do retrieval of documents. Args: retriever: The retriever to use for the retrieval tool_name: The name for the tool. This will be passed to the language model, so should be unique and somewhat descriptive. description: The description for the tool. This will be passed to the language model, so should be descriptive. Returns: Tool class to pass to an agent """ def wrap(func): def wrapped_retrieve(*args, **kwargs): docs: List[Document] = func(*args, **kwargs) return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder) return wrapped_retrieve return Tool( name=tool_name, description=description, func=wrap(retriever.get_relevant_documents), coroutine=retriever.aget_relevant_documents, args_schema=RetrieverInput, )