Spaces:
Running
Running
File size: 1,376 Bytes
e931b70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import json
from typing import List
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema import BaseRetriever, Document
from langchain.tools import Tool
from backend.chat_bot.json_decoder import CustomJSONEncoder
class RetrieverInput(BaseModel):
query: str = Field(description="query to look up in retriever")
def create_retriever_tool(
retriever: BaseRetriever,
tool_name: str,
description: str
) -> Tool:
"""Create a tool to do retrieval of documents.
Args:
retriever: The retriever to use for the retrieval
tool_name: The name for the tool. This will be passed to the language model,
so should be unique and somewhat descriptive.
description: The description for the tool. This will be passed to the language
model, so should be descriptive.
Returns:
Tool class to pass to an agent
"""
def wrap(func):
def wrapped_retrieve(*args, **kwargs):
docs: List[Document] = func(*args, **kwargs)
return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
return wrapped_retrieve
return Tool(
name=tool_name,
description=description,
func=wrap(retriever.get_relevant_documents),
coroutine=retriever.aget_relevant_documents,
args_schema=RetrieverInput,
)
|